Source code for lm_polygraph.estimators.num_sem_sets

import numpy as np
import logging
from typing import Dict
from .estimator import Estimator
import torch.nn as nn

log = logging.getLogger(__name__)

softmax = nn.Softmax(dim=1)


[docs]class NumSemSets(Estimator): """ Estimates the sequence-level uncertainty of a language model following the method of "Number of Semantic Sets" as provided in the paper https://arxiv.org/abs/2305.19187. Works with both whitebox and blackbox models (initialized using lm_polygraph.utils.model.BlackboxModel/WhiteboxModel). """ def __init__( self, verbose: bool = False, ): super().__init__( [ "semantic_matrix_entail", "semantic_matrix_contra", "sample_texts", ], "sequence", ) self.verbose = verbose def __str__(self): return "NumSemSets"
[docs] def find_connected_components(self, graph): def dfs(node, component): visited[node] = True component.append(node) for neighbor in graph[node]: if not visited[neighbor]: dfs(neighbor, component) visited = [False] * len(graph) components = [] for i in range(len(graph)): if not visited[i]: component = [] dfs(i, component) components.append(component) return components
[docs] def U_NumSemSets(self, i, stats): # We have forward upper triangular and backward in lower triangular # parts of the semantic matrices W_entail = stats["semantic_matrix_entail"][i, :, :] W_contra = stats["semantic_matrix_contra"][i, :, :] # We check that for every pairing (both forward and backward) # the condition satisfies W = (W_entail > W_contra).astype(int) # Multiply by it's transpose to get symmetric matrix of full condition W = W * np.transpose(W) # Take upper triangular part with no diag W = np.triu(W, k=1) a = [[i] for i in range(W.shape[0])] # Iterate through each row in 'W' and update the corresponding row in 'a' for i, row in enumerate(W): # Find the indices of non-zero elements in the current row non_zero_indices = np.where(row != 0)[0] # Append the non-zero indices to the corresponding row in 'a' a[i].extend(non_zero_indices.tolist()) # Create an adjacency list representation of the graph graph = [[] for _ in range(len(a))] for sublist in a: for i in range(len(sublist) - 1): graph[sublist[i]].append(sublist[i + 1]) graph[sublist[i + 1]].append(sublist[i]) # Find the connected components connected_components = self.find_connected_components(graph) # Calculate the number of connected components # Cast to float for consistency with other estimators num_components = float(len(connected_components)) return num_components
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray: """ Estimates the number of semantic sets for each sample in the input statistics. Parameters: stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes: * generated samples in 'sample_texts', * matrix with corresponding semantic similarities in 'semantic_matrix_entail' and 'semantic_matrix_contra' Returns: np.ndarray: number of semantic sets for each sample in input statistics. Higher values indicate more uncertain samples. """ res = [] for i, answers in enumerate(stats["sample_texts"]): if self.verbose: log.debug(f"generated answers: {answers}") res.append(self.U_NumSemSets(i, stats)) return np.array(res)