Source code for lm_polygraph.stat_calculators.greedy_lm_probs

import torch
import numpy as np

from typing import Dict, List, Tuple

from .stat_calculator import StatCalculator
from lm_polygraph.utils.model import WhiteboxModel


[docs]class GreedyLMProbsCalculator(StatCalculator): """ Calculates probabilities of the model generations without input texts. Used to calculate P(y_t|y_<t) subtrahend in PointwiseMutualInformation. """
[docs] @staticmethod def meta_info() -> Tuple[List[str], List[str]]: """ Returns the statistics and dependencies for the calculator. """ return ["greedy_lm_log_probs", "greedy_lm_log_likelihoods"], ["greedy_tokens"]
def __init__(self): super().__init__() def __call__( self, dependencies: Dict[str, np.array], texts: List[str], model: WhiteboxModel, max_new_tokens: int = 100, **kwargs, ) -> Dict[str, np.ndarray]: """ Calculates the entropy of probabilities at each token position in the generation. Parameters: dependencies (Dict[str, np.ndarray]): input statistics, consisting of: - 'greedy_tokens' (List[List[int]]): tokenized model generations for each input text. texts (List[str]): Input texts batch used for model generation. model (Model): Model used for generation. max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100. Returns: Dict[str, np.ndarray]: dictionary with the following items: - 'greedy_lm_log_probs' (List[List[np.array]]): logarithms of autoregressive probability distributions when generating the generated text without input text. - 'greedy_lm_log_likelihoods' (List[List[float]]): log-probabilities of generating text without input. P(y_t | y_<t) for all t. """ tokens = dependencies["greedy_tokens"] try: batch = model.tokenize([model.tokenizer.decode(t) for t in tokens]) batch = {k: v.to(model.device()) for k, v in batch.items()} with torch.no_grad(): if model.model_type == "Seq2SeqLM": logprobs = model.model( **batch, decoder_input_ids=batch["input_ids"] ).logits.log_softmax(-1) else: logprobs = model.model(**batch).logits.log_softmax(-1) greedy_lm_log_probs = [] greedy_lm_ll = [] for i in range(len(tokens)): if len(logprobs[i]) < len(tokens[i]): raise ValueError("tokenizer(tokenizer.decode(t)) != t") greedy_lm_log_probs.append( logprobs[i, -len(tokens[i]) : -1].cpu().float().numpy() ) greedy_lm_ll.append( [ logprobs[i, -len(tokens[i]) + j, tokens[i][j]].item() for j in range(len(tokens[i])) ] ) except ValueError: # case where tokenizer(tokenizer.decode(t)) != t; process each sequence separetly greedy_lm_log_probs = [] greedy_lm_ll = [] for toks in tokens: input_ids = torch.LongTensor([toks]).to(model.device()) batch = { "input_ids": input_ids, "attention_mask": torch.ones_like(input_ids).to(model.device()), } with torch.no_grad(): if model.model_type == "Seq2SeqLM": logprobs = model.model( **batch, decoder_input_ids=batch["input_ids"] ).logits.log_softmax(-1) else: logprobs = model.model(**batch).logits.log_softmax(-1) logprobs = logprobs[0] assert len(logprobs) >= len(toks) greedy_lm_log_probs.append( logprobs[-len(toks) : -1].cpu().float().numpy() ) greedy_lm_ll.append( [logprobs[-len(toks) + j, toks[j]].item() for j in range(len(toks))] ) return { "greedy_lm_log_probs": greedy_lm_log_probs, "greedy_lm_log_likelihoods": greedy_lm_ll, }