Source code for lm_polygraph.stat_calculators.model_score

import torch
import logging
import traceback
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from typing import List, Dict, Tuple
from .stat_calculator import StatCalculator
from lm_polygraph.utils.model import WhiteboxModel

log = logging.getLogger(__name__)


def _batch_tokens(tokens_list: List[List[int]], model: WhiteboxModel):
    token_tensors = [torch.tensor(t) for t in tokens_list]
    tokens = pad_sequence(
        token_tensors, batch_first=True, padding_value=model.tokenizer.pad_token_id
    )
    attn_mask = tokens != model.tokenizer.pad_token_id
    return {"input_ids": tokens, "attention_mask": attn_mask}


[docs]class ModelScoreCalculator(StatCalculator):
[docs] @staticmethod def meta_info() -> Tuple[List[str], List[str]]: """ Returns the statistics and dependencies for the calculator. """ return ["model_rh"], ["greedy_tokens", "input_tokens"]
def __init__(self, prompt: str = 'Paraphrase "{}": ', batch_size: int = 10): super().__init__() self.batch_size = batch_size self.prompt = prompt def _score( self, model: WhiteboxModel, srcs: List[List[int]], tgts: List[List[int]] ) -> List[List[float]]: score_list = [] for i in range(0, len(srcs), self.batch_size): src_list = srcs[i : i + self.batch_size] tgt_list = tgts[i : i + self.batch_size] try: with torch.no_grad(): encoded_src = _batch_tokens( [s + t for s, t in zip(src_list, tgt_list)], model ) src_tokens = encoded_src["input_ids"].to(model.device()) src_mask = encoded_src["attention_mask"].to(model.device()) if model.model_type == "CausalLM": logits = model( input_ids=src_tokens, attention_mask=src_mask, ).logits else: encoded_src = _batch_tokens(src_list, model) encoded_tgt = _batch_tokens(tgt_list, model) src_tokens = encoded_src["input_ids"].to(model.device()) tgt_tokens = encoded_tgt["input_ids"].long().to(model.device()) src_mask = encoded_src["attention_mask"].to(model.device()) logits = model( input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens, ).logits for j, sample_logits in enumerate(logits): score_list.append([]) for token_i, logits_i in enumerate( range(len(logits) - len(tgt_list[j]) - 1, len(logits) - 1) ): score_list[-1].append( sample_logits[logits_i, tgt_list[j][token_i]].item() ) except RuntimeError: traceback.print_exc() log.error(f"source: {src_list}") log.error(f"target: {tgt_list}") exit(0) return score_list def __call__( self, dependencies: Dict[str, np.array], texts: List[str], model: WhiteboxModel, max_new_tokens: int = 100, **kwargs, ) -> Dict[str, np.ndarray]: # inp_tokens = dependencies["input_tokens"] preds = dependencies["greedy_tokens"] prompted_refs = model.tokenizer( [self.prompt.format(s) for s in dependencies["target_texts"]] )["input_ids"] scores = {"model_rh": self._score(model, prompted_refs, preds)} # scores["sh"] = self._score(model, inp_tokens, preds) # scores["hr"] = self._score(preds, refs) return scores