Source code for lm_polygraph.utils.register_stat_calculators

import os
import logging

from lm_polygraph.stat_calculators import *
from lm_polygraph.utils.deberta import Deberta, MultilingualDeberta
from lm_polygraph.utils.openai_chat import OpenAIChat
from lm_polygraph.utils.model import Model, BlackboxModel

from typing import Dict, List, Optional, Tuple

log = logging.getLogger("lm_polygraph")


[docs]def register_stat_calculators( deberta_batch_size: int = 10, # TODO: rename to NLI model deberta_device: Optional[str] = None, # TODO: rename to NLI model language: str = "en", n_ccp_alternatives: int = 10, cache_path=os.path.expanduser("~") + "/.cache", model: Model = None, ) -> Tuple[Dict[str, "StatCalculator"], Dict[str, List[str]]]: """ Registers all available statistic calculators to be seen by UEManager for properly organizing the calculations order. """ stat_calculators: Dict[str, "StatCalculator"] = {} stat_dependencies: Dict[str, List[str]] = {} log.info("=" * 100) log.info("Loading NLI model...") if language == "en": nli_model = Deberta(batch_size=deberta_batch_size, device=deberta_device) elif language in ["zh", "ar", "ru"]: nli_model = MultilingualDeberta( batch_size=deberta_batch_size, device=deberta_device, ) else: raise Exception(f"Unsupported language: {language}") log.info("=" * 100) log.info("Initializing stat calculators...") openai_chat = OpenAIChat(cache_path=cache_path) def _register(calculator_class: StatCalculator): for stat in calculator_class.stats: if stat in stat_calculators.keys(): raise ValueError( "A statistic is supposed to be processed by a single calculator only." ) stat_calculators[stat] = calculator_class stat_dependencies[stat] = calculator_class.stat_dependencies _register(InitialStateCalculator()) _register(SemanticMatrixCalculator(nli_model=nli_model)) _register(SemanticClassesCalculator()) if isinstance(model, BlackboxModel): _register(BlackboxGreedyTextsCalculator()) _register(BlackboxSamplingGenerationCalculator()) else: _register(GreedyProbsCalculator(n_alternatives=n_ccp_alternatives)) _register(EntropyCalculator()) _register(GreedyLMProbsCalculator()) _register(SamplingGenerationCalculator()) _register(BartScoreCalculator()) _register(ModelScoreCalculator()) _register(EmbeddingsCalculator()) _register(EnsembleTokenLevelDataCalculator()) _register(CrossEncoderSimilarityMatrixCalculator(nli_model=nli_model)) _register(GreedyAlternativesNLICalculator(nli_model=nli_model)) _register(GreedyAlternativesFactPrefNLICalculator(nli_model=nli_model)) _register(ClaimsExtractor(openai_chat=openai_chat, language=language)) _register( PromptCalculator( "Question: {q}\n Possible answer:{a}\n " "Is the possible answer:\n (A) True\n (B) False\n The possible answer is:", "True", "p_true", sample_text_dependency=None, # Not calculate T text samples for P(True) ) ) _register( PromptCalculator( "Question: {q}\n Here are some ideas that were brainstormed: {s}\n Possible answer:{a}\n " "Is the possible answer:\n (A) True\n (B) False\n The possible answer is:", "True", "p_true_sampling", ) ) _register( PromptCalculator( "Question: {q}\n Possible answer:{a}\n " "Is the possible answer True or False? The possible answer is: ", "True", "p_true_claim", input_text_dependency="claim_input_texts_concatenated", sample_text_dependency=None, generation_text_dependency="claim_texts_concatenated", ) ) log.info("Done intitializing stat calculators...") return stat_calculators, stat_dependencies