from typing import Dict, List, Tuple
from functools import partial
import torch
import numpy as np
from transformers import PreTrainedModel
from .stat_calculator import StatCalculator
from lm_polygraph.utils.token_restoration import (
get_collect_fn,
)
[docs]class EnsembleTokenLevelDataCalculator(StatCalculator):
def __init__(self):
super().__init__()
def __call__(
self,
dependencies: Dict[str, np.array],
texts: List[str],
model: PreTrainedModel,
max_new_tokens: int = 100,
) -> Dict[str, np.ndarray]:
# Ensemble-based UE methods have been disabled due to dependency on old
# transformers code, which prevents bumping transformers version in
# dependencies past 4.40.0. This is a temporary solution until the
# code is updated to work with the latest transformers version.
raise NotImplementedError(
"Ensemble UE methods are not working properly in this version. Consider downgrading to 0.3.0"
)
ensemble_model = dependencies["ensemble_model"]
batch: Dict[str, torch.Tensor] = model.tokenize(texts)
batch = {k: v.to(ensemble_model.device()) for k, v in batch.items()}
generation_params = dependencies["ensemble_generation_params"]
max_length = generation_params.get("generation_max_length", max_new_tokens)
min_length = generation_params.get("generation_min_length", 2)
num_return_sequences = generation_params.get("num_return_sequences", 5)
model_config = ensemble_model.model.config
if "mbart" in model_config._name_or_path:
model_config.decoder_start_token_id = model.tokenizer.lang_code_to_id[
model.tokenizer.tgt_lang
]
if generation_params.get("num_beams") is None and (
"do_sample" not in generation_params
or generation_params["do_sample"] is None
):
generation_params["num_beams"] = num_return_sequences
with torch.no_grad():
output = ensemble_model.generate(
**batch,
max_length=max_length,
min_length=min_length,
output_scores=True,
return_dict_in_generate=True,
num_return_sequences=num_return_sequences,
**generation_params,
)
batch_length = len(batch["input_ids"])
collect_fn = get_collect_fn(output)
collect_fn = partial(
collect_fn,
output,
batch_length,
num_return_sequences,
ensemble_model.model.config.vocab_size,
ensemble_model.model.config.pad_token_id,
)
pe_token_level_scores = collect_fn(
ensemble_uncertainties=output["pe_uncertainties"]
)
ep_token_level_scores = collect_fn(
ensemble_uncertainties=output["ep_uncertainties"]
)
output_dict = {
"pe_token_level_scores": pe_token_level_scores,
"ep_token_level_scores": ep_token_level_scores,
"probas": pe_token_level_scores["probas"],
"log_probas": pe_token_level_scores["log_probas"],
}
if ensemble_model.model.ensembling_mode == "pe":
output_dict.update(
{
"weights": torch.Tensor(pe_token_level_scores["weights"]),
"scores_unbiased": torch.Tensor(
pe_token_level_scores["scores_unbiased"]
),
"entropy": torch.Tensor(pe_token_level_scores["entropy"]),
"entropy_top5": torch.Tensor(pe_token_level_scores["entropy_top5"]),
"entropy_top10": torch.Tensor(
pe_token_level_scores["entropy_top10"]
),
"entropy_top15": torch.Tensor(
pe_token_level_scores["entropy_top15"]
),
}
)
elif ensemble_model.model.ensembling_mode == "ep":
output_dict.update(
{
"weights": torch.Tensor(ep_token_level_scores["weights"]),
"scores_unbiased": torch.Tensor(
ep_token_level_scores["scores_unbiased"]
),
"entropy": torch.Tensor(ep_token_level_scores["entropy"]),
"entropy_top5": torch.Tensor(ep_token_level_scores["entropy_top5"]),
"entropy_top10": torch.Tensor(
ep_token_level_scores["entropy_top10"]
),
"entropy_top15": torch.Tensor(
ep_token_level_scores["entropy_top15"]
),
}
)
else:
raise NotImplementedError
return {"ensemble_token_scores": output_dict}