Source code for lm_polygraph.stat_calculators.greedy_probs

import torch
import numpy as np

from typing import Dict, List, Tuple, Union

from .embeddings import get_embeddings_from_output
from .stat_calculator import StatCalculator
from lm_polygraph.model_adapters import WhiteboxModel, WhiteboxModelvLLM


[docs]class GreedyProbsCalculator(StatCalculator): """ For Whitebox model (lm_polygraph.WhiteboxModel), at input texts batch calculates: * generation texts * tokens of the generation texts * probabilities distribution of the generated tokens * attention masks across the model (if applicable) * embeddings from the model """
[docs] @staticmethod def meta_info() -> Tuple[List[str], List[str]]: """ Returns the statistics and dependencies for the calculator. """ return [ "input_texts", "input_tokens", "greedy_log_probs", "greedy_tokens", "greedy_tokens_alternatives", "greedy_texts", "greedy_log_likelihoods", "embeddings", "attention_all", "tokenizer", ], []
def __init__( self, output_attentions: bool = True, output_hidden_states: bool = False, n_alternatives: int = 10, ): super().__init__() self.output_attentions = output_attentions self.output_hidden_states = output_hidden_states self.n_alternatives = n_alternatives def _preprocess_attention( self, attentions: torch.Tensor, current_idx: int, start_idx: int, end_idx: int, prompt_len: int, ) -> torch.Tensor: """ Preprocess attention weights before stacking. Parameters: attentions (torch.Tensor): Attention weights from a specific layer and head for a current token current_idx (int): Current position in the sequence start_idx (int): Start index of the generated tokens end_idx (int): End index of the generated tokens for current position prompt_len (int): Length of the prompt Returns: torch.Tensor: Preprocessed attention weights """ # Handle attention tensor processing for models with varying attention sizes (e.g. Gemma) n_attentions = attentions.shape[-1] # Handle empty tensor case if attentions.nelement() == 0: return torch.zeros(abs(current_idx), device=attentions.device) # Handle cases where attention size is smaller than expected if n_attentions < end_idx: if start_idx < 0: return attentions[start_idx:n_attentions] return attentions[n_attentions - current_idx : n_attentions] # Handle cases where attention spans beyond expected range if (n_attentions - current_idx) > end_idx and start_idx < 0: return attentions[prompt_len : prompt_len + current_idx] # Default case: return attention slice within expected range return attentions[start_idx:end_idx] def __call__( self, dependencies: Dict[str, np.array], texts: List[str], model: Union[WhiteboxModel, WhiteboxModelvLLM], max_new_tokens: int = 100, ) -> Dict[str, np.ndarray]: """ Calculates the statistics of probabilities at each token position in the generation. Parameters: dependencies (Dict[str, np.ndarray]): input statistics, can be empty (not used). texts (List[str]): Input texts batch used for model generation. model (Model): Model used for generation. max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100. Returns: Dict[str, np.ndarray]: dictionary with the following items: - 'input_tokens' (List[List[int]]): tokenized input texts, - 'greedy_log_probs' (List[List[np.array]]): logarithms of autoregressive probability distributions at each token, - 'greedy_texts' (List[str]): model generations corresponding to the inputs, - 'greedy_tokens' (List[List[int]]): tokenized model generations, - 'attention' (List[List[np.array]]): attention maps at each token, if applicable to the model, - 'greedy_log_likelihoods' (List[List[float]]): log-probabilities of the generated tokens. """ batch: Dict[str, torch.Tensor] = model.tokenize(texts) batch = {k: v.to(model.device()) for k, v in batch.items()} with torch.no_grad(): out = model.generate( **batch, output_scores=True, return_dict_in_generate=True, max_new_tokens=max_new_tokens, min_new_tokens=2, output_attentions=self.output_attentions, output_hidden_states=self.output_hidden_states, num_return_sequences=1, suppress_tokens=( [] if model.generation_parameters.allow_newlines else [ t for t in range(len(model.tokenizer)) if "\n" in model.tokenizer.decode([t]) ] ), ) logits = torch.stack(out.scores, dim=1) if model.model_type == "vLLMCausalLM": logits = logits.transpose(1, 0) sequences = out.sequences if self.output_attentions: attentions = out.attentions if self.output_hidden_states: embeddings_encoder, embeddings_decoder = get_embeddings_from_output( out, batch, model.model_type ) if embeddings_decoder.dtype == torch.bfloat16: embeddings_decoder = embeddings_decoder.to( torch.float16 ) # numpy does not support bfloat16 cut_logits = [] cut_sequences = [] cut_texts = [] cut_alternatives = [] for i in range(len(texts)): if model.model_type == "CausalLM": idx = batch["input_ids"].shape[1] seq = sequences[i, idx:].cpu() elif model.model_type == "vLLMCausalLM": seq = sequences[i].cpu() else: seq = sequences[i, 1:].cpu() length, text_length = len(seq), len(seq) for j in range(len(seq)): if seq[j] == model.tokenizer.eos_token_id: length = j + 1 text_length = j break cut_sequences.append(seq[:length].tolist()) cut_texts.append(model.tokenizer.decode(seq[:text_length])) cut_logits.append(logits[i, :length, :].cpu().numpy()) cut_alternatives.append([[] for _ in range(length)]) for j in range(length): lt = logits[i, j, :].cpu().numpy() best_tokens = np.argpartition(lt, -self.n_alternatives) ln = len(best_tokens) best_tokens = best_tokens[ln - self.n_alternatives : ln] for t in best_tokens: cut_alternatives[-1][j].append((t.item(), lt[t].item())) cut_alternatives[-1][j].sort( key=lambda x: x[0] == cut_sequences[-1][j], reverse=True, ) ll = [] for i in range(len(texts)): log_probs = cut_logits[i] tokens = cut_sequences[i] assert len(tokens) == len(log_probs) ll.append([log_probs[j, tokens[j]] for j in range(len(log_probs))]) attention_all = [] if self.output_attentions and (model.model_type != "vLLMCausalLM"): prompt_len = batch["input_ids"].shape[1] for i in range(len(texts)): c = len(cut_sequences[i]) _cfg = getattr(model.model.config, "text_config", model.model.config) attn_mask = np.zeros( shape=( _cfg.num_attention_heads * _cfg.num_hidden_layers, c, c, ) ) for j in range(1, c): # Get attention dimensions current_attention_len = attentions[j][0].shape[-1] # Default case: use relative indexing from end start_idx = -j end_idx = current_attention_len # Special case for models like Gemma that maintain consistent attention lengths if attentions[0][0].shape[-1] == current_attention_len: start_idx = prompt_len end_idx = prompt_len + j stacked_attention = torch.vstack( [ self._preprocess_attention( attentions[j][layer][0][head][0], j, start_idx, end_idx, prompt_len, ) for layer in range(len(attentions[j])) for head in range(len(attentions[j][layer][0])) ] ) if stacked_attention.dtype == torch.bfloat16: stacked_attention = stacked_attention.to( torch.float16 ) # numpy does not support bfloat16 attn_mask[:, j, :j] = stacked_attention.cpu().numpy() attention_all.append(attn_mask) if not self.output_hidden_states: embeddings_dict = {} elif model.model_type == "CausalLM": embeddings_dict = { "embeddings_decoder": embeddings_decoder.cpu().detach().numpy(), } elif model.model_type == "Seq2SeqLM": embeddings_dict = { "embeddings_encoder": embeddings_encoder.cpu().detach().numpy(), "embeddings_decoder": embeddings_decoder.cpu().detach().numpy(), } else: raise NotImplementedError result_dict = { "input_tokens": batch["input_ids"].to("cpu").tolist(), "greedy_log_probs": cut_logits, "greedy_tokens": cut_sequences, "greedy_tokens_alternatives": cut_alternatives, "greedy_texts": cut_texts, "greedy_log_likelihoods": ll, } result_dict.update(embeddings_dict) if self.output_attentions: result_dict.update({"attention_all": attention_all}) result_dict.update({"tokenizer": model.tokenizer}) return result_dict