Source code for lm_polygraph.estimators.semantic_entropy

import numpy as np

from typing import List, Dict, Optional

from .estimator import Estimator


[docs]class SemanticEntropy(Estimator): """ Estimates the sequence-level uncertainty of a language model following the method of "Semantic entropy" as provided in the paper https://arxiv.org/abs/2302.09664. Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel). This method calculates the generation entropy estimations merged by semantic classes using Monte-Carlo. The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator 'samples_n' parameter. """ def __init__( self, verbose: bool = False, class_probability_estimation: str = "sum" ): self.class_probability_estimation = class_probability_estimation if self.class_probability_estimation == "sum": deps = ["sample_log_probs", "sample_texts", "semantic_classes_entail"] elif self.class_probability_estimation == "frequency": deps = ["sample_texts", "semantic_classes_entail"] else: raise ValueError( f"Unknown class_probability_estimation: {self.class_probability_estimation}. Use 'sum' or 'frequency'." ) super().__init__(deps, "sequence") self.verbose = verbose def __str__(self): if self.class_probability_estimation == "sum": return "SemanticEntropy" elif self.class_probability_estimation == "frequency": return "SemanticEntropyEmpirical" def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray: """ Estimates the semantic entropy for each sample in the input statistics. Parameters: stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes: * generated samples in 'sample_texts', * corresponding log probabilities in 'sample_log_probs', * matrix with semantic similarities in 'semantic_matrix_entail' Returns: np.ndarray: float semantic entropy for each sample in input statistics. Higher values indicate more uncertain samples. """ if self.class_probability_estimation == "sum": loglikelihoods_list = stats["sample_log_probs"] hyps_list = stats["sample_texts"] elif self.class_probability_estimation == "frequency": loglikelihoods_list = None hyps_list = stats["sample_texts"] self._class_to_sample = stats["semantic_classes_entail"]["class_to_sample"] self._sample_to_class = stats["semantic_classes_entail"]["sample_to_class"] return self.batched_call(hyps_list, loglikelihoods_list)
[docs] def batched_call( self, hyps_list: List[List[str]], loglikelihoods_list: Optional[List[List[float]]], log_weights: Optional[List[List[float]]] = None, ) -> np.array: if log_weights is None: log_weights = [None for _ in hyps_list] semantic_logits = {} # Iteration over batch for i in range(len(hyps_list)): if self.class_probability_estimation == "sum": class_likelihoods = [ np.array(loglikelihoods_list[i])[np.array(class_idx)] for class_idx in self._class_to_sample[i] ] class_lp = [ np.logaddexp.reduce(likelihoods) for likelihoods in class_likelihoods ] elif self.class_probability_estimation == "frequency": num_samples = len(hyps_list[i]) class_lp = np.log( [ len(class_idx) / num_samples for class_idx in self._class_to_sample[i] ] ) if log_weights[i] is None: log_weights[i] = [0 for _ in hyps_list[i]] semantic_logits[i] = -np.mean( [ class_lp[self._sample_to_class[i][j]] * np.exp(log_weights[i][j]) for j in range(len(hyps_list[i])) ] ) return np.array([semantic_logits[i] for i in range(len(hyps_list))])