Source code for lm_polygraph.utils.model

import torch
import openai
import time
import logging
import json

from dataclasses import asdict
from typing import List, Dict, Optional, Union
from abc import abstractmethod, ABC
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoConfig,
    LogitsProcessorList,
    BartForConditionalGeneration,
)
from huggingface_hub import InferenceClient

from lm_polygraph.utils.generation_parameters import (
    GenerationParameters,
    GenerationParametersFactory,
)
from lm_polygraph.utils.ensemble_utils.ensemble_generator import EnsembleGenerationMixin
from lm_polygraph.utils.ensemble_utils.dropout import replace_dropout

log = logging.getLogger("lm_polygraph")


[docs]class Model(ABC): """ Abstract model class. Used as base class for both White-box models and Black-box models. """ def __init__(self, model_path: str, model_type: str): """ Parameters: model_path (str): unique model path where it can be found. model_type (str): description of additional model properties. Can be 'Blackbox' or model specifications in the case of white-box. """ self.model_path = model_path self.model_type = model_type
[docs] @abstractmethod def generate_texts(self, input_texts: List[str], **args) -> List[str]: """ Abstract method. Generates a list of model answers using input texts batch. Parameters: input_texts (List[str]): input texts batch. Return: List[str]: corresponding model generations. Have the same length as `input_texts`. """ raise Exception("Not implemented")
[docs] @abstractmethod def generate(self, **args): """ Abstract method. Generates the model output with scores from batch formed by HF Tokenizer. Not implemented for black-box models. """ raise Exception("Not implemented")
@abstractmethod def __call__(self, **args): """ Abstract method. Calls the model on the input batch. Returns the resulted scores. Not implemented for black-box models. """ raise Exception("Not implemented")
[docs]class BlackboxModel(Model): """ Black-box model class. Have no access to model scores and logits. Currently implemented blackbox models: OpenAI models, Huggingface models. Examples: ```python >>> from lm_polygraph import BlackboxModel >>> model = BlackboxModel.from_openai( ... 'YOUR_OPENAI_TOKEN', ... 'gpt-3.5-turbo' ... ) ``` ```python >>> from lm_polygraph import BlackboxModel >>> model = BlackboxModel.from_huggingface( ... hf_api_token='YOUR_API_TOKEN', ... hf_model_id='google/t5-large-ssm-nqo' ... ) ``` """ def __init__( self, openai_api_key: str = None, model_path: str = None, hf_api_token: str = None, generation_parameters: GenerationParameters = GenerationParameters(), supports_logprobs: bool = False, ): """ Parameters: openai_api_key (Optional[str]): OpenAI API key if the blackbox model comes from OpenAI. Default: None. model_path (Optional[str]): Unique model path. Openai model name, if `openai_api_key` is specified, huggingface path, if `hf_api_token` is specified. Default: None. hf_api_token (Optional[str]): Huggingface API token if the blackbox model comes from HF. Default: None. generation_parameters (GenerationParameters): parameters to use in model generation. Default: default parameters. supports_logprobs (bool): Whether the model supports returning log probabilities. Default: False. """ super().__init__(model_path, "Blackbox") self.generation_parameters = generation_parameters self.openai_api_key = openai_api_key self.supports_logprobs = supports_logprobs if openai_api_key is not None: self.openai_api = openai.OpenAI(api_key=openai_api_key) self.hf_api_token = hf_api_token def _validate_args(self, args): """ Validates and adapts arguments for BlackboxModel generation. Parameters: args (dict): The arguments to validate. Returns: dict: Validated and adapted arguments. """ args_copy = args.copy() # BlackboxModel specific validation for delete_key in [ "do_sample", "min_length", "top_k", "repetition_penalty", "min_new_tokens", "num_beams", "allow_newlines", "stop_strings", ]: args_copy.pop(delete_key, None) # Map HF argument names to OpenAI/HF API argument names key_mapping = { "num_return_sequences": "n", "max_length": "max_tokens", "max_new_tokens": "max_tokens", } for key, replace_key in key_mapping.items(): if key in args_copy: args_copy[replace_key] = args_copy[key] args_copy.pop(key) return args_copy def _query(self, payload): client = InferenceClient(model=self.model_path, token=self.hf_api_token) response = client.chat_completion(payload) raw_json = json.dumps(response, indent=2) return raw_json
[docs] @staticmethod def from_huggingface(hf_api_token: str, hf_model_id: str, **kwargs): """ Initializes a blackbox model from huggingface. Parameters: hf_api_token (Optional[str]): Huggingface API token if the blackbox model comes from HF. Default: None. hf_model_id (Optional[str]): model path in huggingface. """ generation_parameters = kwargs.pop( "generation_parameters", GenerationParameters() ) return BlackboxModel( hf_api_token=hf_api_token, model_path=hf_model_id, generation_parameters=generation_parameters, )
[docs] @staticmethod def from_openai( openai_api_key: str, model_path: str, supports_logprobs: bool = False, **kwargs ): """ Initializes a blackbox model from OpenAI API. Parameters: openai_api_key (Optional[str]): OpenAI API key. Default: None. model_path (Optional[str]): model name in OpenAI. supports_logprobs (bool): Whether the model supports returning log probabilities. Default: False. """ generation_parameters = kwargs.pop( "generation_parameters", GenerationParameters() ) return BlackboxModel( openai_api_key=openai_api_key, model_path=model_path, supports_logprobs=supports_logprobs, generation_parameters=generation_parameters, )
[docs] def generate_texts(self, input_texts: List[str], **args) -> List[str]: """ Generates a list of model answers using input texts batch. Parameters: input_texts (List[str]): input texts batch. Return: List[str]: corresponding model generations. Have the same length as `input_texts`. """ # Apply default parameters first, then override with provided args default_params = asdict(self.generation_parameters) default_params.update(args) args = self._validate_args(default_params) # Check if we're trying to access features that require logprobs support if ( any( args.get(arg, False) for arg in [ "output_scores", "output_attentions", "output_hidden_states", ] ) and not self.supports_logprobs ): raise Exception("Cannot access logits for blackbox model") texts = [] if self.openai_api_key is not None: # Save log probabilities if requested self.last_response = None self.logprobs = [] self.tokens = [] # Check if we need to return logprobs return_logprobs = args.pop("output_scores", False) logprobs_args = {} if return_logprobs and self.supports_logprobs: logprobs_args["logprobs"] = True # OpenAI supports returning top logprobs, default to 5 logprobs_args["top_logprobs"] = args.pop("top_logprobs", 5) for prompt in input_texts: if isinstance(prompt, str): # If prompt is a string, create a single message with "user" role messages = [{"role": "user", "content": prompt}] elif isinstance(prompt, list) and all( isinstance(item, dict) for item in prompt ): # If prompt is a list of dictionaries, assume it's already structured as chat messages = prompt else: raise ValueError( "Invalid prompt format. Must be either a string or a list of dictionaries." ) retries = 0 while True: try: response = self.openai_api.chat.completions.create( model=self.model_path, messages=messages, **args, **logprobs_args, ) break except Exception as e: if retries > 4: raise Exception from e else: retries += 1 continue if args.get("n", 1) == 1: texts.append(response.choices[0].message.content) # Store logprobs if available if return_logprobs and hasattr(response.choices[0], "logprobs"): self.logprobs.append(response.choices[0].logprobs) # Extract token information if available if hasattr(response.choices[0].logprobs, "content"): tokens = [ item.token for item in response.choices[0].logprobs.content ] self.tokens.append(tokens) else: texts.append([resp.message.content for resp in response.choices]) # For multiple returns, we don't collect logprobs for now # Store the last response for later use self.last_response = response elif (self.hf_api_token is not None) & (self.model_path is not None): for prompt in input_texts: start = time.time() while True: current_time = time.time() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ] output = self._query(messages) if isinstance(output, dict): if (list(output.keys())[0] == "error") & ( "estimated_time" in output.keys() ): estimated_time = float(output["estimated_time"]) elapsed_time = current_time - start print( f"{output['error']}. Estimated time: {round(estimated_time - elapsed_time, 2)} sec." ) time.sleep(5) elif (list(output.keys())[0] == "error") & ( "estimated_time" not in output.keys() ): log.error(f"{output['error']}") break elif isinstance(output, list): break texts.append(output[0]["generated_text"]) else: print( "Please provide HF API token and model id for using models from HF or openai API key for using OpenAI models" ) return texts
[docs] def generate(self, **args): """ For OpenAI models with logprobs support, returns a lightweight wrapper around OpenAI API response. For other blackbox models, raises an exception as this is not implemented. Parameters: **args: Arguments to pass to the generate method. Returns: object: A wrapper around the OpenAI API response if logprobs are supported. Raises: Exception: If the model doesn't support logprobs. """ if self.supports_logprobs: # Apply default parameters first, then override with provided args default_params = asdict(self.generation_parameters) default_params.update(args) args = self._validate_args(default_params) args["output_scores"] = True sequences = self.generate_texts(**args) # Return a simple object with the necessary attributes for compatibility class OpenAIGenerationOutput: def __init__(self, sequences, scores): self.sequences = sequences self.scores = scores return OpenAIGenerationOutput(sequences, self.logprobs) else: raise Exception("Cannot access logits of blackbox model")
def __call__(self, **args): """ Not implemented for blackbox models. """ raise Exception("Cannot access logits of blackbox model")
[docs] def tokenizer(self, *args, **kwargs): """ Not implemented for blackbox models. """ raise Exception("Cannot access logits of blackbox model")
[docs]class WhiteboxModel(Model): """ White-box model class. Have access to model scores and logits. Currently implemented only for Huggingface models. Examples: ```python >>> from lm_polygraph import WhiteboxModel >>> model = WhiteboxModel.from_pretrained( ... "bigscience/bloomz-3b", ... ) ``` """ def __init__( self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, model_path: str = None, model_type: str = "CausalLM", generation_parameters: GenerationParameters = GenerationParameters(), instruct: bool = False, ): """ Parameters: model (AutoModelForCausalLM): HuggingFace model. tokenizer (AutoTokenizer): HuggingFace tokenizer. model_path (Optional[str]): Unique model path in HuggingFace. model_type (str): Additional model specifications. parameters (GenerationParameters): parameters to use in model generation. Default: default parameters. """ super().__init__(model_path, model_type) self.model = model self.tokenizer = tokenizer self.generation_parameters = generation_parameters self.instruct = instruct def _validate_args(self, args): """ Validates and adapts arguments for WhiteboxModel generation. Parameters: args (dict): The arguments to validate. Returns: dict: Validated and adapted arguments. """ args_copy = args.copy() # WhiteboxModel specific validation if "presence_penalty" in args_copy and args_copy["presence_penalty"] != 0.0: log.warning( "Skipping requested argument presence_penalty={}".format( args_copy["presence_penalty"] ) ) # Remove arguments that are not supported by the HF model.generate function keys_to_remove = ["presence_penalty", "allow_newlines"] for key in keys_to_remove: args_copy.pop(key, None) return args_copy class _ScoresProcessor: # Stores original token scores instead of the ones modified with generation parameters def __init__(self): self.scores = [] def __call__(self, input_ids=None, scores=None): self.scores.append(scores.log_softmax(-1)) return scores class _SanitizeLogitsProcessor: # Replaces inf/nan in logits with finite values to prevent # RuntimeError in torch.multinomial during sampling. # Uses per-row max/min of finite values to preserve distribution shape. def __call__(self, input_ids=None, scores=None): if torch.isfinite(scores).all(): return scores finite_mask = torch.isfinite(scores) scores_for_max = torch.where( finite_mask, scores, torch.tensor(float("-inf"), dtype=scores.dtype, device=scores.device), ) scores_for_min = torch.where( finite_mask, scores, torch.tensor(float("inf"), dtype=scores.dtype, device=scores.device), ) row_max = scores_for_max.max(dim=-1, keepdim=True).values row_min = scores_for_min.min(dim=-1, keepdim=True).values row_max = torch.where( torch.isfinite(row_max), row_max, torch.zeros_like(row_max) ) row_min = torch.where( torch.isfinite(row_min), row_min, torch.zeros_like(row_min) ) scores = torch.where(torch.isposinf(scores), row_max, scores) scores = torch.where(torch.isneginf(scores), row_min, scores) scores = torch.nan_to_num(scores, nan=0.0) return scores
[docs] def generate(self, **args): """ Generates the model output with scores from batch formed by HF Tokenizer. Parameters: **args: Any arguments that can be passed to model.generate function from HuggingFace. Returns: ModelOutput: HuggingFace generation output with scores overriden with original probabilities. """ default_params = asdict(self.generation_parameters) # add ScoresProcessor and SanitizeLogitsProcessor processor = self._ScoresProcessor() sanitizer = self._SanitizeLogitsProcessor() if "logits_processor" in args.keys(): logits_processor = LogitsProcessorList( [sanitizer, processor, args["logits_processor"]] ) else: logits_processor = LogitsProcessorList([sanitizer, processor]) args["logits_processor"] = logits_processor # update default parameters with passed arguments default_params.update(args) args = default_params # Handle stop_strings via stopping_criteria to avoid passing tokenizer # as a kwarg (breaks with transformers>=4.51 + Bloom-like models) stop_strings = args.pop("stop_strings", None) if stop_strings: from transformers import StoppingCriteriaList from transformers.generation.stopping_criteria import StopStringCriteria stop_criteria = StopStringCriteria( stop_strings=stop_strings, tokenizer=self.tokenizer ) if "stopping_criteria" in args and args["stopping_criteria"]: args["stopping_criteria"].append(stop_criteria) else: args["stopping_criteria"] = StoppingCriteriaList([stop_criteria]) args = self._validate_args(args) generation = self.model.generate(**args) # override generation.scores with original scores from model generation.generation_scores = generation.scores generation.scores = processor.scores return generation
[docs] def generate_texts(self, input_texts: List[str], **args) -> List[str]: """ Generates a list of model answers using input texts batch. Parameters: input_texts (List[str]): input texts batch. Return: List[str]: corresponding model generations. Have the same length as `input_texts`. """ # Apply default parameters first, then override with provided args default_params = asdict(self.generation_parameters) default_params.update(args) args = self._validate_args(default_params) args["return_dict_in_generate"] = True batch: Dict[str, torch.Tensor] = self.tokenize(input_texts) batch = {k: v.to(self.device()) for k, v in batch.items()} sequences = self.generate(**batch, **args).sequences.cpu() input_len = batch["input_ids"].shape[1] texts = [] decode_args = {} if self.tokenizer.chat_template is not None: decode_args["skip_special_tokens"] = True for seq in sequences: if self.model_type == "CausalLM": texts.append(self.tokenizer.decode(seq[input_len:], **decode_args)) else: texts.append(self.tokenizer.decode(seq[1:], **decode_args)) return texts
def __call__(self, **args): """ Calls the model on the input batch. Returns the resulted scores. """ return self.model(**args)
[docs] def device(self): """ Returns the device the model is currently loaded on. Returns: str: device string. """ return self.model.device
[docs] @staticmethod def from_pretrained( model_path: str, generation_params: Optional[Dict] = {}, add_bos_token: bool = True, **kwargs, ): """ Initializes the model from HuggingFace. Automatically determines model type. Parameters: model_path (str): model path in HuggingFace. generation_params (Dict): generation arguments for lm_polygraph.utils.generation_parametersGenerationParameters add_bos_token (bool): tokenizer argument. Default: True. """ log.warning( "WhiteboxModel#from_pretrained is deprecated and will be removed in the next release. Please instantiate WhiteboxModel directly by passing an already loaded model, tokenizer and model path." ) config = AutoConfig.from_pretrained( model_path, trust_remote_code=True, **kwargs ) if any(["CausalLM" in architecture for architecture in config.architectures]): model_type = "CausalLM" model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, **kwargs ) elif any( [ ("Seq2SeqLM" in architecture) or ("ConditionalGeneration" in architecture) for architecture in config.architectures ] ): model_type = "Seq2SeqLM" model = AutoModelForSeq2SeqLM.from_pretrained(model_path, **kwargs) if "falcon" in model_path: model.transformer.alibi = True elif any( ["JAISLMHeadModel" in architecture for architecture in config.architectures] ): model_type = "CausalLM" model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, **kwargs, ) elif any( ["BartModel" in architecture for architecture in config.architectures] ): model_type = "Seq2SeqLM" model = BartForConditionalGeneration.from_pretrained(model_path, **kwargs) else: raise ValueError( f"Model {model_path} is not adapted for the sequence generation task" ) tokenizer = AutoTokenizer.from_pretrained( model_path, padding_side="left", add_bos_token=add_bos_token, **kwargs, ) model.eval() if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token generation_params = GenerationParametersFactory.from_params( yaml_config=generation_params, native_config=asdict(model.config), ) instance = WhiteboxModel( model, tokenizer, model_path, model_type, generation_params ) return instance
[docs] def tokenize( self, texts: Union[List[str], List[List[Dict[str, str]]]] ) -> Dict[str, torch.Tensor]: """ Tokenizes input texts batch into a dictionary using the model tokenizer. Parameters: texts (List[str]): list of input texts batch. Returns: dict[str, torch.Tensor]: tensors dictionary obtained by tokenizing input texts batch. """ # Apply chat template if tokenizer has it add_start_symbol = True if self.instruct: formatted_texts = [] for chat in texts: if isinstance(chat, str): chat = [{"role": "user", "content": chat}] formatted_chat = self.tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False ) formatted_texts.append(formatted_chat) texts = formatted_texts add_start_symbol = False return self.tokenizer( texts, padding=True, return_tensors="pt", add_special_tokens=add_start_symbol, )
[docs]def create_ensemble( models: List[WhiteboxModel] = [], mc: bool = False, seed: int = 1, mc_seeds: List[int] = [1], ensembling_mode: str = "pe", dropout_rate: float = 0.1, **kwargs, ) -> WhiteboxModel: model = models[0] ens = model.model ens.__class__ = type( "EnsembleModel", (model.model.__class__, EnsembleGenerationMixin), {} ) if mc: ens.mc = True ens.mc_seeds = mc_seeds ens.base_seed = seed ens.ensembling_mode = ensembling_mode ens.mc_models_num = len(mc_seeds) ens.mc_seeds = mc_seeds replace_dropout( ens.config._name_or_path, ens, p=dropout_rate, share_across_tokens=True ) ens.train() else: raise ValueError( "Only Monte-Carlo ensembling is available. Please set the corresponding argument value to True" ) return model