Source code for lm_polygraph.generation_metrics.aggregated_metric

import numpy as np

from typing import List, Dict
from .generation_metric import GenerationMetric


[docs]class AggregatedMetric(GenerationMetric): """ Aggregated metric class, which wraps a base metric and aggregates its results for multi-target datasets. """ def __init__(self, base_metric: GenerationMetric, aggregation: str = "max"): self.base_metric = base_metric self.level = base_metric.level self.stats_dependencies = base_metric.stats_dependencies self.aggregation = aggregation def __str__(self): # Check if base_metric is a wrapper, else return base_metric return str(getattr(self.base_metric, "base_metric", self.base_metric)) def __call__( self, stats: Dict[str, np.ndarray], target_texts: List[str], ) -> np.ndarray: """ Calculates aggregated metric between stats['greedy_texts'] and target_texts. Parameters: stats (Dict[str, np.ndarray]): calculated stats target_texts (List[str]): ground-truth texts Returns: np.ndarray: list of aggregated metric values for each sample in input. """ metric_values = [] for i, (targets, greedy_text) in enumerate( zip(target_texts, stats["greedy_texts"]) ): # truncate stats to only process one sample at a time truncated_stats = { k: [v[i]] for k, v in stats.items() if k in self.stats_dependencies + ["greedy_texts"] } sample_metric_values = [] for target in targets: value = self.base_metric(truncated_stats, [target]) sample_metric_values.append(value) if self.aggregation == "max": metric_values.append(np.nanmax(sample_metric_values)) else: raise ValueError(f"Unknown aggregation type: {self.aggregation}") return np.array(metric_values)