Source code for lm_polygraph.generation_metrics.comet

import re
import numpy as np

from typing import List, Dict
from .generation_metric import GenerationMetric


[docs]class Comet(GenerationMetric): """ Calculates COMET metric (https://aclanthology.org/2020.emnlp-main.213/) between model-generated texts and ground truth texts. """ def __init__(self, source_ignore_regex=None, lang="en"): super().__init__(["greedy_texts", "input_texts"], "sequence") from evaluate import load self.scorer = load("comet") self.source_ignore_regex = ( re.compile(source_ignore_regex) if source_ignore_regex else None ) def __str__(self): return "Comet" def _filter_text(self, text: str, ignore_regex: re.Pattern) -> str: if ignore_regex is not None: processed_text = ignore_regex.search(text) if processed_text: return processed_text.group(1) else: raise ValueError( f"Source text {text} does not match the ignore regex {ignore_regex}" ) return text def __call__( self, stats: Dict[str, np.ndarray], target_texts: List[str], ) -> np.ndarray: """ Calculates COMET (https://aclanthology.org/2020.emnlp-main.213/) between stats['greedy_texts'], and target_texts. Parameters: stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes: * model-generated texts in 'greedy_texts' target_texts (List[str]): ground-truth texts input_texts (List[str]): input texts before translation Returns: np.ndarray: list of COMET Scores for each sample in input. """ sources = [ self._filter_text(src, self.source_ignore_regex) for src in stats["input_texts"] ] scores = np.array( self.scorer.compute( predictions=stats["greedy_texts"], references=target_texts, sources=sources, )["scores"] ) return scores