Source code for lm_polygraph.utils.ensemble_utils.ensemble_beam

import warnings
from dataclasses import dataclass
from typing import Optional, Union, Dict, List, Tuple

import torch
import torch.distributed as dist
from torch.distributions.categorical import Categorical
from torch import nn

from transformers import GenerationMixin

try:
    from transformers.generation.beam_search import BeamScorer
except ImportError:
    # transformers >= 5.0 removed BeamScorer entirely
    BeamScorer = None

from transformers.generation.logits_process import (
    LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from transformers.generation.utils import ModelOutput

try:
    from transformers.generation.utils import (
        BeamSearchOutput,
        BeamSearchDecoderOnlyOutput,
    )
except ImportError:
    # transformers >= 5.0 renamed these classes
    from transformers.generation.utils import (
        GenerateBeamEncoderDecoderOutput as BeamSearchOutput,
        GenerateBeamDecoderOnlyOutput as BeamSearchDecoderOnlyOutput,
    )


[docs]class EnsembleBeamSearchMixin(GenerationMixin):
[docs]@dataclass class BeamSearchEncoderDecoderOutput(ModelOutput): """ Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) Args: sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Final beam scores of the generated `sequences`. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, config.vocab_size)`). beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, max_length-1)`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, sequence_length)`. cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. """ sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None models_scores: Optional[Tuple[List[torch.FloatTensor]]] = None models_beam_next_token_logits: Optional[Tuple[torch.FloatTensor]] = None pe_uncertainties: Optional[Dict[str, List[torch.FloatTensor]]] = None ep_uncertainties: Optional[Dict[str, List[torch.FloatTensor]]] = None beam_indices: Optional[torch.LongTensor] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None