Source code for lm_polygraph.stat_calculators.stat_calculator

import numpy as np

from typing import List, Dict, Tuple
from abc import ABC, abstractmethod
from lm_polygraph.utils.model import Model
from lm_polygraph.utils.common import polygraph_module_init


[docs]class StatCalculator(ABC): """ Abstract class for some particular statistics calculation. Used to re-use same statistics across different uncertainty estimators at `lm_polygraph.estimators`. See the list of available calculators at lm_polygraph/stat_calculators/__init__.py. While estimators specify `stats_dependencies` to re-use these StatCalculator calculations, calculators can also specify dependencies on other calculators. UEManager at lm_polygraph.utils.manager will order all the needed calculators and estimators to be called in the correct order. Any cycle dependencies among calculators will be spotted by UEManager and end with an exception. Each new StatCalculator needs to be registered at lm_polygraph/stat_calculators/__init__.py to be seen be UEManager. """
[docs] @staticmethod def meta_info() -> Tuple[List[str], List[str]]: """ Placeholder method to return the list of statistics and dependencies for the calculator. """ raise NotImplementedError( f"Implement static meta_info() method {__class__.__name__}" )
@polygraph_module_init def __init__(self): """ Parameters: stats: List[str]: Names of statiscits which can be calculated by using this StatCalculator. stat_dependencies: List[str]: Names of statistics which this calculator needs to use. Can be any names of other StatCalculators. Any cycle dependencies among calculators will be spotted by UEManager and end with an exception. """ self._stats, self._stat_dependencies = self.__class__.meta_info() @abstractmethod def __call__( self, dependencies: Dict[str, np.array], texts: List[str], model: Model, max_new_tokens: int = 100, **kwargs, ) -> Dict[str, np.ndarray]: """ Abstract method. Calculates the statistic based on the other provided statistics. Parameters: dependencies (Dict[str, np.ndarray]): input statistics, which includes values from statistics calculators for each `stat_dependencies`. texts (List[str]): Input texts batch used for model generation. model (Model): Model used for generation. max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100. Returns: Dict[str, np.ndarray]: dictionary with calculated statistics under all keys from `stats`. """ raise Exception("Not implemented") @property def stats(self) -> List[str]: """ Returns: List[str]: Names of statistics which can be calculated by this class. """ return self._stats @property def stat_dependencies(self) -> List[str]: """ Returns: List[str]: Names of statistics dependencies which this class needs at __call__. """ return self._stat_dependencies