Source code for lm_polygraph.utils.ensemble_utils.ensemble_generator

import inspect
from typing import Optional, Dict, Any, List, Tuple

import torch
from transformers import GenerationMixin
from transformers.generation.utils import ModelOutput

from lm_polygraph.utils.ensemble_utils.ensemble_beam import EnsembleBeamSearchMixin
from lm_polygraph.utils.ensemble_utils.ensemble_sample import EnsembleSampleMixin
from lm_polygraph.utils.ensemble_utils.ensemble_greedy import EnsembleGreedyMixin


[docs]class EnsembleGenerationMixin( EnsembleBeamSearchMixin, EnsembleSampleMixin, EnsembleGreedyMixin, GenerationMixin ):
[docs] def add_ensemble_models(self, models, devices): self._models_list = list(models)
@property def tokenizer(self): if hasattr(self, "_tokenizer"): return self._tokenizer return None @tokenizer.setter def tokenizer(self, value=None): self._tokenizer = value @property def models(self): return [self] + self._models_list @property def ensembling_mode(self): if self._ensembling_mode is not None: return self._ensembling_mode return "pe" @ensembling_mode.setter def ensembling_mode(self, value="pe"): self._ensembling_mode = value @property def mc(self): if hasattr(self, "_mc") and self._mc is not None: return self._mc return False @mc.setter def mc(self, value=False): self._mc = value @property def mc_models_num(self): if hasattr(self, "_mc_models_num"): return self._mc_models_num return 1 @mc_models_num.setter def mc_models_num(self, num=1): self._mc_models_num = num @property def base_seed(self): return self._base_seed @base_seed.setter def base_seed(self, seed=42): self._base_seed = seed @property def mc_seeds(self): return self._mc_seeds @mc_seeds.setter def mc_seeds(self, seeds=[]): self._mc_seeds = seeds @property def models_beam_logits_iter(self): if hasattr(self, "_models_beam_logits_iter"): return self._models_beam_logits_iter @models_beam_logits_iter.setter def models_beam_logits_iter(self, value): self._models_beam_logits_iter = value
[docs] def calculate_entropy_based_measures(self, enable=True): self.calculate_entropies = enable
def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None, ) -> Dict[str, Any]: if getattr(self, "models", None) is None: self._models_list = [] if self.mc: # 1. get encoders encoder = self.get_encoder() # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # as the inputs. if hasattr(encoder, "_hf_hook"): encoder._hf_hook.io_same_device = True # 2. prepare encoder args and encoder kwargs from model kwargs irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } encoder_signature = set(inspect.signature(encoder.forward).parameters) encoder_accepts_wildcard = ( "kwargs" in encoder_signature or "model_kwargs" in encoder_signature ) if not encoder_accepts_wildcard: encoder_kwargs = { argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature } # 3. make sure that encoder returns `ModelOutput` model_input_name = ( model_input_name if model_input_name is not None else self.main_input_name ) encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor outs = [] for i in range(self.mc_models_num): torch.manual_seed(self.mc_seeds[i]) outs.append(encoder(**encoder_kwargs)) torch.manual_seed(self.base_seed) model_kwargs["encoder_outputs"]: List[ModelOutput] = outs else: model_kwargs["encoder_outputs"]: List[ModelOutput] = [] for model in self.models: # 1. get encoder encoder = self.get_encoder() # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device # as the inputs. if hasattr(encoder, "_hf_hook"): encoder._hf_hook.io_same_device = True # 2. prepare encoder args and encoder kwargs from model kwargs irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } encoder_signature = set(inspect.signature(encoder.forward).parameters) encoder_accepts_wildcard = ( "kwargs" in encoder_signature or "model_kwargs" in encoder_signature ) if not encoder_accepts_wildcard: encoder_kwargs = { argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature } # 3. make sure that encoder returns `ModelOutput` model_input_name = ( model_input_name if model_input_name is not None else self.main_input_name ) encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor encoder_kwargs[model_input_name].to(model.device) encoder_kwargs = { k: v.to(model.device) for k, v in encoder_kwargs.items() if hasattr(v, "to") } model_kwargs["encoder_outputs"].append(encoder(**encoder_kwargs)) return model_kwargs @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, is_encoder_decoder: bool = False, input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" def _expand_dict_for_generation(dict_to_expand): for key in dict_to_expand: if dict_to_expand[key] is not None and isinstance( dict_to_expand[key], torch.Tensor ): dict_to_expand[key] = dict_to_expand[key].repeat_interleave( expand_size, dim=0 ) return dict_to_expand if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) model_kwargs = _expand_dict_for_generation(model_kwargs) if is_encoder_decoder: if model_kwargs.get("encoder_outputs") is None: raise ValueError( "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." ) encoder_outputs_expanded = [] for output in model_kwargs["encoder_outputs"]: encoder_outputs_expanded.append(_expand_dict_for_generation(output)) model_kwargs["encoder_outputs"] = encoder_outputs_expanded return input_ids, model_kwargs