Source code for lm_polygraph.stat_calculators.greedy_probs

import torch
import numpy as np

from typing import Dict, List

from .embeddings import get_embeddings_from_output
from .stat_calculator import StatCalculator
from lm_polygraph.utils.model import WhiteboxModel, BlackboxModel


[docs]class BlackboxGreedyTextsCalculator(StatCalculator): """ Calculates generation texts for Blackbox model (lm_polygraph.BlackboxModel). """ def __init__(self): super().__init__(["greedy_texts"], []) def __call__( self, dependencies: Dict[str, np.array], texts: List[str], model: BlackboxModel, max_new_tokens: int = 100, ) -> Dict[str, np.ndarray]: """ Calculates generation texts for Blackbox model on the input batch. Parameters: dependencies (Dict[str, np.ndarray]): input statistics, can be empty (not used). texts (List[str]): Input texts batch used for model generation. model (Model): Model used for generation. max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100. Returns: Dict[str, np.ndarray]: dictionary with List[List[float]] generation texts at 'greedy_texts' key. """ with torch.no_grad(): sequences = model.generate_texts( input_texts=texts, max_new_tokens=max_new_tokens, n=1, ) return {"greedy_texts": sequences}
[docs]class GreedyProbsCalculator(StatCalculator): """ For Whitebox model (lm_polygraph.WhiteboxModel), at input texts batch calculates: * generation texts * tokens of the generation texts * probabilities distribution of the generated tokens * attention masks across the model (if applicable) * embeddings from the model """ def __init__(self, n_alternatives: int = 10): super().__init__( [ "input_tokens", "greedy_log_probs", "greedy_tokens", "greedy_tokens_alternatives", "greedy_texts", "greedy_log_likelihoods", "train_greedy_log_likelihoods", "embeddings", ], [], ) self.n_alternatives = n_alternatives def __call__( self, dependencies: Dict[str, np.array], texts: List[str], model: WhiteboxModel, max_new_tokens: int = 100, ) -> Dict[str, np.ndarray]: """ Calculates the statistics of probabilities at each token position in the generation. Parameters: dependencies (Dict[str, np.ndarray]): input statistics, can be empty (not used). texts (List[str]): Input texts batch used for model generation. model (Model): Model used for generation. max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100. Returns: Dict[str, np.ndarray]: dictionary with the following items: - 'input_tokens' (List[List[int]]): tokenized input texts, - 'greedy_log_probs' (List[List[np.array]]): logarithms of autoregressive probability distributions at each token, - 'greedy_texts' (List[str]): model generations corresponding to the inputs, - 'greedy_tokens' (List[List[int]]): tokenized model generations, - 'attention' (List[List[np.array]]): attention maps at each token, if applicable to the model, - 'greedy_log_likelihoods' (List[List[float]]): log-probabilities of the generated tokens. """ batch: Dict[str, torch.Tensor] = model.tokenize(texts) batch = {k: v.to(model.device()) for k, v in batch.items()} with torch.no_grad(): out = model.generate( **batch, output_scores=True, return_dict_in_generate=True, max_new_tokens=max_new_tokens, min_new_tokens=2, output_attentions=False, output_hidden_states=True, num_return_sequences=1, suppress_tokens=( [] if model.generation_parameters.allow_newlines else [ t for t in range(len(model.tokenizer)) if "\n" in model.tokenizer.decode([t]) ] ), ) logits = torch.stack(out.scores, dim=1) sequences = out.sequences embeddings_encoder, embeddings_decoder = get_embeddings_from_output( out, batch, model.model_type ) cut_logits = [] cut_sequences = [] cut_texts = [] cut_alternatives = [] for i in range(len(texts)): if model.model_type == "CausalLM": idx = batch["input_ids"].shape[1] seq = sequences[i, idx:].cpu() else: seq = sequences[i, 1:].cpu() length, text_length = len(seq), len(seq) for j in range(len(seq)): if seq[j] == model.tokenizer.eos_token_id: length = j + 1 text_length = j break cut_sequences.append(seq[:length].tolist()) cut_texts.append(model.tokenizer.decode(seq[:text_length])) cut_logits.append(logits[i, :length, :].cpu().numpy()) cut_alternatives.append([[] for _ in range(length)]) for j in range(length): lt = logits[i, j, :].cpu().numpy() best_tokens = np.argpartition(lt, -self.n_alternatives) ln = len(best_tokens) best_tokens = best_tokens[ln - self.n_alternatives : ln] for t in best_tokens: cut_alternatives[-1][j].append((t.item(), lt[t].item())) cut_alternatives[-1][j].sort( key=lambda x: x[0] == cut_sequences[-1][j], reverse=True, ) ll = [] for i in range(len(texts)): log_probs = cut_logits[i] tokens = cut_sequences[i] assert len(tokens) == len(log_probs) ll.append([log_probs[j, tokens[j]] for j in range(len(log_probs))]) if model.model_type == "CausalLM": embeddings_dict = { "embeddings_decoder": embeddings_decoder, } elif model.model_type == "Seq2SeqLM": embeddings_dict = { "embeddings_encoder": embeddings_encoder, "embeddings_decoder": embeddings_decoder, } else: raise NotImplementedError result_dict = { "input_tokens": batch["input_ids"].to("cpu").tolist(), "greedy_log_probs": cut_logits, "greedy_tokens": cut_sequences, "greedy_tokens_alternatives": cut_alternatives, "greedy_texts": cut_texts, "greedy_log_likelihoods": ll, } result_dict.update(embeddings_dict) return result_dict