Source code for lm_polygraph.generation_metrics.generation_metric

import numpy as np

from typing import List, Dict
from abc import ABC, abstractmethod
from lm_polygraph.utils.common import polygraph_module_init


[docs]class GenerationMetric(ABC): """ Abstract generation metric class, which measures ground-truth uncertainty by comparing model-generated text with dataset ground-truth text. This ground-truth uncertainty is further compared with different estimators' uncertainties in UEManager using ue_metrics. """ @polygraph_module_init def __init__(self, stats_dependencies: List[str], level: str): """ Parameters: stats_dependencies (List[str]): listed statistics which need to be calculated and passed in __call__ method. Statistics should include only names of statistics registered in lm_polygraph/stat_calculators/__init__.py level (str): uncertainty estimation level. Possible values: * 'sequence': method should output GenerationMetric for each input sequence in __call__. * 'token': method should output GenerationMetric for each token in input sequence in __call__. """ assert level in ["sequence", "claim", "token"] self.level = level self.stats_dependencies = stats_dependencies @abstractmethod def __str__(self): """ Abstract method. Returns unique name of the generation metric. Class parameters which affect generation metric estimates should also be included in the unique name to diversify between generation metrics. """ raise Exception("Not implemented") @abstractmethod def __call__( self, stats: Dict[str, np.ndarray], target_texts: List[str], ) -> np.ndarray: """ Abstract method. Measures ground-truth uncertainty by comparing model-generated statistics from `stats` with dataset ground-truth texts from `target-texts`. Parameters: stats (Dict[str, np.ndarray]): input statistics, which includes values from statistics calculators for each self.stat_dependencies, target_texts (List[str]): ground-truth texts, Returns: np.ndarray: list of float ground-truth uncertainties calculated by comparing input statistics with ground truth texts. Should be 1-dimensional (in case of token-level, generation metrics from different samples should be concatenated). Higher values should indicate more confident samples (more similar to ground truth texts). """ raise Exception("Not implemented")