Source code for lm_polygraph.stat_calculators.ensemble_token_data

from typing import Dict, List, Tuple
from functools import partial

import torch
import numpy as np
from transformers import PreTrainedModel

from .stat_calculator import StatCalculator
from lm_polygraph.utils.token_restoration import (
    get_collect_fn,
)


[docs]class EnsembleTokenLevelDataCalculator(StatCalculator):
[docs] @staticmethod def meta_info() -> Tuple[List[str], List[str]]: """ Returns the statistics and dependencies for the calculator. """ return ["ensemble_token_scores"], []
def __init__(self): super().__init__() def __call__( self, dependencies: Dict[str, np.array], texts: List[str], model: PreTrainedModel, max_new_tokens: int = 100, ) -> Dict[str, np.ndarray]: # Ensemble-based UE methods have been disabled due to dependency on old # transformers code, which prevents bumping transformers version in # dependencies past 4.40.0. This is a temporary solution until the # code is updated to work with the latest transformers version. raise NotImplementedError( "Ensemble UE methods are not working properly in this version. Consider downgrading to 0.3.0" ) ensemble_model = dependencies["ensemble_model"] batch: Dict[str, torch.Tensor] = model.tokenize(texts) batch = {k: v.to(ensemble_model.device()) for k, v in batch.items()} generation_params = dependencies["ensemble_generation_params"] max_length = generation_params.get("generation_max_length", max_new_tokens) min_length = generation_params.get("generation_min_length", 2) num_return_sequences = generation_params.get("num_return_sequences", 5) model_config = ensemble_model.model.config if "mbart" in model_config._name_or_path: model_config.decoder_start_token_id = model.tokenizer.lang_code_to_id[ model.tokenizer.tgt_lang ] if generation_params.get("num_beams") is None and ( "do_sample" not in generation_params or generation_params["do_sample"] is None ): generation_params["num_beams"] = num_return_sequences with torch.no_grad(): output = ensemble_model.generate( **batch, max_length=max_length, min_length=min_length, output_scores=True, return_dict_in_generate=True, num_return_sequences=num_return_sequences, **generation_params, ) batch_length = len(batch["input_ids"]) collect_fn = get_collect_fn(output) collect_fn = partial( collect_fn, output, batch_length, num_return_sequences, ensemble_model.model.config.vocab_size, ensemble_model.model.config.pad_token_id, ) pe_token_level_scores = collect_fn( ensemble_uncertainties=output["pe_uncertainties"] ) ep_token_level_scores = collect_fn( ensemble_uncertainties=output["ep_uncertainties"] ) output_dict = { "pe_token_level_scores": pe_token_level_scores, "ep_token_level_scores": ep_token_level_scores, "probas": pe_token_level_scores["probas"], "log_probas": pe_token_level_scores["log_probas"], } if ensemble_model.model.ensembling_mode == "pe": output_dict.update( { "weights": torch.Tensor(pe_token_level_scores["weights"]), "scores_unbiased": torch.Tensor( pe_token_level_scores["scores_unbiased"] ), "entropy": torch.Tensor(pe_token_level_scores["entropy"]), "entropy_top5": torch.Tensor(pe_token_level_scores["entropy_top5"]), "entropy_top10": torch.Tensor( pe_token_level_scores["entropy_top10"] ), "entropy_top15": torch.Tensor( pe_token_level_scores["entropy_top15"] ), } ) elif ensemble_model.model.ensembling_mode == "ep": output_dict.update( { "weights": torch.Tensor(ep_token_level_scores["weights"]), "scores_unbiased": torch.Tensor( ep_token_level_scores["scores_unbiased"] ), "entropy": torch.Tensor(ep_token_level_scores["entropy"]), "entropy_top5": torch.Tensor(ep_token_level_scores["entropy_top5"]), "entropy_top10": torch.Tensor( ep_token_level_scores["entropy_top10"] ), "entropy_top15": torch.Tensor( ep_token_level_scores["entropy_top15"] ), } ) else: raise NotImplementedError return {"ensemble_token_scores": output_dict}