Source code for lm_polygraph.stat_calculators.extract_claims

import re
import logging

from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

from .stat_calculator import StatCalculator
from lm_polygraph.utils.openai_chat import OpenAIChat
from lm_polygraph.utils.model import WhiteboxModel
from .claim_level_prompts import CLAIM_EXTRACTION_PROMPTS, MATCHING_PROMPTS

from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

log = logging.getLogger("lm_polygraph")


[docs]@dataclass class Claim: claim_text: str # The sentence of the generation, from which the claim was extracted sentence: str # Indices in the original generation of the tokens, which are related to the current claim aligned_token_ids: List[int]
[docs]class ClaimsExtractor(StatCalculator): """ Extracts claims from the text of the model generation. """ def __init__( self, openai_chat: OpenAIChat, sent_separators: str = ".?!。?!\n", language: str = "en", progress_bar: bool = False, extraction_prompts: Dict[str, str] = CLAIM_EXTRACTION_PROMPTS, matching_prompts: Dict[str, str] = MATCHING_PROMPTS, n_threads: int = 1, ): super().__init__() log.info( f"Initializing ClaimsExtractor with language={language}, n_threads={n_threads}" ) self.language = language self.openai_chat = openai_chat self.sent_separators = sent_separators self.progress_bar = progress_bar self.extraction_prompts = extraction_prompts self.matching_prompts = matching_prompts self.n_threads = n_threads
[docs] @staticmethod def meta_info() -> Tuple[List[str], List[str]]: return ( [ "claims", "claim_texts_concatenated", "claim_input_texts_concatenated", ], [ "greedy_texts", "greedy_tokens", ], )
def __call__( self, dependencies: Dict[str, object], texts: List[str], model: WhiteboxModel, *args, **kwargs, ) -> Dict[str, List]: """ Extracts the claims out of each generation text. Parameters: dependencies (Dict[str, object]): input statistics, which includes: * 'greedy_log_probs' (List[List[float]]): log-probabilities of the generation tokens. texts (List[str]): Input texts batch used for model generation. model (Model): Model used for generation. Returns: Dict[str, List]: dictionary with : * 'claims' (List[List[lm_polygraph.stat_calculators.extract_claims.Claim]]): list of claims for each input text; * 'claim_texts_concatenated' (List[str]): list of all textual claims extracted; * 'claim_input_texts_concatenated' (List[str]): for each claim in claim_texts_concatenated, corresponding input text. """ all_sent_texts, all_sent_tokens, all_sent_positions, n_sents = [], [], [], [] for greedy_text, greedy_tokens in zip( dependencies["greedy_texts"], dependencies["greedy_tokens"] ): sent_texts, sent_tokens, sent_positions = self.split_to_sentences( greedy_text, greedy_tokens, model.tokenizer ) n_sents.append(len(sent_texts)) all_sent_texts += sent_texts all_sent_tokens += sent_tokens all_sent_positions += sent_positions with ThreadPoolExecutor(max_workers=self.n_threads) as executor: claims_from_sent: List[List[Claim]] = list( tqdm( executor.map( self._claims_from_sentence, all_sent_texts, all_sent_tokens, [model.tokenizer] * len(all_sent_texts), ), total=len(all_sent_texts), desc="Extracting claims", disable=not self.progress_bar, ) ) claims: List[List[Claim]] = [] claim_texts_concatenated: List[str] = [] claim_input_texts_concatenated: List[str] = [] for i in range(len(texts)): claims.append([]) for sent in range(n_sents[i]): sent_claims = claims_from_sent[sent] sent_position = all_sent_positions[sent] for j in range(len(sent_claims)): for k in range(len(sent_claims[j].aligned_token_ids)): sent_claims[j].aligned_token_ids[k] += sent_position claims[-1] += sent_claims claim_texts_concatenated += sent_claims claim_input_texts_concatenated += [texts[i] for _ in sent_claims] claims_from_sent = claims_from_sent[n_sents[i] :] all_sent_positions = all_sent_positions[n_sents[i] :] return { "claims": claims, "claim_texts_concatenated": claim_texts_concatenated, "claim_input_texts_concatenated": claim_input_texts_concatenated, }
[docs] def split_to_sentences( self, text: str, tokens: List[int], tokenizer, ) -> Tuple[List[str], List[List[int]], List[int]]: sentences = [] for s in re.split(f"[{self.sent_separators}]", text): if len(s) > 0: sentences.append(s) if len(text) > 0 and text[-1] not in self.sent_separators: # Remove last unfinished sentence, because extracting claims # from unfinished sentence may lead to hallucinated claims. sentences = sentences[:-1] sent_start_token_idx, sent_end_token_idx = 0, 0 sent_start_idx, sent_end_idx = 0, 0 all_sent_texts, all_sent_tokens, all_sent_positions = [], [], [] for s in sentences: # Find sentence location in text: text[sent_start_idx:sent_end_idx] while not text[sent_start_idx:].startswith(s): sent_start_idx += 1 while not text[:sent_end_idx].endswith(s): sent_end_idx += 1 # Iteratively decode tokenized text until decoded sequence length is # greater or equal to the starting position of current sentence. # Find sentence location in tokens: tokens[sent_start_token_idx:sent_end_token_idx] while len(tokenizer.decode(tokens[:sent_start_token_idx])) < sent_start_idx: sent_start_token_idx += 1 while len(tokenizer.decode(tokens[:sent_end_token_idx])) < sent_end_idx: sent_end_token_idx += 1 all_sent_texts.append(s) all_sent_tokens.append(tokens[sent_start_token_idx:sent_end_token_idx]) all_sent_positions.append(sent_start_token_idx) return all_sent_texts, all_sent_tokens, all_sent_positions
def _claims_from_sentence( self, sent: str, sent_tokens: List[int], tokenizer, ) -> List[Claim]: # Extract claims with specific prompt extracted_claims = self.openai_chat.ask( self.extraction_prompts[self.language].format(sent=sent) ) claims = [] for claim_text in extracted_claims.split("\n"): # Bad claim_text example: # - There aren't any claims in this sentence. if not claim_text.startswith("- "): continue if "there aren't any claims" in claim_text.lower(): continue # remove '- ' in the beginning claim_text = claim_text[2:].strip() # Get words which matches the claim using specific prompt # Example: # sent = 'Lanny Flaherty is an American actor born on December 18, 1949, in Pensacola, Florida.' # claim = 'Lanny Flaherty was born on December 18, 1949.' # GPT response: 'Lanny, Flaherty, born, on, December, 18, 1949' # match_words = ['Lanny', 'Flaherty', 'born', 'on', 'December', '18', '1949'] chat_ask = self.matching_prompts[self.language].format( sent=sent, claim=claim_text, ) match_words = self.openai_chat.ask(chat_ask) # comma has a different form in Chinese and space works better if self.language == "zh": match_words = match_words.strip().split(" ") else: match_words = match_words.strip().split(",") match_words = list(map(lambda x: x.strip(), match_words)) # Try to highlight matched symbols in sent if self.language == "zh": match_string = self._match_string_zh(sent, match_words) else: match_string = self._match_string(sent, match_words) if match_string is None: continue # Get token positions which intersect with highlighted regions, that is, correspond to the claim aligned_token_ids = self._align(sent, match_string, sent_tokens, tokenizer) if len(aligned_token_ids) == 0: continue claims.append( Claim( claim_text=claim_text, sentence=sent, aligned_token_ids=aligned_token_ids, ) ) return claims def _match_string(self, sent: str, match_words: List[str]) -> Optional[str]: """ Greedily matching words from `match_words` to `sent`. Parameters: sent (str): sentence string match_words (List[str]): list of words from sent, in the same order they appear in it. Returns: Optional[str]: string of length len(sent), for each symbol in sent, '^' if it contains in one of the match_words if aligned to sent, ' ' otherwise. Returns None if matching failed, e.g. due to words in match_words, which are not present in sent, or of the words are specified not in the same order they appear in the sentence. Example: sent = 'Lanny Flaherty is an American actor born on December 18, 1949, in Pensacola, Florida.' match_words = ['Lanny', 'Flaherty', 'born', 'on', 'December', '18', '1949'] return '^^^^^ ^^^^^^^^ ^^^^ ^^ ^^^^^^^^ ^^ ^^^^ ' """ sent_pos = 0 # pointer to the sentence match_words_pos = 0 # pointer to the match_words list # Iteratively construct match_str with highlighted symbols, start with empty string match_str = "" while sent_pos < len(sent): # Check if current word cur_word can be located in sent[sent_pos:sent_pos + len(cur_word)]: # 1. check if symbols around word position are not letters check_boundaries = False if sent_pos == 0 or not sent[sent_pos - 1].isalpha(): check_boundaries = True if check_boundaries and match_words_pos < len(match_words): cur_match_word = match_words[match_words_pos] right_idx = sent_pos + len(cur_match_word) if right_idx < len(sent): check_boundaries = not sent[right_idx].isalpha() # 2. check if symbols in word position are the same as cur_word if check_boundaries and sent[sent_pos:].startswith(cur_match_word): # Found match at sent[sent_pos] with cur_word len_w = len(cur_match_word) sent_pos += len_w # Highlight this position in match string match_str += "^" * len_w match_words_pos += 1 continue # No match at sent[sent_pos], continue with the next position sent_pos += 1 match_str += " " if match_words_pos < len(match_words): # Didn't match all words to the sentence. # Possibly because the match words are in the wrong order or are not present in sentence. return None return match_str def _match_string_zh(self, sent: str, match_words: List[str]) -> Optional[str]: # Greedily matching characters from `match_words` to `sent` for Chinese. # Returns None if matching failed, e.g. due to characters in match_words, which are not present # in sent, or if the characters are not in the same order they appear in the sentence. # # Example: # sent = '爱因斯坦也是一位和平主义者。' # match_words = ['爱因斯坦', '是', '和平', '主义者'] # return '^^^^ ^ ^^^^' last = 0 # pointer to the sentence last_match = 0 # pointer to the match_words list match_str = "" # Iterate through each character in the input Chinese text for char in sent: # Check if the current character matches the next character in match_words[last_match] if last_match < len(match_words) and char == match_words[last_match][last]: # Match found, update pointers and match_str match_str += "^" last += 1 if last == len(match_words[last_match]): last = 0 last_match += 1 else: # No match, append a space to match_str match_str += " " # Check if all characters in match_words have been matched if last_match < len(match_words): return None # Didn't match all characters to the sentence return match_str def _align( self, sent: str, match_str: str, sent_tokens: List[int], tokenizer, ) -> List[int]: """ Identifies token indices in `sent_tokens` that align with matching characters (marked by '^') in `match_str`. All tokens, which textual representations intersect with any of matching characters, are included. Partial intersections should be uncommon in practice. Args: sent: the original sentence. match_str: a string of the same length as `sent` where '^' characters indicate matches. sent_tokens: a list of token ids representing the tokenized version of `sent`. tokenizer: the tokenizer used to decode tokens. Returns: A list of integers representing the indices of tokens in `sent_tokens` that align with matching characters in `match_str`. """ sent_pos = 0 cur_token_i = 0 # Iteratively find position of each new token. aligned_token_ids = [] while sent_pos < len(sent) and cur_token_i < len(sent_tokens): cur_token_text = tokenizer.decode(sent_tokens[cur_token_i]) # Try to find the position of cur_token_text in sentence, possibly in sent[sent_pos] if len(cur_token_text) == 0: # Skip non-informative token cur_token_i += 1 continue if sent[sent_pos:].startswith(cur_token_text): # If the match string corresponding to the token contains matches, add to answer if any( t == "^" for t in match_str[sent_pos : sent_pos + len(cur_token_text)] ): aligned_token_ids.append(cur_token_i) cur_token_i += 1 sent_pos += len(cur_token_text) else: # Continue with the same token and next position in the sentence. sent_pos += 1 return aligned_token_ids