import warnings
from dataclasses import dataclass
from typing import Optional, Union, Dict, List, Tuple
import torch
import torch.distributed as dist
from torch.distributions.categorical import Categorical
from torch import nn
from transformers import GenerationMixin
try:
from transformers.generation.beam_search import BeamScorer
except ImportError:
# transformers >= 5.0 removed BeamScorer entirely
BeamScorer = None
from transformers.generation.logits_process import (
LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from transformers.generation.utils import ModelOutput
try:
from transformers.generation.utils import (
BeamSearchOutput,
BeamSearchDecoderOnlyOutput,
)
except ImportError:
# transformers >= 5.0 renamed these classes
from transformers.generation.utils import (
GenerateBeamEncoderDecoderOutput as BeamSearchOutput,
GenerateBeamDecoderOnlyOutput as BeamSearchDecoderOnlyOutput,
)
[docs]class EnsembleBeamSearchMixin(GenerationMixin):
[docs] def beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
Averages the function across the ensemble models
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Examples:
```python
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
... )
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... num_beams=num_beams,
... device=model.device,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```"""
if getattr(self, "models", None) is None:
self._models_list = []
# init values
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(
stopping_criteria, max_length
)
if len(stopping_criteria) == 0:
warnings.warn(
"You don't have defined any stopping_criteria, this will likely loop forever",
UserWarning,
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.generation_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.generation_config.eos_token_id
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = (
output_scores
if output_scores is not None
else self.generation_config.output_scores
)
output_attentions = (
output_attentions
if output_attentions is not None
else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
models_scores = [] if (return_dict_in_generate and output_scores) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size))
if (return_dict_in_generate and output_scores)
else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = (
model_kwargs["encoder_outputs"][0].get("attentions")
if output_attentions
else None
)
encoder_hidden_states = (
model_kwargs["encoder_outputs"][0].get("hidden_states")
if output_hidden_states
else None
)
beam_scores = torch.zeros(
(batch_size, num_beams), dtype=torch.float, device=input_ids.device
)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
encoder_outputs = model_kwargs.pop("encoder_outputs")
calculate_entropies = getattr(self, "calculate_entropies", True)
self.models_beam_tokens_iter = None
models_beam_next_token_logits = []
pe_uncertainties = {}
ep_uncertainties = {}
if calculate_entropies:
pe_uncertainties["total_uncertainty"] = []
pe_uncertainties["data_uncertainty"] = []
pe_uncertainties["mutual_information"] = []
pe_uncertainties["epkl_total_uncertainty"] = []
pe_uncertainties["epkl"] = []
pe_uncertainties["rmi"] = []
ep_uncertainties["total_uncertainty"] = []
ep_uncertainties["data_uncertainty"] = []
ep_uncertainties["mutual_information"] = []
ep_uncertainties["epkl_total_uncertainty"] = []
ep_uncertainties["epkl"] = []
ep_uncertainties["rmi"] = []
if self.mc:
num_models = self.mc_models_num
else:
num_models = len(self.models)
self.models_beam_logits_iter = None
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(
0.0 if this_peer_finished else 1.0
).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = []
if self.mc:
for i in range(self.mc_models_num):
torch.manual_seed(self.mc_seeds[i])
model_inputs.append(
self.prepare_inputs_for_generation(
input_ids,
encoder_outputs=encoder_outputs[i],
**model_kwargs,
)
)
torch.manual_seed(self.base_seed)
else:
for i in range(num_models):
dev = self.models[i].device
input_ids.to(dev)
model_kwargs = {
k: v.to(dev)
for k, v in model_kwargs.items()
if hasattr(v, "to")
}
model_inputs.append(
self.prepare_inputs_for_generation(
input_ids.to(dev),
encoder_outputs=encoder_outputs[i],
**model_kwargs,
)
)
models_next_token_probas = []
models_next_token_logits = []
models_entropies = []
models_outputs = []
if self.mc:
for i in range(self.mc_models_num):
torch.manual_seed(self.mc_seeds[i])
models_outputs.append(
self(
**model_inputs[i],
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
torch.manual_seed(self.base_seed)
else:
for i, model in enumerate(self.models):
models_outputs.append(
model(
**model_inputs[i],
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
for outputs in models_outputs:
model_next_token_logits = outputs.logits[:, -1, :].to(self.device)
model_next_token_scores = nn.functional.log_softmax(
model_next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
models_next_token_logits.append(model_next_token_scores)
models_next_token_probas.append(
model_next_token_scores.exp()
) # probas of one model
if calculate_entropies:
model_entropy = Categorical(models_next_token_probas[-1]).entropy()
models_entropies.append(model_entropy)
pe_next_token_scores = (
torch.stack(models_next_token_logits).logsumexp(dim=0)
- torch.tensor(num_models).log()
)
if self.models_beam_logits_iter is None:
self.models_beam_logits_iter = torch.zeros(
(num_models, batch_size * num_beams, 1)
).to(input_ids.device)
models_beam_logits = self.models_beam_logits_iter
denom = models_beam_logits.logsumexp(dim=0)
num = (
torch.stack(models_next_token_logits) + models_beam_logits
).logsumexp(dim=0)
ep_next_token_scores = num - denom
pe_next_token_probas = pe_next_token_scores.exp()
ep_next_token_probas = ep_next_token_scores.exp()
if calculate_entropies:
pe_token_total_unc = Categorical(pe_next_token_probas).entropy()
pe_token_data_unc = torch.stack(models_entropies).mean(0)
pe_token_mi = pe_token_total_unc - pe_token_data_unc
pe_token_av_logs = torch.stack(models_next_token_logits).mean(0)
pe_token_epkl_total_unc = -(
pe_token_av_logs * pe_next_token_probas
).sum(-1)
pe_token_epkl = pe_token_epkl_total_unc - pe_token_data_unc
pe_token_rmi = pe_token_epkl_total_unc - pe_token_total_unc
ep_token_total_unc = Categorical(ep_next_token_probas).entropy()
ep_token_data_unc = torch.stack(models_entropies).mean(0)
ep_token_mi = ep_token_total_unc - ep_token_data_unc
ep_token_av_logs = torch.stack(models_next_token_logits).mean(0)
ep_token_epkl_total_unc = -(
ep_token_av_logs * ep_next_token_probas
).sum(-1)
ep_token_epkl = ep_token_epkl_total_unc - ep_token_data_unc
ep_token_rmi = ep_token_epkl_total_unc - ep_token_total_unc
if self.ensembling_mode == "pe":
next_token_scores = pe_next_token_scores
elif self.ensembling_mode == "ep":
next_token_scores = ep_next_token_scores
else:
raise NotImplementedError
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
iter_models_scores = []
for model_scores in models_next_token_logits:
model_scores_processed = logits_processor(input_ids, model_scores)
iter_models_scores.append(model_scores_processed)
next_token_scores = next_token_scores_processed + beam_scores[
:, None
].expand_as(next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores_processed,)
models_scores.append(iter_models_scores)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
if calculate_entropies:
pe_uncertainties["total_uncertainty"].append(pe_token_total_unc)
pe_uncertainties["data_uncertainty"].append(pe_token_data_unc)
pe_uncertainties["mutual_information"].append(pe_token_mi)
pe_uncertainties["epkl_total_uncertainty"].append(
pe_token_epkl_total_unc
)
pe_uncertainties["epkl"].append(pe_token_epkl)
pe_uncertainties["rmi"].append(pe_token_rmi)
ep_uncertainties["total_uncertainty"].append(ep_token_total_unc)
ep_uncertainties["data_uncertainty"].append(ep_token_data_unc)
ep_uncertainties["mutual_information"].append(ep_token_mi)
ep_uncertainties["epkl_total_uncertainty"].append(
ep_token_epkl_total_unc
)
ep_uncertainties["epkl"].append(ep_token_epkl)
ep_uncertainties["rmi"].append(ep_token_rmi)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(
batch_size, num_beams * vocab_size
)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat(
[input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
)
token_models_beam_logits = torch.stack(models_next_token_logits)[
:, beam_idx, :
]
token_models_beam_logits = torch.gather(
token_models_beam_logits,
-1,
beam_next_tokens.repeat((num_models), 1).unsqueeze(-1),
)
self.models_beam_logits_iter = torch.cat(
(
self.models_beam_logits_iter[:, beam_idx, :],
token_models_beam_logits,
),
-1,
)
# Finished hypos may have -inf as a value of selected logit
# Manually set them to 1 (as normal beam scorer does)
finished_beams = beam_next_tokens == self.config.pad_token_id
self.models_beam_logits_iter[:, finished_beams, -1] = 0.0
models_beam_logits = self.models_beam_logits_iter.sum(-1, keepdims=True)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if "past" not in model_kwargs.keys():
# model_kwargs["past"] = None
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
if return_dict_in_generate and output_scores:
beam_indices = tuple(
(
beam_indices[beam_idx[i]] + (beam_idx[i],)
for i in range(len(beam_indices))
)
)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
models_scores=models_scores,
beam_indices=sequence_outputs["beam_indices"],
models_beam_next_token_logits=models_beam_next_token_logits,
pe_uncertainties=pe_uncertainties,
ep_uncertainties=ep_uncertainties,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
# TODO: This here needs to change for decoder-only ensembles in the future
else:
return BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return sequence_outputs["sequences"]
[docs]@dataclass
class BeamSearchEncoderDecoderOutput(ModelOutput):
"""
Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
Args:
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, max_length-1)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
sequence_length)`.
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
"""
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
models_scores: Optional[Tuple[List[torch.FloatTensor]]] = None
models_beam_next_token_logits: Optional[Tuple[torch.FloatTensor]] = None
pe_uncertainties: Optional[Dict[str, List[torch.FloatTensor]]] = None
ep_uncertainties: Optional[Dict[str, List[torch.FloatTensor]]] = None
beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None