Source code for lm_polygraph.estimators.semantic_entropy_token

import numpy as np
import logging
import os
import torch
from typing import Dict, List
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from .estimator import Estimator

log = logging.getLogger(__name__)


[docs]def split_classes( tokens: List[str], batch_size: int, semantic_bert_path: str = "sentence-transformers/bert-base-nli-mean-tokens", sim_threshold: float = 0.85, ) -> np.ndarray: device = "cuda:0" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(semantic_bert_path) model = AutoModel.from_pretrained(semantic_bert_path).to(device) classes_embeddings: np.ndarray = np.zeros( shape=(0, model.pooler.dense.out_features) ) classes_sizes: List[int] = [] sample_to_class: List[int] = [] rng = tqdm( range(0, len(tokens), batch_size), total=(len(tokens) + batch_size - 1) // batch_size, ) for i in rng: batch = tokenizer(tokens[i : i + batch_size], return_tensors="pt", padding=True) batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): state = model(**batch).last_hidden_state embeddings = ( torch.vstack( [ l[attn == 1].mean(0) for attn, l in zip(batch["attention_mask"], state) ] ) .cpu() .numpy() ) for j in range(len(embeddings)): if len(classes_embeddings) != 0: sims = cosine_similarity([embeddings[j]], classes_embeddings) else: sims = [] if len(sims) == 0 or sims.max() < sim_threshold: classes_embeddings = np.append( classes_embeddings, embeddings[j].reshape(1, -1), axis=0 ) sample_to_class.append(len(classes_sizes)) classes_sizes.append(0) else: cl = sims.argmax() classes_embeddings[cl] = ( classes_embeddings[cl] * classes_sizes[cl] + embeddings[j] ) / (classes_sizes[cl] + 1) classes_sizes[cl] += 1 sample_to_class.append(cl) rng.set_description( f"{min(i + batch_size, len(tokens))} tokens, {len(classes_sizes)} classes" ) return np.array(sample_to_class)
[docs]class SemanticEntropyToken(Estimator): def __init__( self, tokenizer_path: str, tokenizer_save_path: str, semantic_bert_path: str = "sentence-transformers/bert-base-nli-mean-tokens", batch_size: int = 10, ): super().__init__(["greedy_log_probs"], "token") tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, padding_side="left", add_bos_token=True ) tokens = [tokenizer.decode([i]) for i in range(len(tokenizer))] tokenizer_classes_path = os.path.join( tokenizer_save_path, tokenizer.name_or_path.split("/")[-1] + "_classes.npy" ) if os.path.exists(tokenizer_classes_path): log.info(f"Loading tokenizer classes from {tokenizer_classes_path}") self.classes: np.ndarray = np.load(tokenizer_classes_path) else: self.classes: np.ndarray = split_classes( tokens, batch_size, semantic_bert_path ) log.info(f"Saving tokenizer classes at {tokenizer_classes_path}") np.save(tokenizer_classes_path, self.classes) def __str__(self): return "SemanticEntropyToken" def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray: logprobs = stats["greedy_log_probs"] sem_ent: List[List[float]] = [] for s_lp in logprobs: sem_ent.append([]) for lp in s_lp[:-1]: p = np.exp(lp)[: len(self.classes)] class_probs = np.bincount(self.classes, weights=p) sem_ent[-1].append(-np.mean(class_probs * np.log(class_probs))) return sem_ent
if __name__ == "__main__": print( split_classes( [ "bad", "awful", "terrible", "horrible", "good", "fine", "excellent", "great", ], 100, ) )