Source code for lm_polygraph.estimators.lexical_similarity

import numpy as np
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu

from typing import Dict

from .estimator import Estimator

from absl import logging as absl_logging

# This prevents bullshit spam from rouge scorer
absl_logging.set_verbosity(absl_logging.WARNING)


[docs]class LexicalSimilarity(Estimator): """ Estimates the sequence-level uncertainty of a language model following the method of "Lexical Similarity" as provided in the paper https://arxiv.org/abs/2302.09664. Works with both whitebox and blackbox models (initialized using lm_polygraph.utils.model.BlackboxModel/WhiteboxModel). The method calculates mean similarity between all pairs of sampled generations with minus sign. The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator 'samples_n' parameter. """ def __init__(self, metric: str = "rougeL"): """ Parameters: metric (str): similarity metric (default: 'rougeL'). Possible values: * rouge1 / rouge2 / rougeL * BLEU """ self.metric = metric if self.metric.startswith("rouge"): self.scorer = rouge_scorer.RougeScorer([self.metric], use_stemmer=True) super().__init__(["sample_texts"], "sequence") def __str__(self): return f"LexicalSimilarity_{self.metric}" def _score_single(self, t1: str, t2: str): if self.metric.startswith("rouge"): return self.scorer.score(t1, t2)[self.metric].fmeasure elif self.metric == "BLEU": min_sentence_len = min(len(t1.split()), len(t2.split())) if min_sentence_len == 1: weights = [1.0, 0.0, 0.0, 0.0] elif min_sentence_len == 2: weights = [0.5, 0.5, 0.0, 0.0] elif min_sentence_len == 3: weights = [0.33, 0.33, 0.33, 0.0] else: # default weights in sentence_bleu weights = [0.25, 0.25, 0.25, 0.25] return sentence_bleu([t1.split()], t2.split(), weights=weights) else: raise Exception(f"Unknown metrics for lexical similarity: {self.metric}") def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray: """ Estimates the mean similarity with minus sign for each sample in the input statistics. Parameters: stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes: * several sampled texts in 'sample_texts' Returns: np.ndarray: float uncertainty for each sample in input statistics. Higher values indicate more uncertain samples. """ batch_texts = stats["sample_texts"] res = [] for texts in batch_texts: sims = [] for i in range(len(texts)): for j in range(i + 1, len(texts)): sims.append(self._score_single(texts[i], texts[j])) res.append(-np.mean(sims)) return np.array(res)