import numpy as np
import logging
from typing import Dict
from .estimator import Estimator
import torch.nn as nn
log = logging.getLogger(__name__)
softmax = nn.Softmax(dim=1)
[docs]class NumSemSets(Estimator):
"""
Estimates the sequence-level uncertainty of a language model following the method of
"Number of Semantic Sets" as provided in the paper https://arxiv.org/abs/2305.19187.
Works with both whitebox and blackbox models (initialized using
lm_polygraph.utils.model.BlackboxModel/WhiteboxModel).
"""
def __init__(
self,
verbose: bool = False,
):
super().__init__(
[
"semantic_matrix_entail",
"semantic_matrix_contra",
"sample_texts",
],
"sequence",
)
self.verbose = verbose
def __str__(self):
return "NumSemSets"
[docs] def find_connected_components(self, graph):
def dfs(node, component):
visited[node] = True
component.append(node)
for neighbor in graph[node]:
if not visited[neighbor]:
dfs(neighbor, component)
visited = [False] * len(graph)
components = []
for i in range(len(graph)):
if not visited[i]:
component = []
dfs(i, component)
components.append(component)
return components
[docs] def U_NumSemSets(self, i, stats):
# We have forward upper triangular and backward in lower triangular
# parts of the semantic matrices
W_entail = stats["semantic_matrix_entail"][i, :, :]
W_contra = stats["semantic_matrix_contra"][i, :, :]
# We check that for every pairing (both forward and backward)
# the condition satisfies
W = (W_entail > W_contra).astype(int)
# Multiply by it's transpose to get symmetric matrix of full condition
W = W * np.transpose(W)
# Take upper triangular part with no diag
W = np.triu(W, k=1)
a = [[i] for i in range(W.shape[0])]
# Iterate through each row in 'W' and update the corresponding row in 'a'
for i, row in enumerate(W):
# Find the indices of non-zero elements in the current row
non_zero_indices = np.where(row != 0)[0]
# Append the non-zero indices to the corresponding row in 'a'
a[i].extend(non_zero_indices.tolist())
# Create an adjacency list representation of the graph
graph = [[] for _ in range(len(a))]
for sublist in a:
for i in range(len(sublist) - 1):
graph[sublist[i]].append(sublist[i + 1])
graph[sublist[i + 1]].append(sublist[i])
# Find the connected components
connected_components = self.find_connected_components(graph)
# Calculate the number of connected components
# Cast to float for consistency with other estimators
num_components = float(len(connected_components))
return num_components
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
"""
Estimates the number of semantic sets for each sample in the input statistics.
Parameters:
stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
* generated samples in 'sample_texts',
* matrix with corresponding semantic similarities in
'semantic_matrix_entail' and 'semantic_matrix_contra'
Returns:
np.ndarray: number of semantic sets for each sample in input statistics.
Higher values indicate more uncertain samples.
"""
res = []
for i, answers in enumerate(stats["sample_texts"]):
if self.verbose:
log.debug(f"generated answers: {answers}")
res.append(self.U_NumSemSets(i, stats))
return np.array(res)