Source code for lm_polygraph.stat_calculators.statistic_extraction_visual

import gc
import torch
import numpy as np
from tqdm import tqdm
from typing import Dict, List, Tuple

from .stat_calculator import StatCalculator
from lm_polygraph.model_adapters.visual_whitebox_model import VisualWhiteboxModel
from .greedy_visual_probs import GreedyProbsVisualCalculator


[docs]class TrainingStatisticExtractionCalculatorVisual(StatCalculator):
[docs] @staticmethod def meta_info() -> Tuple[List[str], List[str]]: return [ "train_embeddings", "background_train_embeddings", "train_greedy_log_likelihoods", ], []
def __init__(self, train_dataset=None, background_train_dataset=None): super().__init__() self.hidden_layer = -1 self.train_dataset = train_dataset self.background_train_dataset = background_train_dataset self.statistics_extracted = False self.base_calculators = [GreedyProbsVisualCalculator()] def __call__( self, dependencies: Dict[str, np.ndarray], texts: List[str], model: VisualWhiteboxModel, max_new_tokens: int = 100, background_train_dataset_max_new_tokens: int = 100, ) -> Dict[str, np.ndarray]: if self.statistics_extracted: return {} else: train_stats = {} result_train_stat = {} datasets = [self.train_dataset, self.background_train_dataset] datasets_name = ["train_", "background_train_"] for dataset, dataset_name in zip(datasets, datasets_name): if dataset is None: continue train_max_new_tokens = ( max_new_tokens if dataset_name == "train_" else background_train_dataset_max_new_tokens ) for inp_texts, target_texts, images in tqdm(dataset): batch_stats: Dict[str, np.ndarray] = { "images": model.get_images(images), "input_texts": inp_texts, "target_texts": target_texts, } for stat_calculator in self.base_calculators: new_stats = stat_calculator( batch_stats, inp_texts, model, train_max_new_tokens ) for stat, stat_value in new_stats.items(): if stat in batch_stats.keys(): continue batch_stats[stat] = stat_value for stat in batch_stats.keys(): if stat in [ "input_tokens", "input_texts", "target_texts", "images", ]: continue key = dataset_name + stat if key in train_stats: train_stats[key].append(batch_stats[stat]) else: train_stats[key] = [batch_stats[stat]] torch.cuda.empty_cache() gc.collect() for stat in train_stats.keys(): if ( any(s is None for s in train_stats[stat]) or ("tokenizer" in stat) or ("processor" in stat) ): continue if isinstance(train_stats[stat][0], list): result_train_stat[stat] = [ item for sublist in train_stats[stat] for item in sublist ] else: result_train_stat[stat] = np.concatenate(train_stats[stat]) self.statistics_extracted = True return result_train_stat