Source code for lm_polygraph.stat_calculators.statistic_extraction

import gc
import torch
import numpy as np

from typing import Dict, List, Tuple

from .stat_calculator import StatCalculator
from lm_polygraph.utils.model import WhiteboxModel
from .greedy_probs import GreedyProbsCalculator


[docs]class TrainingStatisticExtractionCalculator(StatCalculator):
[docs] @staticmethod def meta_info() -> Tuple[List[str], List[str]]: """ Returns the statistics and dependencies for the calculator. """ return [ "train_embeddings", "background_train_embeddings", "train_greedy_log_likelihoods", ], []
def __init__(self, train_dataset=None, background_train_dataset=None): super().__init__() self.hidden_layer = -1 self.train_dataset = train_dataset self.background_train_dataset = background_train_dataset self.statistics_extracted = False self.base_calculators = [GreedyProbsCalculator(output_hidden_states=True)] def __call__( self, dependencies: Dict[str, np.array], texts: List[str], model: WhiteboxModel, max_new_tokens: int = 100, background_train_dataset_max_new_tokens: int = 100, ) -> Dict[str, np.ndarray]: if self.statistics_extracted: return {} else: train_stats = {} result_train_stat = {} datasets = [self.train_dataset, self.background_train_dataset] datasets_name = ["train_", "background_train_"] for dataset, dataset_name in zip(datasets, datasets_name): if dataset is None: continue train_max_new_tokens = ( max_new_tokens if datasets_name == "train_" else background_train_dataset_max_new_tokens ) for batch_i, batch in enumerate(dataset): if len(batch) == 3: inp_texts, target_texts, images = batch elif len(batch) == 2: inp_texts, target_texts = batch else: raise ValueError( f"Expected batch with 2 or 3 elements, got {len(batch)}" ) batch_stats: Dict[str, np.ndarray] = {} for key, val in [ ("input_texts", inp_texts), ("target_texts", target_texts), ]: batch_stats[key] = val for stat_calculator in self.base_calculators: new_stats = stat_calculator( batch_stats, inp_texts, model, train_max_new_tokens ) for stat, stat_value in new_stats.items(): if stat in batch_stats.keys(): continue batch_stats[stat] = stat_value for stat in batch_stats.keys(): if stat in [ "input_tokens", "input_texts", "target_texts", ]: continue if dataset_name + stat in train_stats.keys(): train_stats[dataset_name + stat].append(batch_stats[stat]) else: train_stats[dataset_name + stat] = [batch_stats[stat]] torch.cuda.empty_cache() gc.collect() for stat in train_stats.keys(): if any(s is None for s in train_stats[stat]) or ("tokenizer" in stat): continue if isinstance(train_stats[stat][0], list): result_train_stat[stat] = [ item for sublist in train_stats[stat] for item in sublist ] else: result_train_stat[stat] = np.concatenate(train_stats[stat]) self.statistics_extracted = True return result_train_stat