Source code for lm_polygraph.utils.normalize

from typing import List, Tuple, Dict

import numpy as np

from lm_polygraph.utils.manager import UEManager
from lm_polygraph.utils.common import seq_man_key


def _concat_mans_data(mans_data_dicts, names):
    """Concatenates data from multiple manager data dictionaries.

    Args:
    mans_data_dicts: List of dictionaries, where each dictionary contains
      data of particular type from a single manager.
      Each dictionary should have the same keys.
    names: List of value types to extract from the dictionaries.

    Returns:
    Dictionary, where keys are the input names and values are concatenated
    arrays of the data from all managers.
    """
    data = {}
    for name in names:
        man_data = []
        for man_data_dict in mans_data_dicts:
            key = seq_man_key(name)
            try:
                man_data.append(man_data_dict[key])
            except KeyError:
                raise KeyError(f"{key} not found in manager data")
        data[name] = np.concatenate(man_data)

    return data


[docs]def get_mans_ues_metrics( man_paths: List[str], ue_method_names: List[str], gen_metric_names: List[str] ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: """Extracts and concats data from a list of paths to saved manager data files. Args: man_paths: List of paths to manager data files ue_method_names: List of UE methods to extract gen_metric_names: List of gen_metrics to extract Returns: Tuple of two dictionaries: - First dictionary contains UE method data, where keys are method names and values are concatenated arrays of UE method data from all managers - Second dictionary contains gen_metric data, where keys are metric names and values are concatenated arrays of gen_metric data from all managers """ mans = [UEManager.load(path) for path in man_paths] mans_ues = [man.estimations for man in mans] mans_gen_metrics = [man.gen_metrics for man in mans] ues = _concat_mans_data(mans_ues, ue_method_names) gen_metrics = _concat_mans_data(mans_gen_metrics, gen_metric_names) return ues, gen_metrics
[docs]def filter_nans( gen_metrics: np.ndarray, ues: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """Filters out NaNs from gen_metrics and ues if they occur at least in one of the arrays. Args: gen_metrics: Array of gen_metrics ues: Array of ues Returns: Tuple of two arrays: - First array contains gen_metrics with NaNs removed - Second array contains ues with NaNs removed """ nan_mask = np.isnan(gen_metrics) | np.isnan(ues) gen_metrics = gen_metrics[~nan_mask] ues = ues[~nan_mask] return gen_metrics, ues