import torch
import logging
import traceback
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from typing import List, Dict, Tuple
from .stat_calculator import StatCalculator
from lm_polygraph.utils.model import WhiteboxModel
log = logging.getLogger(__name__)
def _batch_tokens(tokens_list: List[List[int]], model: WhiteboxModel):
token_tensors = [torch.tensor(t) for t in tokens_list]
tokens = pad_sequence(
token_tensors, batch_first=True, padding_value=model.tokenizer.pad_token_id
)
attn_mask = tokens != model.tokenizer.pad_token_id
return {"input_ids": tokens, "attention_mask": attn_mask}
[docs]class ModelScoreCalculator(StatCalculator):
def __init__(self, prompt: str = 'Paraphrase "{}": ', batch_size: int = 10):
super().__init__()
self.batch_size = batch_size
self.prompt = prompt
def _score(
self, model: WhiteboxModel, srcs: List[List[int]], tgts: List[List[int]]
) -> List[List[float]]:
score_list = []
for i in range(0, len(srcs), self.batch_size):
src_list = srcs[i : i + self.batch_size]
tgt_list = tgts[i : i + self.batch_size]
try:
with torch.no_grad():
encoded_src = _batch_tokens(
[s + t for s, t in zip(src_list, tgt_list)], model
)
src_tokens = encoded_src["input_ids"].to(model.device())
src_mask = encoded_src["attention_mask"].to(model.device())
if model.model_type == "CausalLM":
logits = model(
input_ids=src_tokens,
attention_mask=src_mask,
).logits
else:
encoded_src = _batch_tokens(src_list, model)
encoded_tgt = _batch_tokens(tgt_list, model)
src_tokens = encoded_src["input_ids"].to(model.device())
tgt_tokens = encoded_tgt["input_ids"].long().to(model.device())
src_mask = encoded_src["attention_mask"].to(model.device())
logits = model(
input_ids=src_tokens,
attention_mask=src_mask,
labels=tgt_tokens,
).logits
for j, sample_logits in enumerate(logits):
score_list.append([])
for token_i, logits_i in enumerate(
range(len(logits) - len(tgt_list[j]) - 1, len(logits) - 1)
):
score_list[-1].append(
sample_logits[logits_i, tgt_list[j][token_i]].item()
)
except RuntimeError:
traceback.print_exc()
log.error(f"source: {src_list}")
log.error(f"target: {tgt_list}")
exit(0)
return score_list
def __call__(
self,
dependencies: Dict[str, np.array],
texts: List[str],
model: WhiteboxModel,
max_new_tokens: int = 100,
**kwargs,
) -> Dict[str, np.ndarray]:
# inp_tokens = dependencies["input_tokens"]
preds = dependencies["greedy_tokens"]
prompted_refs = model.tokenizer(
[self.prompt.format(s) for s in dependencies["target_texts"]]
)["input_ids"]
scores = {"model_rh": self._score(model, prompted_refs, preds)}
# scores["sh"] = self._score(model, inp_tokens, preds)
# scores["hr"] = self._score(preds, refs)
return scores