import numpy as np
import logging
import os
import torch
from typing import Dict, List
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from .estimator import Estimator
log = logging.getLogger(__name__)
[docs]def split_classes(
tokens: List[str],
batch_size: int,
semantic_bert_path: str = "sentence-transformers/bert-base-nli-mean-tokens",
sim_threshold: float = 0.85,
) -> np.ndarray:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(semantic_bert_path)
model = AutoModel.from_pretrained(semantic_bert_path).to(device)
classes_embeddings: np.ndarray = np.zeros(
shape=(0, model.pooler.dense.out_features)
)
classes_sizes: List[int] = []
sample_to_class: List[int] = []
rng = tqdm(
range(0, len(tokens), batch_size),
total=(len(tokens) + batch_size - 1) // batch_size,
)
for i in rng:
batch = tokenizer(tokens[i : i + batch_size], return_tensors="pt", padding=True)
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
state = model(**batch).last_hidden_state
embeddings = (
torch.vstack(
[
l[attn == 1].mean(0)
for attn, l in zip(batch["attention_mask"], state)
]
)
.cpu()
.numpy()
)
for j in range(len(embeddings)):
if len(classes_embeddings) != 0:
sims = cosine_similarity([embeddings[j]], classes_embeddings)
else:
sims = []
if len(sims) == 0 or sims.max() < sim_threshold:
classes_embeddings = np.append(
classes_embeddings, embeddings[j].reshape(1, -1), axis=0
)
sample_to_class.append(len(classes_sizes))
classes_sizes.append(0)
else:
cl = sims.argmax()
classes_embeddings[cl] = (
classes_embeddings[cl] * classes_sizes[cl] + embeddings[j]
) / (classes_sizes[cl] + 1)
classes_sizes[cl] += 1
sample_to_class.append(cl)
rng.set_description(
f"{min(i + batch_size, len(tokens))} tokens, {len(classes_sizes)} classes"
)
return np.array(sample_to_class)
[docs]class SemanticEntropyToken(Estimator):
def __init__(
self,
tokenizer_path: str,
tokenizer_save_path: str,
semantic_bert_path: str = "sentence-transformers/bert-base-nli-mean-tokens",
batch_size: int = 10,
):
super().__init__(["greedy_log_probs"], "token")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, padding_side="left", add_bos_token=True
)
tokens = [tokenizer.decode([i]) for i in range(len(tokenizer))]
tokenizer_classes_path = os.path.join(
tokenizer_save_path, tokenizer.name_or_path.split("/")[-1] + "_classes.npy"
)
if os.path.exists(tokenizer_classes_path):
log.info(f"Loading tokenizer classes from {tokenizer_classes_path}")
self.classes: np.ndarray = np.load(tokenizer_classes_path)
else:
self.classes: np.ndarray = split_classes(
tokens, batch_size, semantic_bert_path
)
log.info(f"Saving tokenizer classes at {tokenizer_classes_path}")
np.save(tokenizer_classes_path, self.classes)
def __str__(self):
return "SemanticEntropyToken"
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
logprobs = stats["greedy_log_probs"]
sem_ent: List[List[float]] = []
for s_lp in logprobs:
sem_ent.append([])
for lp in s_lp[:-1]:
p = np.exp(lp)[: len(self.classes)]
class_probs = np.bincount(self.classes, weights=p)
sem_ent[-1].append(-np.mean(class_probs * np.log(class_probs)))
return sem_ent
if __name__ == "__main__":
print(
split_classes(
[
"bad",
"awful",
"terrible",
"horrible",
"good",
"fine",
"excellent",
"great",
],
100,
)
)