| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Dict, List |
| |
|
| | import torch |
| | from PIL.Image import Image |
| | from transformers import LlamaTokenizerFast |
| | from transformers.processing_utils import ProcessorMixin |
| |
|
| | from .image_processing_vlm import VLMImageProcessor |
| | from .conversation import get_conv_template |
| |
|
| |
|
| | class DictOutput(object): |
| | def keys(self): |
| | return self.__dict__.keys() |
| |
|
| | def __getitem__(self, item): |
| | return self.__dict__[item] |
| |
|
| | def __setitem__(self, key, value): |
| | self.__dict__[key] = value |
| |
|
| |
|
| | @dataclass |
| | class VLChatProcessorOutput(DictOutput): |
| | sft_format: str |
| | input_ids: torch.Tensor |
| | pixel_values: torch.Tensor |
| | num_image_tokens: torch.IntTensor |
| |
|
| | def __len__(self): |
| | return len(self.input_ids) |
| |
|
| |
|
| | @dataclass |
| | class BatchedVLChatProcessorOutput(DictOutput): |
| | sft_format: List[str] |
| | input_ids: torch.Tensor |
| | pixel_values: torch.Tensor |
| | attention_mask: torch.Tensor |
| | images_seq_mask: torch.BoolTensor |
| | images_emb_mask: torch.BoolTensor |
| |
|
| | def to(self, device, dtype=torch.bfloat16): |
| | self.input_ids = self.input_ids.to(device) |
| | self.attention_mask = self.attention_mask.to(device) |
| | self.images_seq_mask = self.images_seq_mask.to(device) |
| | self.images_emb_mask = self.images_emb_mask.to(device) |
| | self.pixel_values = self.pixel_values.to(device=device, dtype=dtype) |
| | return self |
| |
|
| |
|
| | class VLChatProcessor(ProcessorMixin): |
| | image_processor_class = "AutoImageProcessor" |
| | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") |
| |
|
| | attributes = ["image_processor", "tokenizer"] |
| |
|
| | system_prompt = ( |
| | "You are a helpful language and vision assistant. " |
| | "You are able to understand the visual content that the user provides, " |
| | "and assist the user with a variety of tasks using natural language." |
| | ) |
| |
|
| | def __init__( |
| | self, |
| | image_processor: VLMImageProcessor, |
| | tokenizer: LlamaTokenizerFast, |
| | image_tag: str = "<image_placeholder>", |
| | image_start_tag: str = "<begin_of_image>", |
| | image_end_tag: str = "<end_of_image>", |
| | pad_tag: str = "<|▁pad▁|>", |
| | num_image_tokens: int = 576, |
| | add_special_token: bool = False, |
| | sft_format: str = "deepseek", |
| | mask_prompt: bool = True, |
| | ignore_id: int = -100, |
| | **kwargs, |
| | ): |
| | self.image_processor = image_processor |
| | self.tokenizer = tokenizer |
| |
|
| | image_id = self.tokenizer.vocab.get(image_tag) |
| | if image_id is None: |
| | special_tokens = [image_tag] |
| | special_tokens_dict = {"additional_special_tokens": special_tokens} |
| | self.tokenizer.add_special_tokens(special_tokens_dict) |
| | print(f"Add image tag = {image_tag} to the tokenizer") |
| |
|
| | self.image_tag = image_tag |
| | self.image_start_tag = image_start_tag |
| | self.image_end_tag = image_end_tag |
| | self.pad_tag = pad_tag |
| |
|
| | self.num_image_tokens = num_image_tokens |
| | self.add_special_token = add_special_token |
| | self.sft_format = sft_format |
| | self.mask_prompt = mask_prompt |
| | self.ignore_id = ignore_id |
| |
|
| | super().__init__( |
| | image_processor, |
| | tokenizer, |
| | image_tag, |
| | num_image_tokens, |
| | add_special_token, |
| | sft_format, |
| | mask_prompt, |
| | ignore_id, |
| | **kwargs, |
| | ) |
| |
|
| | def new_chat_template(self): |
| | conv = get_conv_template(self.sft_format) |
| | conv.set_system_message(self.system_prompt) |
| | return conv |
| |
|
| | def apply_sft_template_for_multi_turn_prompts( |
| | self, |
| | conversations: List[Dict[str, str]], |
| | sft_format: str = "deepseek", |
| | system_prompt: str = "", |
| | ): |
| | """ |
| | Applies the SFT template to conversation. |
| | |
| | An example of conversation: |
| | conversation = [ |
| | { |
| | "role": "User", |
| | "content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?", |
| | "images": [ |
| | "./multi-images/attribute_comparison_1.png", |
| | "./multi-images/attribute_comparison_2.png" |
| | ] |
| | }, |
| | { |
| | "role": "Assistant", |
| | "content": "" |
| | } |
| | ] |
| | |
| | Args: |
| | conversations (List[Dict]): A conversation with a List of Dict[str, str] text. |
| | sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". |
| | system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". |
| | |
| | Returns: |
| | sft_prompt (str): The formatted text. |
| | """ |
| |
|
| | conv = get_conv_template(sft_format) |
| | conv.set_system_message(system_prompt) |
| | for message in conversations: |
| | conv.append_message(message["role"], message["content"].strip()) |
| | sft_prompt = conv.get_prompt().strip() |
| |
|
| | return sft_prompt |
| |
|
| | @property |
| | def image_token(self): |
| | return self.image_tag |
| |
|
| | @property |
| | def image_id(self): |
| | image_id = self.tokenizer.vocab.get(self.image_tag) |
| | return image_id |
| |
|
| | @property |
| | def image_start_id(self): |
| | image_start_id = self.tokenizer.vocab.get(self.image_start_tag) |
| | return image_start_id |
| |
|
| | @property |
| | def image_end_id(self): |
| | image_end_id = self.tokenizer.vocab.get(self.image_end_tag) |
| | return image_end_id |
| |
|
| | @property |
| | def image_start_token(self): |
| | return self.image_start_tag |
| |
|
| | @property |
| | def image_end_token(self): |
| | return self.image_end_tag |
| |
|
| | @property |
| | def pad_id(self): |
| | pad_id = self.tokenizer.vocab.get(self.pad_tag) |
| | |
| | |
| | |
| |
|
| | return pad_id |
| |
|
| | def add_image_token( |
| | self, |
| | image_indices: List[int], |
| | input_ids: torch.LongTensor, |
| | ): |
| | """ |
| | |
| | Args: |
| | image_indices (List[int]): [index_0, index_1, ..., index_j] |
| | input_ids (torch.LongTensor): [N] |
| | |
| | Returns: |
| | input_ids (torch.LongTensor): [N + image tokens] |
| | num_image_tokens (torch.IntTensor): [n_images] |
| | """ |
| |
|
| | input_slices = [] |
| |
|
| | start = 0 |
| | for index in image_indices: |
| | if self.add_special_token: |
| | end = index + 1 |
| | else: |
| | end = index |
| |
|
| | |
| | input_slices.append(input_ids[start:end]) |
| |
|
| | |
| | input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long)) |
| | input_slices.append( |
| | self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) |
| | ) |
| | input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long)) |
| | start = index + 1 |
| |
|
| | |
| | input_slices.append(input_ids[start:]) |
| |
|
| | |
| | input_ids = torch.cat(input_slices, dim=0) |
| | num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) |
| |
|
| | return input_ids, num_image_tokens |
| |
|
| | def process_one( |
| | self, |
| | prompt: str = None, |
| | conversations: List[Dict[str, str]] = None, |
| | images: List[Image] = None, |
| | **kwargs, |
| | ): |
| | """ |
| | |
| | Args: |
| | prompt (str): the formatted prompt; |
| | conversations (List[Dict]): conversations with a list of messages; |
| | images (List[ImageType]): the list of images; |
| | **kwargs: |
| | |
| | Returns: |
| | outputs (BaseProcessorOutput): the output of the processor, |
| | - input_ids (torch.LongTensor): [N + image tokens] |
| | - target_ids (torch.LongTensor): [N + image tokens] |
| | - images (torch.FloatTensor): [n_images, 3, H, W] |
| | - image_id (int): the id of the image token |
| | - num_image_tokens (List[int]): the number of image tokens |
| | """ |
| |
|
| | assert ( |
| | prompt is None or conversations is None |
| | ), "prompt and conversations cannot be used at the same time." |
| |
|
| | if prompt is None: |
| | |
| | sft_format = self.apply_sft_template_for_multi_turn_prompts( |
| | conversations=conversations, |
| | sft_format=self.sft_format, |
| | system_prompt=self.system_prompt, |
| | ) |
| | else: |
| | sft_format = prompt |
| |
|
| | |
| | input_ids = self.tokenizer.encode(sft_format) |
| | input_ids = torch.LongTensor(input_ids) |
| |
|
| | |
| | image_token_mask: torch.BoolTensor = input_ids == self.image_id |
| | image_indices = image_token_mask.nonzero() |
| | input_ids, num_image_tokens = self.add_image_token( |
| | image_indices=image_indices, |
| | input_ids=input_ids, |
| | ) |
| |
|
| | |
| | images_outputs = self.image_processor(images, return_tensors="pt") |
| |
|
| | prepare = VLChatProcessorOutput( |
| | sft_format=sft_format, |
| | input_ids=input_ids, |
| | pixel_values=images_outputs.pixel_values, |
| | num_image_tokens=num_image_tokens, |
| | ) |
| |
|
| | return prepare |
| |
|
| | def __call__( |
| | self, |
| | *, |
| | prompt: str = None, |
| | conversations: List[Dict[str, str]] = None, |
| | images: List[Image] = None, |
| | force_batchify: bool = True, |
| | **kwargs, |
| | ): |
| | """ |
| | |
| | Args: |
| | prompt (str): the formatted prompt; |
| | conversations (List[Dict]): conversations with a list of messages; |
| | images (List[ImageType]): the list of images; |
| | force_batchify (bool): force batchify the inputs; |
| | **kwargs: |
| | |
| | Returns: |
| | outputs (BaseProcessorOutput): the output of the processor, |
| | - input_ids (torch.LongTensor): [N + image tokens] |
| | - images (torch.FloatTensor): [n_images, 3, H, W] |
| | - image_id (int): the id of the image token |
| | - num_image_tokens (List[int]): the number of image tokens |
| | """ |
| |
|
| | prepare = self.process_one( |
| | prompt=prompt, conversations=conversations, images=images |
| | ) |
| |
|
| | if force_batchify: |
| | prepare = self.batchify([prepare]) |
| |
|
| | return prepare |
| |
|
| | def batchify( |
| | self, prepare_list: List[VLChatProcessorOutput] |
| | ) -> BatchedVLChatProcessorOutput: |
| | """ |
| | Preprocesses the inputs for multimodal inference. |
| | |
| | Args: |
| | prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. |
| | |
| | Returns: |
| | BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. |
| | """ |
| |
|
| | batch_size = len(prepare_list) |
| | sft_format = [] |
| | n_images = [] |
| | seq_lens = [] |
| | for prepare in prepare_list: |
| | n_images.append(len(prepare.num_image_tokens)) |
| | seq_lens.append(len(prepare)) |
| |
|
| | input_token_max_len = max(seq_lens) |
| | max_n_images = max(1, max(n_images)) |
| |
|
| | batched_input_ids = torch.full( |
| | (batch_size, input_token_max_len), self.pad_id |
| | ).long() |
| | batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() |
| | batched_pixel_values = torch.zeros( |
| | (batch_size, max_n_images, *self.image_processor.default_shape) |
| | ).float() |
| | batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() |
| | batched_images_emb_mask = torch.zeros( |
| | (batch_size, max_n_images, self.num_image_tokens) |
| | ).bool() |
| |
|
| | for i, prepare in enumerate(prepare_list): |
| | input_ids = prepare.input_ids |
| | seq_len = len(prepare) |
| | n_image = len(prepare.num_image_tokens) |
| | |
| | batched_attention_mask[i, -seq_len:] = 1 |
| | batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) |
| | batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id |
| |
|
| | if n_image > 0: |
| | batched_pixel_values[i, :n_image] = prepare.pixel_values |
| | for j, n_image_tokens in enumerate(prepare.num_image_tokens): |
| | batched_images_emb_mask[i, j, :n_image_tokens] = True |
| |
|
| | sft_format.append(prepare.sft_format) |
| |
|
| | batched_prepares = BatchedVLChatProcessorOutput( |
| | input_ids=batched_input_ids, |
| | attention_mask=batched_attention_mask, |
| | pixel_values=batched_pixel_values, |
| | images_seq_mask=batched_images_seq_mask, |
| | images_emb_mask=batched_images_emb_mask, |
| | sft_format=sft_format, |
| | ) |
| |
|
| | return batched_prepares |
| |
|