from typing import Dict
import numpy as np
import torch
from .estimator import Estimator
[docs]def all_ep_estimators():
return [EPStu(), EPSrmi()]
[docs]def all_pe_estimators():
return [PEStu(), PESrmi()]
[docs]def get_seq_level_ue(
sequence_level_data: Dict[str, torch.Tensor],
) -> Dict[str, np.ndarray]:
softmax_t = 1
model_log_probas = sequence_level_data[
"log_probas"
] # num_obs x num_models x num_beams
num_models = model_log_probas.shape[1]
ens_log_probas = (
torch.tensor(model_log_probas).logsumexp(1) - torch.tensor(num_models).log()
) # num_obs x num_beams
ens_log_probas = ens_log_probas.numpy()
ens_probas = np.exp(ens_log_probas)
ens_probas_exp = ens_probas**softmax_t
weights = ens_probas_exp / ens_probas_exp.sum(-1, keepdims=True)
tu = (ens_probas * weights).sum(-1) # num_obs
ens_log_probas = np.repeat(ens_log_probas[:, None, :], num_models, axis=1)
rmi_weights = np.repeat(weights[:, None, :], num_models, axis=1)
rmi_base = (ens_log_probas - model_log_probas) * rmi_weights
rmi = rmi_base.sum(-1).mean(1) # num_obs
rmi_abs = np.abs(rmi_base).sum(-1).mean(1) # num_obs
uncertainty_estimates = {
"tu": -tu,
"rmi": rmi,
"rmi-abs": rmi_abs,
}
return uncertainty_estimates
[docs]class EPStu(Estimator):
def __init__(self):
super().__init__(["ensemble_token_scores"], "sequence")
def __str__(self):
return "EPStu"
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
sequence_level_data = stats["ensemble_token_scores"]["ep_token_level_scores"]
return get_seq_level_ue(sequence_level_data)["tu"]
[docs]class EPSrmi(Estimator):
def __init__(self):
super().__init__(["ensemble_token_scores"], "sequence")
def __str__(self):
return "EPSrmi"
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
sequence_level_data = stats["ensemble_token_scores"]["ep_token_level_scores"]
return get_seq_level_ue(sequence_level_data)["rmi"]
[docs]class EPSrmiabs(Estimator):
def __init__(self):
super().__init__(["ensemble_token_scores"], "sequence")
def __str__(self):
return "EPSrmiabs"
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
sequence_level_data = stats["ensemble_token_scores"]["ep_token_level_scores"]
return get_seq_level_ue(sequence_level_data)["rmi-abs"]
[docs]class PEStu(Estimator):
def __init__(self):
super().__init__(["ensemble_token_scores"], "sequence")
def __str__(self):
return "PEStu"
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
sequence_level_data = stats["ensemble_token_scores"]["pe_token_level_scores"]
return get_seq_level_ue(sequence_level_data)["tu"]
[docs]class PESrmi(Estimator):
def __init__(self):
super().__init__(["ensemble_token_scores"], "sequence")
def __str__(self):
return "PESrmi"
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
sequence_level_data = stats["ensemble_token_scores"]["pe_token_level_scores"]
return get_seq_level_ue(sequence_level_data)["rmi"]
[docs]class PESrmiabs(Estimator):
def __init__(self):
super().__init__(["ensemble_token_scores"], "sequence")
def __str__(self):
return "PESrmiabs"
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
sequence_level_data = stats["ensemble_token_scores"]["pe_token_level_scores"]
return get_seq_level_ue(sequence_level_data)["rmi-abs"]