Source code for lm_polygraph.generation_metrics.openai_fact_check

import numpy as np
import os
from tqdm import tqdm

from typing import List, Dict
from lm_polygraph.utils.openai_chat import OpenAIChat
from .generation_metric import GenerationMetric
from lm_polygraph.stat_calculators.claim_level_prompts import *

from concurrent.futures import ThreadPoolExecutor


[docs]class OpenAIFactCheck(GenerationMetric): """ Calculates for each claim, whether it is true of not, using OpenAI model specified in lm_polygraph.stat_calculators.openai_chat.OpenAIChat. """ def __init__( self, llm_url: str = None, openai_model: str = "gpt-4o", cache_path: str = os.path.expanduser("~") + "/.cache", language: str = "en", progress_bar: bool = False, fact_check_prompts: Dict[str, str] = OPENAI_FACT_CHECK_PROMPTS, fact_check_summarize_prompt: Dict[ str, str ] = OPENAI_FACT_CHECK_SUMMARIZE_PROMPT, n_threads: int = 1, timeout: int = 600, max_tokens: int = None, rewrite_cache: bool = False, ): super().__init__(["input_texts"], "claim") self.openai_chat = OpenAIChat( base_url=llm_url, openai_model=openai_model, cache_path=cache_path, timeout=timeout, max_tokens=max_tokens, rewrite_cache=rewrite_cache, ) self.language = language self.progress_bar = progress_bar self.fact_check_prompts = fact_check_prompts self.fact_check_summarize_prompt = fact_check_summarize_prompt self.n_threads = n_threads def __str__(self): return "OpenAIFactCheck" def _score_single(self, args: tuple[str, str]) -> int: claim, input = args reply = self.openai_chat.ask( self.fact_check_prompts[self.language].format( claim=claim, input=input, ) ) reply = self.openai_chat.ask( self.fact_check_summarize_prompt[self.language].format( claim=claim, input=input, reply=reply, ) ) reply = reply.strip() if any(x in reply for x in ["True", '"True"', "是", "真", "نعم"]): return 0 elif any(x in reply for x in ["False", '"False"', "否", "假", "لا"]): return 1 else: return np.nan def __call__( self, stats: Dict[str, np.ndarray], target_texts: List[str], ) -> np.ndarray: """ For each claim in stats['claims'], asks OpenAI model whether this fact is correct or not. Parameters: stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes: * for each generation, list of lm_polygraph.stat_calculators.extract_claims.Claim in 'claims' target_texts (List[str]): ground-truth texts Returns: np.ndarray: list of labels, 1 if the fact is false and 0 if it is true. """ input_texts = stats["input_texts"] all_inputs = [ (claim.claim_text, input_text) for input_text, sample_claims in zip(input_texts, stats["claims"]) for claim in sample_claims ] with ThreadPoolExecutor(max_workers=self.n_threads) as executor: all_outputs = list( tqdm( executor.map(self._score_single, all_inputs), total=len(all_inputs), desc="Verifying claims", disable=not self.progress_bar, ) ) claim_labels = [] for sample_claims in stats["claims"]: claim_labels.append(all_outputs[: len(sample_claims)]) all_outputs = all_outputs[len(sample_claims) :] return claim_labels