Source code for lm_polygraph.stat_calculators.extract_claims
import re
import logging
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from .stat_calculator import StatCalculator
from lm_polygraph.utils.openai_chat import OpenAIChat
from lm_polygraph.utils.model import WhiteboxModel
from .claim_level_prompts import CLAIM_EXTRACTION_PROMPTS, MATCHING_PROMPTS
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
log = logging.getLogger("lm_polygraph")
[docs]@dataclass
class Claim:
claim_text: str
# The sentence of the generation, from which the claim was extracted
sentence: str
# Indices in the original generation of the tokens, which are related to the current claim
aligned_token_ids: List[int]
[docs]class ClaimsExtractor(StatCalculator):
"""
Extracts claims from the text of the model generation.
"""
def __init__(
self,
openai_chat: OpenAIChat,
sent_separators: str = ".?!。?!\n",
language: str = "en",
progress_bar: bool = False,
extraction_prompts: Dict[str, str] = CLAIM_EXTRACTION_PROMPTS,
matching_prompts: Dict[str, str] = MATCHING_PROMPTS,
n_threads: int = 1,
):
super().__init__()
log.info(
f"Initializing ClaimsExtractor with language={language}, n_threads={n_threads}"
)
self.language = language
self.openai_chat = openai_chat
self.sent_separators = sent_separators
self.progress_bar = progress_bar
self.extraction_prompts = extraction_prompts
self.matching_prompts = matching_prompts
self.n_threads = n_threads
[docs] @staticmethod
def meta_info() -> Tuple[List[str], List[str]]:
return (
[
"claims",
"claim_texts_concatenated",
"claim_input_texts_concatenated",
],
[
"greedy_texts",
"greedy_tokens",
],
)
def __call__(
self,
dependencies: Dict[str, object],
texts: List[str],
model: WhiteboxModel,
*args,
**kwargs,
) -> Dict[str, List]:
"""
Extracts the claims out of each generation text.
Parameters:
dependencies (Dict[str, object]): input statistics, which includes:
* 'greedy_log_probs' (List[List[float]]): log-probabilities of the generation tokens.
texts (List[str]): Input texts batch used for model generation.
model (Model): Model used for generation.
Returns:
Dict[str, List]: dictionary with :
* 'claims' (List[List[lm_polygraph.stat_calculators.extract_claims.Claim]]):
list of claims for each input text;
* 'claim_texts_concatenated' (List[str]): list of all textual claims extracted;
* 'claim_input_texts_concatenated' (List[str]): for each claim in
claim_texts_concatenated, corresponding input text.
"""
all_sent_texts, all_sent_tokens, all_sent_positions, n_sents = [], [], [], []
for greedy_text, greedy_tokens in zip(
dependencies["greedy_texts"], dependencies["greedy_tokens"]
):
sent_texts, sent_tokens, sent_positions = self.split_to_sentences(
greedy_text, greedy_tokens, model.tokenizer
)
n_sents.append(len(sent_texts))
all_sent_texts += sent_texts
all_sent_tokens += sent_tokens
all_sent_positions += sent_positions
with ThreadPoolExecutor(max_workers=self.n_threads) as executor:
claims_from_sent: List[List[Claim]] = list(
tqdm(
executor.map(
self._claims_from_sentence,
all_sent_texts,
all_sent_tokens,
[model.tokenizer] * len(all_sent_texts),
),
total=len(all_sent_texts),
desc="Extracting claims",
disable=not self.progress_bar,
)
)
claims: List[List[Claim]] = []
claim_texts_concatenated: List[str] = []
claim_input_texts_concatenated: List[str] = []
for i in range(len(texts)):
claims.append([])
for sent in range(n_sents[i]):
sent_claims = claims_from_sent[sent]
sent_position = all_sent_positions[sent]
for j in range(len(sent_claims)):
for k in range(len(sent_claims[j].aligned_token_ids)):
sent_claims[j].aligned_token_ids[k] += sent_position
claims[-1] += sent_claims
claim_texts_concatenated += sent_claims
claim_input_texts_concatenated += [texts[i] for _ in sent_claims]
claims_from_sent = claims_from_sent[n_sents[i] :]
all_sent_positions = all_sent_positions[n_sents[i] :]
return {
"claims": claims,
"claim_texts_concatenated": claim_texts_concatenated,
"claim_input_texts_concatenated": claim_input_texts_concatenated,
}
[docs] def split_to_sentences(
self,
text: str,
tokens: List[int],
tokenizer,
) -> Tuple[List[str], List[List[int]], List[int]]:
sentences = []
for s in re.split(f"[{self.sent_separators}]", text):
if len(s) > 0:
sentences.append(s)
if len(text) > 0 and text[-1] not in self.sent_separators:
# Remove last unfinished sentence, because extracting claims
# from unfinished sentence may lead to hallucinated claims.
sentences = sentences[:-1]
sent_start_token_idx, sent_end_token_idx = 0, 0
sent_start_idx, sent_end_idx = 0, 0
all_sent_texts, all_sent_tokens, all_sent_positions = [], [], []
for s in sentences:
# Find sentence location in text: text[sent_start_idx:sent_end_idx]
while not text[sent_start_idx:].startswith(s):
sent_start_idx += 1
while not text[:sent_end_idx].endswith(s):
sent_end_idx += 1
# Iteratively decode tokenized text until decoded sequence length is
# greater or equal to the starting position of current sentence.
# Find sentence location in tokens: tokens[sent_start_token_idx:sent_end_token_idx]
while len(tokenizer.decode(tokens[:sent_start_token_idx])) < sent_start_idx:
sent_start_token_idx += 1
while len(tokenizer.decode(tokens[:sent_end_token_idx])) < sent_end_idx:
sent_end_token_idx += 1
all_sent_texts.append(s)
all_sent_tokens.append(tokens[sent_start_token_idx:sent_end_token_idx])
all_sent_positions.append(sent_start_token_idx)
return all_sent_texts, all_sent_tokens, all_sent_positions
def _claims_from_sentence(
self,
sent: str,
sent_tokens: List[int],
tokenizer,
) -> List[Claim]:
# Extract claims with specific prompt
extracted_claims = self.openai_chat.ask(
self.extraction_prompts[self.language].format(sent=sent)
)
claims = []
for claim_text in extracted_claims.split("\n"):
# Bad claim_text example:
# - There aren't any claims in this sentence.
if not claim_text.startswith("- "):
continue
if "there aren't any claims" in claim_text.lower():
continue
# remove '- ' in the beginning
claim_text = claim_text[2:].strip()
# Get words which matches the claim using specific prompt
# Example:
# sent = 'Lanny Flaherty is an American actor born on December 18, 1949, in Pensacola, Florida.'
# claim = 'Lanny Flaherty was born on December 18, 1949.'
# GPT response: 'Lanny, Flaherty, born, on, December, 18, 1949'
# match_words = ['Lanny', 'Flaherty', 'born', 'on', 'December', '18', '1949']
chat_ask = self.matching_prompts[self.language].format(
sent=sent,
claim=claim_text,
)
match_words = self.openai_chat.ask(chat_ask)
# comma has a different form in Chinese and space works better
if self.language == "zh":
match_words = match_words.strip().split(" ")
else:
match_words = match_words.strip().split(",")
match_words = list(map(lambda x: x.strip(), match_words))
# Try to highlight matched symbols in sent
if self.language == "zh":
match_string = self._match_string_zh(sent, match_words)
else:
match_string = self._match_string(sent, match_words)
if match_string is None:
continue
# Get token positions which intersect with highlighted regions, that is, correspond to the claim
aligned_token_ids = self._align(sent, match_string, sent_tokens, tokenizer)
if len(aligned_token_ids) == 0:
continue
claims.append(
Claim(
claim_text=claim_text,
sentence=sent,
aligned_token_ids=aligned_token_ids,
)
)
return claims
def _match_string(self, sent: str, match_words: List[str]) -> Optional[str]:
"""
Greedily matching words from `match_words` to `sent`.
Parameters:
sent (str): sentence string
match_words (List[str]): list of words from sent, in the same order they appear in it.
Returns:
Optional[str]: string of length len(sent), for each symbol in sent, '^' if it contains in one
of the match_words if aligned to sent, ' ' otherwise.
Returns None if matching failed, e.g. due to words in match_words, which are not present
in sent, or of the words are specified not in the same order they appear in the sentence.
Example:
sent = 'Lanny Flaherty is an American actor born on December 18, 1949, in Pensacola, Florida.'
match_words = ['Lanny', 'Flaherty', 'born', 'on', 'December', '18', '1949']
return '^^^^^ ^^^^^^^^ ^^^^ ^^ ^^^^^^^^ ^^ ^^^^ '
"""
sent_pos = 0 # pointer to the sentence
match_words_pos = 0 # pointer to the match_words list
# Iteratively construct match_str with highlighted symbols, start with empty string
match_str = ""
while sent_pos < len(sent):
# Check if current word cur_word can be located in sent[sent_pos:sent_pos + len(cur_word)]:
# 1. check if symbols around word position are not letters
check_boundaries = False
if sent_pos == 0 or not sent[sent_pos - 1].isalpha():
check_boundaries = True
if check_boundaries and match_words_pos < len(match_words):
cur_match_word = match_words[match_words_pos]
right_idx = sent_pos + len(cur_match_word)
if right_idx < len(sent):
check_boundaries = not sent[right_idx].isalpha()
# 2. check if symbols in word position are the same as cur_word
if check_boundaries and sent[sent_pos:].startswith(cur_match_word):
# Found match at sent[sent_pos] with cur_word
len_w = len(cur_match_word)
sent_pos += len_w
# Highlight this position in match string
match_str += "^" * len_w
match_words_pos += 1
continue
# No match at sent[sent_pos], continue with the next position
sent_pos += 1
match_str += " "
if match_words_pos < len(match_words):
# Didn't match all words to the sentence.
# Possibly because the match words are in the wrong order or are not present in sentence.
return None
return match_str
def _match_string_zh(self, sent: str, match_words: List[str]) -> Optional[str]:
# Greedily matching characters from `match_words` to `sent` for Chinese.
# Returns None if matching failed, e.g. due to characters in match_words, which are not present
# in sent, or if the characters are not in the same order they appear in the sentence.
#
# Example:
# sent = '爱因斯坦也是一位和平主义者。'
# match_words = ['爱因斯坦', '是', '和平', '主义者']
# return '^^^^ ^ ^^^^'
last = 0 # pointer to the sentence
last_match = 0 # pointer to the match_words list
match_str = ""
# Iterate through each character in the input Chinese text
for char in sent:
# Check if the current character matches the next character in match_words[last_match]
if last_match < len(match_words) and char == match_words[last_match][last]:
# Match found, update pointers and match_str
match_str += "^"
last += 1
if last == len(match_words[last_match]):
last = 0
last_match += 1
else:
# No match, append a space to match_str
match_str += " "
# Check if all characters in match_words have been matched
if last_match < len(match_words):
return None # Didn't match all characters to the sentence
return match_str
def _align(
self,
sent: str,
match_str: str,
sent_tokens: List[int],
tokenizer,
) -> List[int]:
"""
Identifies token indices in `sent_tokens` that align with matching characters (marked by '^')
in `match_str`. All tokens, which textual representations intersect with any of matching
characters, are included. Partial intersections should be uncommon in practice.
Args:
sent: the original sentence.
match_str: a string of the same length as `sent` where '^' characters indicate matches.
sent_tokens: a list of token ids representing the tokenized version of `sent`.
tokenizer: the tokenizer used to decode tokens.
Returns:
A list of integers representing the indices of tokens in `sent_tokens` that align with
matching characters in `match_str`.
"""
sent_pos = 0
cur_token_i = 0
# Iteratively find position of each new token.
aligned_token_ids = []
while sent_pos < len(sent) and cur_token_i < len(sent_tokens):
cur_token_text = tokenizer.decode(sent_tokens[cur_token_i])
# Try to find the position of cur_token_text in sentence, possibly in sent[sent_pos]
if len(cur_token_text) == 0:
# Skip non-informative token
cur_token_i += 1
continue
if sent[sent_pos:].startswith(cur_token_text):
# If the match string corresponding to the token contains matches, add to answer
if any(
t == "^"
for t in match_str[sent_pos : sent_pos + len(cur_token_text)]
):
aligned_token_ids.append(cur_token_i)
cur_token_i += 1
sent_pos += len(cur_token_text)
else:
# Continue with the same token and next position in the sentence.
sent_pos += 1
return aligned_token_ids