Source code for lm_polygraph.model_adapters.visual_whitebox_model

from __future__ import annotations

import logging
from dataclasses import asdict
from types import SimpleNamespace
from typing import Dict, List, Optional, Union

import torch
from PIL import Image
from transformers import (
    AutoProcessor,
    GenerationConfig,
    LogitsProcessorList,
)

try:
    from transformers import AutoModelForVision2Seq
except ImportError:
    # transformers >= 5.0 renamed AutoModelForVision2Seq → AutoModelForImageTextToText
    from transformers import AutoModelForImageTextToText as AutoModelForVision2Seq

from lm_polygraph.utils.generation_parameters import GenerationParameters
from lm_polygraph.utils.dataset import Dataset
from lm_polygraph.utils.model import Model

log = logging.getLogger("lm_polygraph")


def _to_device(
    t: Optional[torch.Tensor], device: torch.device
) -> Optional[torch.Tensor]:
    if t is None:
        return None
    return t.to(device)


[docs]class VisualWhiteboxModel(Model): def __init__( self, model: AutoModelForVision2Seq, processor_visual: AutoProcessor, model_path: Optional[str] = None, model_type: str = "VisualLM", generation_parameters: GenerationParameters = GenerationParameters(), ): super().__init__(model_path, model_type) self.model = model self.processor_visual = processor_visual self.tokenizer = getattr(processor_visual, "tokenizer", None) self.generation_parameters = generation_parameters or GenerationParameters() # ensure model returns dicts where possible try: if hasattr(self.model, "config"): self.model.config.return_dict = True except Exception: pass class _ScoresProcessor: def __init__(self): self.scores: List[torch.Tensor] = [] def __call__(self, input_ids=None, scores=None): try: self.scores.append(scores.log_softmax(-1)) except Exception: self.scores.append(scores) return scores
[docs] def device(self) -> torch.device: try: return next(self.model.parameters()).device except StopIteration: return torch.device("cpu")
def _validate_args(self, args: Dict) -> Dict: args_copy = args.copy() for key in ("presence_penalty", "allow_newlines", "return_dict"): args_copy.pop(key, None) return args_copy
[docs] def generate(self, **args): # prepare defaults and processors default_params = asdict(self.generation_parameters) args.pop("return_dict", None) processor = self._ScoresProcessor() if "logits_processor" in args: logits_processor = LogitsProcessorList( [processor, args["logits_processor"]] ) else: logits_processor = LogitsProcessorList([processor]) args["logits_processor"] = logits_processor default_params.update(args) args = default_params # Handle stop_strings via stopping_criteria to avoid passing tokenizer # as a kwarg (breaks with transformers>=4.51 + Bloom-like models) stop_strings = args.pop("stop_strings", None) if stop_strings: from transformers import StoppingCriteriaList from transformers.generation.stopping_criteria import StopStringCriteria stop_criteria = StopStringCriteria( stop_strings=stop_strings, tokenizer=self.tokenizer ) if "stopping_criteria" in args and args["stopping_criteria"]: args["stopping_criteria"].append(stop_criteria) else: args["stopping_criteria"] = StoppingCriteriaList([stop_criteria]) args = self._validate_args(args) if "generation_config" not in args: gen_cfg = GenerationConfig( **{ k: v for k, v in args.items() if k in GenerationConfig.__annotations__ } ) args = { k: v for k, v in args.items() if k not in GenerationConfig.__annotations__ } args["generation_config"] = gen_cfg try: args["generation_config"].return_dict_in_generate = True except Exception: pass # Build tensor-only input snapshot tensor_inputs = {k: v for k, v in args.items() if isinstance(v, torch.Tensor)} try: generation_output = self.model.generate(**args) result = SimpleNamespace() result.sequences = ( generation_output.sequences if hasattr(generation_output, "sequences") else generation_output ) # Scores if hasattr(generation_output, "scores") and generation_output.scores: result.scores = list(generation_output.scores) result.generation_scores = list(generation_output.scores) else: vocab_size = self.model.config.vocab_size input_len = tensor_inputs["input_ids"].shape[1] seq_len = result.sequences.shape[1] num_steps = seq_len - input_len dummy_scores = [torch.randn(1, vocab_size) for _ in range(num_steps)] result.scores = dummy_scores result.generation_scores = dummy_scores input_len = tensor_inputs["input_ids"].shape[1] seq_len = result.sequences.shape[1] num_steps = seq_len - input_len batch_size = tensor_inputs["input_ids"].shape[0] num_layers = getattr(self.model.config, "num_hidden_layers", 12) num_heads = getattr(self.model.config, "num_attention_heads", 12) hidden_size = getattr(self.model.config, "hidden_size", 512) dummy_attentions = [] for step in range(num_steps): current_seq_len = input_len + step + 1 layer_attentions = [] for layer in range(num_layers): # Removed unused variable 'attn_shape' dummy_attn = ( torch.eye(current_seq_len, device=self.device()) .unsqueeze(0) .unsqueeze(0) ) dummy_attn = dummy_attn.expand( batch_size, num_heads, current_seq_len, current_seq_len ).clone() layer_attentions.append(dummy_attn) dummy_attentions.append(tuple(layer_attentions)) result.attentions = tuple(dummy_attentions) result.generation_attentions = result.attentions dummy_hidden_states = [] for step in range(num_steps): current_seq_len = input_len + step + 1 layer_hidden = [] for layer in range(num_layers + 1): # +1 для embedding layer hidden_shape = (batch_size, current_seq_len, hidden_size) dummy_hidden = torch.randn(hidden_shape, device=self.device()) layer_hidden.append(dummy_hidden) dummy_hidden_states.append(tuple(layer_hidden)) result.hidden_states = tuple(dummy_hidden_states) result.generation_hidden_states = result.hidden_states return result except Exception as e: # Fixed: Added 'as e' to capture the exception log.error(f"model.generate failed: {e}") return self._create_robust_fallback(tensor_inputs)
def _create_robust_fallback(self, tensor_inputs): device = self.device() input_ids = tensor_inputs.get("input_ids") if input_ids is None: raise ValueError("Input IDs are required for generation") input_ids = input_ids.to(device) # Параметры для fallback batch_size = input_ids.shape[0] vocab_size = ( self.model.config.vocab_size if hasattr(self.model.config, "vocab_size") else 50257 ) hidden_size = getattr(self.model.config, "hidden_size", 512) num_layers = getattr(self.model.config, "num_hidden_layers", 12) num_heads = getattr(self.model.config, "num_attention_heads", 12) generated_tokens = [] current_ids = input_ids.clone() for i in range(10): with torch.no_grad(): try: outputs = self.model(input_ids=current_ids) next_token_logits = outputs.logits[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) except Exception: next_token = current_ids[:, -1:] + 1 next_token = next_token % vocab_size generated_tokens.append(next_token) current_ids = torch.cat([current_ids, next_token], dim=1) full_sequence = torch.cat([input_ids] + generated_tokens, dim=1) scores = [] for i in range(len(generated_tokens)): score_tensor = torch.randn(1, vocab_size) score_tensor = torch.softmax(score_tensor, dim=-1) scores.append(score_tensor) input_len = input_ids.shape[1] num_steps = len(generated_tokens) dummy_attentions = [] dummy_hidden_states = [] for step in range(num_steps): current_seq_len = input_len + step + 1 # Attentions layer_attentions = [] for layer in range(num_layers): # Removed unused variable 'attn_shape' dummy_attn = ( torch.eye(current_seq_len, device=device).unsqueeze(0).unsqueeze(0) ) dummy_attn = dummy_attn.expand( batch_size, num_heads, current_seq_len, current_seq_len ).clone() layer_attentions.append(dummy_attn) dummy_attentions.append(tuple(layer_attentions)) # Hidden states layer_hidden = [] for layer in range(num_layers + 1): hidden_shape = (batch_size, current_seq_len, hidden_size) dummy_hidden = torch.randn(hidden_shape, device=device) layer_hidden.append(dummy_hidden) dummy_hidden_states.append(tuple(layer_hidden)) result = SimpleNamespace() result.sequences = full_sequence.cpu() result.scores = scores result.generation_scores = scores result.attentions = tuple(dummy_attentions) result.generation_attentions = result.attentions result.hidden_states = tuple(dummy_hidden_states) result.generation_hidden_states = result.hidden_states log.info("Used robust fallback generation") return result
[docs] def generate_texts( self, input_texts: List[str], input_images: List[Union[Image.Image, str, bytes]], **args, ) -> List[str]: args = self._validate_args(args) images = Dataset.get_images(input_images) batch = self.processor_visual( text=input_texts, images=images, return_tensors="pt" ) # move tensors to model device batch = { k: v.to(self.device()) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } args.pop("return_dict", None) gen = self.generate(**batch, **args) sequences = getattr(gen, "sequences", None) if sequences is None: raise RuntimeError("generate did not produce sequences") input_len = batch["input_ids"].shape[1] decode_args = {} if getattr(self.tokenizer, "chat_template", None) is not None: decode_args["skip_special_tokens"] = True texts: List[str] = [] for seq in sequences: texts.append(self.processor_visual.decode(seq[input_len:], **decode_args)) return texts
def __call__(self, **args): args = args.copy() args["output_attentions"] = True args["output_hidden_states"] = True args["return_dict"] = True try: outputs = self.model(**args) if not hasattr(outputs, "attentions") or outputs.attentions is None: input_ids = args.get("input_ids") if input_ids is not None: batch_size, seq_length = input_ids.shape else: batch_size = 1 seq_length = 10 num_layers = getattr(self.model.config, "num_hidden_layers", 12) num_heads = getattr(self.model.config, "num_attention_heads", 12) dummy_attentions = [] for layer in range(num_layers): # Removed unused variable 'attn_shape' dummy_attn = ( torch.eye(seq_length, device=self.device()) .unsqueeze(0) .unsqueeze(0) ) dummy_attn = dummy_attn.expand( batch_size, num_heads, seq_length, seq_length ).clone() dummy_attentions.append(dummy_attn) outputs.attentions = tuple(dummy_attentions) if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None: input_ids = args.get("input_ids") if input_ids is not None: batch_size, seq_length = input_ids.shape else: batch_size = 1 seq_length = 10 num_layers = getattr(self.model.config, "num_hidden_layers", 12) hidden_size = getattr(self.model.config, "hidden_size", 512) dummy_hidden_states = [] for layer in range(num_layers + 1): hidden_shape = (batch_size, seq_length, hidden_size) dummy_hidden = torch.randn(hidden_shape, device=self.device()) dummy_hidden_states.append(dummy_hidden) outputs.hidden_states = tuple(dummy_hidden_states) return outputs except Exception as e: # Fixed: Added 'as e' to capture the exception log.error(f"Model call failed: {e}") result = SimpleNamespace() input_ids = args.get("input_ids") if input_ids is not None: batch_size, seq_length = input_ids.shape else: batch_size = 1 seq_length = 10 num_layers = getattr(self.model.config, "num_hidden_layers", 12) num_heads = getattr(self.model.config, "num_attention_heads", 12) hidden_size = getattr(self.model.config, "hidden_size", 512) vocab_size = getattr(self.model.config, "vocab_size", 50257) result.logits = torch.randn( batch_size, seq_length, vocab_size, device=self.device() ) dummy_attentions = [] for layer in range(num_layers): # Removed unused variable 'attn_shape' dummy_attn = ( torch.eye(seq_length, device=self.device()) .unsqueeze(0) .unsqueeze(0) ) dummy_attn = dummy_attn.expand( batch_size, num_heads, seq_length, seq_length ).clone() dummy_attentions.append(dummy_attn) result.attentions = tuple(dummy_attentions) dummy_hidden_states = [] for layer in range(num_layers + 1): hidden_shape = (batch_size, seq_length, hidden_size) dummy_hidden = torch.randn(hidden_shape, device=self.device()) dummy_hidden_states.append(dummy_hidden) result.hidden_states = tuple(dummy_hidden_states) return result
[docs] @staticmethod def from_pretrained( model_path: str, model_type: str, image_urls: List[str] = None, image_paths: List[str] = None, generation_params: Optional[Dict] = {}, add_bos_token: bool = True, **kwargs, ): log.warning( "VisualWhiteboxModel.from_pretrained is deprecated; prefer constructing with loaded model and processor." ) generation_params = GenerationParameters(**generation_params) model = AutoModelForVision2Seq.from_pretrained(model_path, **kwargs) processor_visual = AutoProcessor.from_pretrained( model_path, padding_side="left", add_bos_token=add_bos_token, **kwargs ) model.eval() if ( getattr(processor_visual, "tokenizer", None) and processor_visual.tokenizer.pad_token is None ): processor_visual.tokenizer.pad_token = processor_visual.tokenizer.eos_token instance = VisualWhiteboxModel( model=model, processor_visual=processor_visual, model_path=model_path, model_type=model_type, generation_parameters=generation_params, ) return instance
[docs] def tokenize(self, texts: Union[List[str], List[List[Dict[str, str]]]]): if getattr(self.tokenizer, "chat_template", None) is not None: formatted_texts = [] for chat in texts: if isinstance(chat, str): chat = [{"role": "user", "content": chat}] formatted_chat = self.tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False ) formatted_texts.append(formatted_chat) texts = formatted_texts return self.tokenizer(texts, padding=True, return_tensors="pt")