import torch
import numpy as np
from typing import Dict, List, Tuple
from .stat_calculator import StatCalculator
from lm_polygraph.model_adapters.visual_whitebox_model import VisualWhiteboxModel
[docs]class AttentionForwardPassCalculatorVisual(StatCalculator):
"""
For VisualLM (vision-language) WhiteboxModel: computes attention weights
during forward pass using greedy tokens + image inputs.
"""
def __init__(self):
super().__init__()
def __call__(
self,
dependencies: Dict[str, np.ndarray],
texts: List[str],
model: VisualWhiteboxModel,
max_new_tokens: int = 100,
) -> Dict[str, np.ndarray]:
"""
Parameters:
dependencies: includes greedy_tokens and images.
texts: List of input texts
model: VisualWhiteboxModel wrapper.
Returns:
Dict with 'forwardpass_attention_weights' numpy array.
"""
images = dependencies["images"]
forwardpass_attention_weights = []
for i in range(len(texts)):
# Process each sample individually
encoding = model.processor_visual(
text=str(texts[i]), images=images[i], return_tensors="pt"
).to(model.device())
# Forward pass with attention output
with torch.no_grad():
outputs = model(**encoding)
# Safely get attentions
if hasattr(outputs, "attentions") and outputs.attentions is not None:
# Convert to CPU and numpy
attentions_cpu = []
for attention in outputs.attentions:
if attention is not None:
attentions_cpu.append(attention.cpu().float())
else:
# Create dummy attention if None
batch_size = encoding["input_ids"].shape[0]
seq_len = encoding["input_ids"].shape[1]
# Get model config for proper dimensions
try:
_cfg = getattr(
model.model.config,
"text_config",
model.model.config,
)
num_heads = _cfg.num_attention_heads
except Exception: # Fixed: removed unused 'e'
num_heads = 12 # fallback
# Create identity matrix for dummy attention
dummy_attn = (
torch.eye(seq_len, device="cpu")
.unsqueeze(0)
.unsqueeze(0)
)
dummy_attn = dummy_attn.expand(
batch_size, num_heads, seq_len, seq_len
)
attentions_cpu.append(dummy_attn)
if attentions_cpu:
# Stack along layer dimension
forwardpass_attentions = torch.stack(
attentions_cpu, dim=0
).numpy()
else:
# Fallback if no attentions
batch_size, seq_len = encoding["input_ids"].shape
try:
_cfg = getattr(
model.model.config, "text_config", model.model.config
)
num_layers = _cfg.num_hidden_layers
num_heads = _cfg.num_attention_heads
except Exception: # Fixed: removed unused 'e'
num_layers = 12
num_heads = 12
forwardpass_attentions = np.ones(
(num_layers, batch_size, num_heads, seq_len, seq_len)
)
else:
# Fallback if model didn't return attentions
batch_size, seq_len = encoding["input_ids"].shape
try:
_cfg = getattr(
model.model.config, "text_config", model.model.config
)
num_layers = _cfg.num_hidden_layers
num_heads = _cfg.num_attention_heads
except Exception: # Fixed: removed unused 'e'
num_layers = 12
num_heads = 12
forwardpass_attentions = np.ones(
(num_layers, batch_size, num_heads, seq_len, seq_len)
)
forwardpass_attention_weights.append(forwardpass_attentions)
# Handle padding if sequence lengths vary
try:
forwardpass_attention_weights = np.array(forwardpass_attention_weights)
except Exception: # Fixed: removed unused 'e'
# in this case we have various len of input_ids+greedy_tokens in batch, so pad before concat
max_seq_length = np.max(
[el.shape[-1] for el in forwardpass_attention_weights]
)
forwardpass_attention_weights_padded = []
for el in forwardpass_attention_weights:
buf_el = el
if el.shape[-1] != max_seq_length:
pad_mask = (
(0, 0),
(0, 0),
(0, max_seq_length - el.shape[-1]),
(0, max_seq_length - el.shape[-1]),
)
buf_el = np.pad(el, pad_mask, constant_values=np.nan)
forwardpass_attention_weights_padded.append(buf_el)
forwardpass_attention_weights = np.array(
forwardpass_attention_weights_padded
)
return {"forwardpass_attention_weights": forwardpass_attention_weights}