from lm_polygraph.utils.model import Model
from lm_polygraph.utils.generation_parameters import GenerationParameters
from transformers.generation import GenerateDecoderOnlyOutput
import torch
from typing import List
from copy import copy
[docs]class WhiteboxModelvLLM(Model):
"""Basic whitebox model adapter for using vLLM in stat calculators and uncertainty estimators."""
def __init__(
self,
model,
sampling_params,
generation_parameters: GenerationParameters = GenerationParameters(),
device: str = "cuda",
instruct: bool = False,
):
self.model = model
self.tokenizer = self.model.get_tokenizer()
self.tokenizer.pad_token = self.tokenizer.eos_token
self.sampling_params = sampling_params
self.generation_parameters = generation_parameters
self.instruct = instruct
stop_strings = getattr(self.generation_parameters, "stop_strings", None)
if stop_strings is None:
stop_strings = []
self.sampling_params.stop = list(stop_strings)
for param in [
"presence_penalty",
"repetition_penalty",
"temperature",
"top_p",
"top_k",
]:
setattr(
self.sampling_params,
param,
getattr(self.generation_parameters, param, None),
)
self.base_device = device
self.model_type = "vLLMCausalLM"
[docs] def generate(self, *args, **kwargs):
sampling_params = copy(self.sampling_params)
sampling_params.n = kwargs.get("num_return_sequences", 1)
if "max_new_tokens" in kwargs:
sampling_params.max_tokens = kwargs["max_new_tokens"]
texts = self.tokenizer.batch_decode(
kwargs["input_ids"], skip_special_tokens=True
)
if self.instruct:
chats = []
for text in texts:
chat = [
{
"role": "system",
"content": "You are a knowledgeable assistant who answers questions concisely and accurately and strictly follows output formatting instructions.",
},
{
"role": "user",
"content": text,
},
]
chats.append(chat)
output = self.model.chat(*args, chats, sampling_params)
else:
output = self.model.generate(*args, texts, sampling_params)
return self.post_processing(output)
[docs] def device(self):
return self.base_device
[docs] def tokenize(self, texts):
output = self.tokenizer(texts, return_tensors="pt", padding=True)
return output
def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)
[docs] def generate_texts(self, input_texts: List[str], **args):
outputs = self.generate(input_texts, **args)
texts = [
outputs.text
for sampled_outputs in outputs
for outputs in sampled_outputs.outputs
]
return texts
[docs] def post_processing(self, outputs):
vocab_size = max(
self.tokenizer.vocab_size, max(self.tokenizer.added_tokens_decoder.keys())
)
logits = []
sequences = []
max_seq_len = max(
[
len(output.token_ids)
for sampled_outputs in outputs
for output in sampled_outputs.outputs
]
)
for sample_output in outputs:
for output in sample_output.outputs:
log_prob = torch.zeros((max_seq_len, vocab_size)).fill_(-torch.inf)
sequence = (
torch.zeros(max_seq_len).fill_(self.tokenizer.eos_token_id).long()
)
for i, probs in enumerate(output.logprobs):
top_tokens = torch.tensor(list(probs.keys()))
top_values = torch.tensor([lp.logprob for lp in probs.values()])
log_prob[i, top_tokens] = top_values
sequence[i] = output.token_ids[i]
logits.append(log_prob)
sequences.append(sequence)
standard_output = GenerateDecoderOnlyOutput(
sequences=sequences, logits=logits, scores=logits
)
return standard_output