import numpy as np
from typing import Dict
from .estimator import Estimator
[docs]class ClaimConditionedProbability(Estimator):
def __init__(self):
super().__init__(
[
"greedy_tokens",
"greedy_tokens_alternatives",
"greedy_tokens_alternatives_nli",
],
"sequence",
)
def __str__(self):
return "CCP"
def _reduce(self, logprobs: list[float]):
return np.exp(np.sum(logprobs))
def _combine_nli(self, forward: str, backward: str):
"""
Combines two NLI predictions NLI(x, y) and NLI(y, x) into a single prediction.
Prioritizes "entail" or "contra" if present, otherwise returns "neutral".
"""
if forward == backward:
return forward
if all(x in [forward, backward] for x in ["entail", "contra"]):
return "neutral"
for x in ["entail", "contra"]:
if x in [forward, backward]:
return x
return "neutral"
def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
words = stats["greedy_tokens"]
alternatives = stats["greedy_tokens_alternatives"]
alternatives_nli = stats["greedy_tokens_alternatives_nli"]
prob_nli = []
for sample_words, sample_alternatives, sample_alternatives_nli in zip(
words,
alternatives,
alternatives_nli,
):
sample_mnlis = []
for word, word_alternatives, word_alternatives_nli in zip(
sample_words,
sample_alternatives,
sample_alternatives_nli,
):
entail_logprobs, entail_words = [], []
contra_logprobs, contra_words = [], []
for i in range(len(word_alternatives)):
word_alt, logprob = word_alternatives[i]
nli_outcome = self._combine_nli(
word_alternatives_nli[0][i],
word_alternatives_nli[i][0],
)
if i == 0 or nli_outcome == "entail":
entail_logprobs.append(logprob)
entail_words.append(word_alt)
elif nli_outcome == "contra":
contra_logprobs.append(logprob)
contra_words.append(word_alt)
entail_logprob = np.logaddexp.reduce(entail_logprobs)
total_logprob = np.logaddexp.reduce(entail_logprobs + contra_logprobs)
sample_mnlis.append(entail_logprob - total_logprob)
prob_nli.append(self._reduce(sample_mnlis))
return -np.array(prob_nli)