diff --git a/xtuner/chat/__init__.py b/xtuner/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/chat/backend/__init__.py b/xtuner/chat/backend/__init__.py new file mode 100644 index 000000000..54351fa29 --- /dev/null +++ b/xtuner/chat/backend/__init__.py @@ -0,0 +1,5 @@ +from .encoder import VisionEncoderForDeploy +from .huggingface import HFBackend +from .lmdeploy import LMDeployBackend + +__all__ = ['VisionEncoderForDeploy', 'HFBackend', 'LMDeployBackend'] diff --git a/xtuner/chat/backend/base.py b/xtuner/chat/backend/base.py new file mode 100644 index 000000000..0a0fd4bbe --- /dev/null +++ b/xtuner/chat/backend/base.py @@ -0,0 +1,26 @@ +from abc import abstractmethod + +from xtuner.types import HybridChatTemplate + + +class BaseBackend(): + + @property + def chat_template(self) -> HybridChatTemplate: + pass + + @abstractmethod + def create_streamer(self, iterable=False): + pass + + @abstractmethod + def chat(self, messages, streamer=None, generation_config=None): + pass + + # @abstractmethod + # def response_with_function_call(self, response: str): + # pass + + # @abstractmethod + # def response_with_code_interpreter(self, response: str): + # pass diff --git a/xtuner/chat/backend/encoder.py b/xtuner/chat/backend/encoder.py new file mode 100644 index 000000000..af05b78df --- /dev/null +++ b/xtuner/chat/backend/encoder.py @@ -0,0 +1,308 @@ +import base64 +import os +from io import BytesIO +from typing import List, Literal, Optional, Union + +import requests +import torch +from peft import PeftModel +from PIL import Image +from torch import nn +from transformers import AutoModel, CLIPImageProcessor, CLIPVisionModel + +from xtuner.dataset.utils import expand2square + + +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: + """load image from base64 format.""" + return Image.open(BytesIO(base64.b64decode(image))) + + +def load_image(image_url: str) -> Image.Image: + """load image from url, local path or openai GPT4V.""" + + headers = { + 'User-Agent': + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' + } + if image_url.startswith('http'): + response = requests.get(image_url, headers=headers) + response.raise_for_status() + + # Open the image using PIL + img = Image.open(BytesIO(response.content)) + elif image_url.startswith('data:image'): + img = load_image_from_base64(image_url.split(',')[1]) + else: + img = Image.open(image_url) + + return img + + +ModelHub = Literal['huggingface', 'modelscope'] + + +class VisionEncoderForDeploy(nn.Module): + + def __init__(self, + model_name_or_path: str, + projector_name_or_path: str, + adapter_name_or_path: str = None, + select_layer: int = -2, + hub: ModelHub = 'huggingface', + device='cuda'): + + super().__init__() + + # model_path = self._parse_model_path(xtuner_model_name_or_path, hub) + # visual_encoder_path = self._parse_visual_encoder_path( + # model_path, visual_encoder_name_or_path, hub + # ) + # projector_path = self._parse_projector_path(model_path) + + # # parse visual encoder adapter path. + # vis_enc_adapter_path = self._parse_vis_enc_adapter_path(model_path) + + self.select_layer = select_layer + self.image_processor = CLIPImageProcessor.from_pretrained( + model_name_or_path) + print(f'Load Image Processor From {model_name_or_path}') + + visual_encoder = CLIPVisionModel.from_pretrained( + model_name_or_path, torch_dtype=torch.float16) + print(f'Load Visual Encoder From {model_name_or_path}') + + # when path is None, means without visual encoder adapter + if adapter_name_or_path: + self.visual_encoder = PeftModel.from_pretrained( + visual_encoder, adapter_name_or_path) + print(f'Load Visual Encoder Adapter From {adapter_name_or_path}') + else: + self.visual_encoder = visual_encoder + + self.projector = AutoModel.from_pretrained( + projector_name_or_path, + torch_dtype=torch.float16, + trust_remote_code=True) + print(f'Load Projector from {projector_name_or_path}') + + self.dtype = torch.float16 + self.device = device + self.to(self.device) + self.to(self.dtype) + + def process_img(self, image: Image.Image) -> List[torch.Tensor]: + """Preprocess the input image, including expanding to square and + normalization. + + Args: + image (Image.Image): The input image need to be preprocessed. + + Returns: + torch.Tensor: The preprocessed image tensor. + """ + + if isinstance(image, str): + image = load_image(image) + + if not isinstance(image, Image.Image): + raise TypeError(f"Don't support {type(image).__name__}, " + 'the image type must be `PIL.Image`.') + + processor = self.image_processor + image_mean = processor.image_mean + + background_color = tuple(int(x * 255) for x in image_mean) + squared_img = expand2square(image, background_color) + + processed = processor.preprocess(squared_img, return_tensors='pt') + img_tensor = processed['pixel_values'][0] # shape: 3, h, w + + # before this line, `img_tensor` is on cpu. + img_tensor = img_tensor.to(self.device).to(self.dtype) + return img_tensor + + @torch.no_grad() + def forward(self, images: List[Union[str, + Image.Image]]) -> List[torch.Tensor]: + """Obtain the corresponding embeddings based on the images. + + Args: + images (List[Image.Image]): The input images. The data layout + for each image is (c, h, w). + + Returns: + List[torch.Tensor]: The list of extracted features from images. + The data layout for each tensor should be (tokens, dims). + """ + + num_imgs = len(images) + + img_tensors = [self.process_img(img) for img in images] + + # Determine if all image sizes are consistent. + # TODO (pppppM): Confirm when the image size will be inconsistent + shape_consistant = all(x.shape == img_tensors[0].shape + for x in img_tensors) + + from transformers.modeling_outputs import BaseModelOutputWithPooling + + if shape_consistant: + # Batch inference when all image sizes are consistent. + # img_tensors[0] shape: (3, h, w) + # tensor shape: (num_imgs, 3, h, w) + tensor = torch.stack(img_tensors, dim=0) + + enc_out = self.visual_encoder(tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + + # feat shape: (num_imgs, tokens, dims) + feat = self.projector(enc_out.hidden_states[self.select_layer][:, + 1:]) + + # Split along the batch dimension + # The feature of each image corresponds to a tensor. + # len(features): num_imgs, features[0] shape:(1, tokens, dims) + features = torch.chunk(feat, num_imgs, dim=0) + + # per image feature's layout should be (tokens, dims) + features = [x.flatten(0, 1) for x in features] + + else: + features = [] + for tensor in img_tensors: + tensor: torch.Tensor + # The visual encoder requires a data layout of (bs, c, h, w). + # tensor shape: (3, h, w) batch_tensor shape: (1, 3, h, w) + batch_tensor = tensor.unsqueeze(0) + enc_out = self.visual_encoder( + batch_tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + # feat shape: (1, tokens, dims) + feat = self.projector( + enc_out.hidden_states[self.select_layer][:, 1:]) + features.append(feat) + + return features + + def _parse_model_path(self, name_or_path: str, hub: ModelHub) -> str: + """Parse and get the directory path of the model. It supports load + model from local directory or download from the hub. + + Args: + name_or_path (str): The directory path or name of the model. + hub (str): The hub to download models from. + + Returns: + str: The local directory path of the model. + + Raises: + NotImplementedError: If the input hub is not supported currently. + """ + + if os.path.isdir(name_or_path): + model_path = name_or_path + else: + if hub == 'huggingface': + from huggingface_hub import snapshot_download + model_path = snapshot_download(repo_id=name_or_path) + elif hub == 'modelscope': + from modelscope import snapshot_download + model_path = snapshot_download(model_id=name_or_path) + else: + raise NotImplementedError( + 'Only supports downloading models from `Huggingface` or ' + '`Modelscope`.') + + return model_path + + def _parse_visual_encoder_path(self, model_path: str, + visual_encoder_name_or_path: str, + hub: ModelHub) -> str: + """Parse and get the directory path of the visual encoder. It supports + load visual encoder from local directory, download from the hub, or + find it in the XTuner model directory. + + Args: + model_path (str): The directory path of the model. + visual_encoder_name_or_path (Optional[str]): The directory path or + name of the visual encoder. + hub (str): The hub to download models from. + + Returns: + str: The local directory path of the visual encoder. + + Raises: + NotImplementedError: If the input hub is not supported currently. + """ + + if 'visual_encoder' in os.listdir(model_path): + assert visual_encoder_name_or_path is None + visual_encoder_path = os.path.join(model_path, 'visual_encoder') + elif os.path.isdir(visual_encoder_name_or_path): + visual_encoder_path = visual_encoder_name_or_path + else: + if hub == 'huggingface': + from huggingface_hub import snapshot_download + visual_encoder_path = snapshot_download( + repo_id=visual_encoder_name_or_path) + elif hub == 'modelscope': + from modelscope import snapshot_download + visual_encoder_path = snapshot_download( + model_id=visual_encoder_name_or_path) + else: + raise NotImplementedError( + 'Only supports downloading models from `Huggingface` or ' + '`Modelscope`.') + + return visual_encoder_path + + def _parse_projector_path(self, model_path: str) -> Optional[str]: + """Parse the path of the `projector` model according to the model path. + + Args: + model_path (str): The path to the model directory. + + Raises: + ValueError: If the 'projector' directory is not found in the + `model_path`. + + Returns: + Optional[str]: The full path of 'projector' directory if exists, + else raises ValueError. + """ + if 'projector' in os.listdir(model_path): + projector_path = os.path.join(model_path, 'projector') + else: + # Raises exception if 'projector' directory/folder not found + raise ValueError('Projector directory not found in given path') + return projector_path + + def _parse_vis_enc_adapter_path(self, model_path: str) -> Optional[str]: + """Parses the model path and returns the path to + 'visual_encoder_adapter' directory. + + Args: + model_path (str): The path to the model directory. + + Returns: + Optional[str]: The full path of 'visual_encoder_adapter' directory if exists, + else returns None. + """ + if 'visual_encoder_adapter' in os.listdir(model_path): + adapter_path = os.path.join(model_path, 'visual_encoder_adapter') + else: + # Returns None if 'visual_encoder_adapter' directory/folder not found + adapter_path = None + return adapter_path + + +if __name__ == '__main__': + img = load_image('llava.jpeg') + model = VisionEncoderForDeploy('xtuner/llava-internlm-7b', + 'openai/clip-vit-large-patch14-336') + + model.cuda() + model.eval() + outputs = model([img]) diff --git a/xtuner/chat/backend/huggingface.py b/xtuner/chat/backend/huggingface.py new file mode 100644 index 000000000..51e742327 --- /dev/null +++ b/xtuner/chat/backend/huggingface.py @@ -0,0 +1,224 @@ +from typing import Optional + +import torch +from peft import PeftModel +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) +from transformers import GenerationConfig as HFGenerationConfig +from transformers import PreTrainedModel, PreTrainedTokenizer + +from xtuner.chat.streamer import HFTextIteratorStreamer, HFTextStreamer +from xtuner.model.utils import LoadWoInit +from xtuner.tools.utils import get_stop_criteria +from xtuner.types import HybridChatMessages, HybridChatTemplate, SampleParams +from .base import BaseBackend + + +class _HFBackend(BaseBackend): + + def __init__( + self, + chat_template: HybridChatTemplate, + llm: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + + self.llm = llm + self.llm.cuda() + self.tokenizer = tokenizer + + self._chat_template = chat_template + + @property + def chat_template(self) -> HybridChatTemplate: + return self._chat_template + + @property + def eos_token_id(self): + if self.tokenizer.pad_token_id: + return self.tokenizer.eos_token_id + else: + return self.tokenizer.eos_token_id + + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + def build_llm_and_tokenizer(self, + model_name_or_path, + adapter=None, + bits=None): + + if bits is None: + quantization_config = None + load_in_8bit = False + elif bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + load_in_8bit = False + elif bits == 8: + quantization_config = None + load_in_8bit = True + + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + trust_remote_code=True, + encode_special_tokens=True) + + with LoadWoInit(): + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map='auto', + load_in_8bit=load_in_8bit, + quantization_config=quantization_config, + trust_remote_code=True, + torch_dtype=torch.float16) + + if adapter is not None: + model = PeftModel.from_pretrained(model, adapter) + + model.eval() + return model, tokenizer + + def response_with_code_interpreter(self, response: str): + return False + + def response_with_function_call(self, response: str): + return False + + def create_streamer(self, chat_template=None, iterable=False): + if iterable: + return HFTextIteratorStreamer( + self.tokenizer, skip_prompt=True, chat_template=chat_template) + else: + return HFTextStreamer( + self.tokenizer, skip_prompt=True, chat_template=chat_template) + + def parse_sample_params(self, params: SampleParams) -> HFGenerationConfig: + + if params is None: + params = SampleParams() + + hf_gen_config = HFGenerationConfig( + max_new_tokens=params.max_new_tokens, + do_sample=params.temperature > 0, + temperature=params.temperature, + top_k=params.top_k, + top_p=params.top_p, + repetition_penalty=params.repetition_penalty, + seed=params.seed, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id) + + stop_words = params.stop_words + stop_words.extend(self.chat_template.stop_words) + + return hf_gen_config, stop_words + + def chat(self, + messages: HybridChatMessages, + streamer=None, + sample_params: Optional[SampleParams] = None): + + prompt = messages.apply_chat_template(self.chat_template) + ids = self.tokenizer.encode(prompt, return_tensors='pt') + + hf_gen_config, stop_words = self.parse_sample_params(sample_params) + + stop_criteria = get_stop_criteria( + tokenizer=self.tokenizer, stop_words=stop_words) + + generate_output = self.llm.generate( + inputs=ids.cuda(), + streamer=streamer, + generation_config=hf_gen_config, + stopping_criteria=stop_criteria) + + output = self.tokenizer.decode( + generate_output[0][len(ids[0]):], skip_special_tokens=True) + + for word in stop_words: + output = output.rstrip(word) + + return output + + +class HFBackend(_HFBackend): + + def __init__( + self, + chat_template: HybridChatTemplate, + llm: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + vision_tower: Optional[torch.nn.Module] = None, + ) -> None: + super().__init__(chat_template, llm, tokenizer) + + if vision_tower: + self.vision_tower = vision_tower + self.vision_tower.cuda() + self.vision_tower.eval() + else: + self.vision_tower = None + + def chat(self, + messages: HybridChatMessages, + streamer=None, + sample_params=None): + + img_urls = messages.collect_img_urls() + + if self.vision_tower is None or len(img_urls) == 0: + return super().chat(messages, streamer, sample_params) + + prompt = messages.apply_chat_template(self.chat_template) + + img_features = self.vision_tower(img_urls) + + # prompt, img_ranges = _insert_img_pad_tokens( + # prompt, self.chat_template.image_token, img_features, + # self.tokenizer.pad_token + # ) + + chunks = prompt.split(self.chat_template.image_token) + assert len(chunks) - 1 == len(img_urls) + chunk_embeddings = [] + for i in range(len(chunks)): + + chunk_ids = self.tokenizer.encode(chunks[i], return_tensors='pt') + chunk_ids = chunk_ids.to(self.llm.device) + chunk_emb = self.llm.get_input_embeddings()(chunk_ids) + chunk_embeddings.append(chunk_emb) + + if i < len(chunks) - 1: + chunk_embeddings.append(img_features[i].unsqueeze(0)) + + embeddings = torch.cat(chunk_embeddings, dim=1) + + hf_gen_config, stop_words = self.parse_sample_params(sample_params) + + stop_criteria = get_stop_criteria( + tokenizer=self.tokenizer, stop_words=stop_words) + + generate_output = self.llm.generate( + input_ids=None, + inputs_embeds=embeddings, + streamer=streamer, + generation_config=hf_gen_config, + bos_token_id=self.tokenizer.bos_token_id, + stopping_criteria=stop_criteria) + + output = self.tokenizer.decode( + generate_output[0], skip_special_tokens=True) + + for word in stop_words: + output = output.rstrip(word) + + return output diff --git a/xtuner/chat/backend/lmdeploy/__init__.py b/xtuner/chat/backend/lmdeploy/__init__.py new file mode 100644 index 000000000..139c066fb --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/__init__.py @@ -0,0 +1,3 @@ +from .backend import LMDeployBackend + +__all__ = ['LMDeployBackend'] diff --git a/xtuner/chat/backend/lmdeploy/_encoder.py b/xtuner/chat/backend/lmdeploy/_encoder.py new file mode 100644 index 000000000..3466eb30f --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/_encoder.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import queue +import time +from threading import Thread +from typing import List, Union + +import torch +from lmdeploy.utils import get_logger +from PIL.Image import Image + +logger = get_logger('lmdeploy') + + +class Record: + """Batching manager.""" + + def __init__(self): + self.number = [] + self.waiting = [] + self.done = [] + self.res_que = [] + self.total = 0 + + def enqueue(self, images: List[Image], que: Union[queue.Queue, + asyncio.Queue]): + """add ith request to manager.""" + self.number.append(len(images)) + self.waiting.extend(images) + self.res_que.append(que) + self.total += len(images) + self.log('received', len(images)) + + def dequeue(self, max_batch_size): + """try to dequeue max batch size images.""" + inputs = self.waiting[:max_batch_size] + self.waiting = self.waiting[max_batch_size:] + self.total -= len(inputs) + self.log('process', len(inputs)) + return inputs + + def nofify(self): + """set result if request i is finished.""" + if len(self.number) == 0 or self.number[0] > len(self.done): + return False + num_images = self.number.pop(0) + outputs = self.done[:num_images] + self.done = self.done[num_images:] + que = self.res_que.pop(0) + if isinstance(que, queue.Queue): + que.put(outputs) + else: + que._loop.call_soon_threadsafe(que.put_nowait, outputs) + self.log('done', num_images) + return True + + def log(self, task: str, num: int): + logger.info(f'ImageEncoder {task} {num} images, ' + f'left {self.total} images.') + + +class _AsyncEncoderWrapper: + """Image encoder.""" + + def __init__(self, model, max_batch_size: int = 16): + self.model = model + self.max_batch_size = max_batch_size + self.loop = asyncio.new_event_loop() + self.work_thread = self._start_work_thread() + torch.cuda.empty_cache() + + def _start_work_thread(self): + """internal thread.""" + + def _work_thread(): + asyncio.set_event_loop(self.loop) + self.que = asyncio.Queue() + self.loop.run_until_complete(self._forward_loop()) + + thread = Thread(target=_work_thread, daemon=True) + thread.start() + return thread + + async def _forward_loop(self): + """working loop to process images.""" + logger.info('start ImageEncoder._forward_loop') + record = Record() + while True: + while record.total == 0 or (self.que.qsize() and + record.total < self.max_batch_size): + item = await self.que.get() + record.enqueue(item[0], item[1]) + inputs = record.dequeue(self.max_batch_size) + outputs = self.forward(inputs) + record.done.extend(outputs) + while record.nofify(): + pass + + def forward(self, inputs: List[Image]): + """Model forward.""" + time_start = time.perf_counter() + outputs = self.model.forward(inputs) + time_end = time.perf_counter() + logger.info(f'ImageEncoder forward {len(inputs)} images, ' + f'cost {time_end - time_start:.3f}s') + return outputs + + def infer(self, inputs: List[Image]): + """infer.""" + outputs = queue.Queue() + item = (inputs, outputs) + self.loop.call_soon_threadsafe(self.que.put_nowait, item) + results = outputs.get() + return results + + async def async_infer(self, inputs: List[Image]): + """async infer.""" + outputs = asyncio.Queue() + item = (inputs, outputs) + self.loop.call_soon_threadsafe(self.que.put_nowait, item) + results = await outputs.get() + return results diff --git a/xtuner/chat/backend/lmdeploy/_engine.py b/xtuner/chat/backend/lmdeploy/_engine.py new file mode 100644 index 000000000..d81d30c6c --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/_engine.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from lmdeploy.serve.async_engine import AsyncEngine +from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX + +from xtuner.types import HybridChatMessages, HybridChatTemplate + + +class _MMAsyncEngine(AsyncEngine): + """Visual Language Async inference engine.""" + + def __init__(self, + chat_template: HybridChatTemplate, + *args, + encoder=None, + **kwargs) -> None: + super().__init__(*args, **kwargs) + assert self.model_name == 'base' + self.encoder = encoder + self.chat_template = chat_template + + async def _get_prompt_input(self, prompt: HybridChatMessages, + do_preprocess: bool, sequence_start: bool): + """get input_ids, embeddings and offsets.""" + + decorated = prompt.apply_chat_template(self.chat_template) + segs = decorated.split(self.chat_template.image_token) + + results = {} + input_ids = [] + if len(segs) > 1: + assert self.encoder is not None + img_urls = prompt.collect_img_urls() + features = await self.encoder.async_infer(img_urls) + features = [x.cpu().numpy() for x in features] + input_ids = [] + begins = [] + ends = [] + for i, seg in enumerate(segs): + if i > 0: + image_dim = features[i - 1].shape[0] + begins.append(len(input_ids)) + ends.append(begins[-1] + image_dim) + input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) + seg_ids = self.tokenizer.encode( + seg, add_bos=((i == 0) and sequence_start)) + input_ids.extend(seg_ids) + ranges = np.stack([begins, ends], axis=1).tolist() + results['input_embeddings'] = features + results['input_embedding_ranges'] = ranges + else: + input_ids = self.tokenizer.encode( + decorated, add_bos=sequence_start) + + results['input_ids'] = input_ids + results['prompt'] = decorated + return results + + # def batch_infer(self, prompts: Union[VLPromptType, List[Dict], + # List[VLPromptType], List[List[Dict]]], + # **kwargs): + # """Inference a batch of prompts.""" + # # prompts = self._convert_prompts(prompts) + # return super().batch_infer(prompts, **kwargs) + + # def stream_infer(self, prompts: Union[VLPromptType, List[Dict], + # List[VLPromptType], + # List[List[Dict]]], **kwargs): + # """Inference a batch of prompts with stream mode.""" + # # prompts = self._convert_prompts(prompts) + # return super().stream_infer(prompts, **kwargs) + + # def __call__(self, prompts, **kwargs): + # """Inference a batch of prompts.""" + # # prompts = self._convert_prompts(prompts) + # return super().__call__(prompts, **kwargs) + + # def chat(self, prompts: VLPromptType, **kwargs): + # """chat.""" + # # _prompts = self._convert_prompts(prompts) + # sess = super().chat(_prompts, **kwargs) + + # # recover prompts & history + # sess._prompt = prompts + # last_round = sess.history[-1] + # sess.history[-1] = (prompts, last_round[-1]) + # return sess diff --git a/xtuner/chat/backend/lmdeploy/backend.py b/xtuner/chat/backend/lmdeploy/backend.py new file mode 100644 index 000000000..1df25fe81 --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/backend.py @@ -0,0 +1,107 @@ +import asyncio +import os +from typing import List, Optional, Union + +from lmdeploy.utils import get_logger + +from xtuner.types import HybridChatMessages, HybridChatTemplate, SampleParams +from ...streamer import LMDeployTextIteratorStreamer, LMDeployTextStreamer +from ..base import BaseBackend +from ._encoder import _AsyncEncoderWrapper +from ._engine import _MMAsyncEngine + +os.environ['TM_LOG_LEVEL'] = 'ERROR' +logger = get_logger('lmdeploy') +logger.setLevel('ERROR') + +_StreamerType = Union[LMDeployTextStreamer, LMDeployTextIteratorStreamer] + + +class LMDeployBackend(BaseBackend): + + def __init__(self, + chat_template, + llm_name_or_path, + vision_encoder=None) -> None: + super().__init__() + + if vision_encoder: + encoder = _AsyncEncoderWrapper(vision_encoder) + else: + encoder = None + + self._engine = _MMAsyncEngine( + chat_template, + encoder=encoder, + model_path=llm_name_or_path, + model_name='base') + + self._chat_template = chat_template + + @property + def chat_template(self) -> HybridChatTemplate: + return self._chat_template + + def create_streamer(self, iterable=False): + + if iterable: + return LMDeployTextIteratorStreamer() + else: + return LMDeployTextStreamer() + + def parse_sample_params(self, + params: SampleParams) -> 'LMGenerationConfig': + + if params is None: + params = SampleParams() + + stop_words = params.stop_words + stop_words.extend(self.chat_template.stop_words) + + from lmdeploy.messages import GenerationConfig as LMDGenerationConfig + lmd_gen_config = LMDGenerationConfig( + max_new_tokens=params.max_new_tokens, + temperature=params.temperature, + top_k=params.top_k, + top_p=params.top_p, + repetition_penalty=params.repetition_penalty, + random_seed=params.seed, + stop_words=stop_words) + + return lmd_gen_config + + def chat(self, + messages: HybridChatMessages, + streamer: Optional[_StreamerType] = None, + sample_params: Optional[SampleParams] = None): + + lmd_gen_config = self.parse_sample_params(sample_params) + self.session_id += 1 + import random + + generator = self._engine.generate( + messages, random.randint(1, 100000), gen_config=lmd_gen_config) + + async def get_response(): + out = '' + async for res in generator: + out += res.response + if streamer: + streamer.put(res.response) + if streamer: + streamer.end() + return out + + loop = asyncio.new_event_loop() + response = loop.run_until_complete(get_response()) + return response + + def batch_infer(self, + messages: List[HybridChatMessages], + sample_params: Optional[SampleParams] = None): + + lmd_gen_config = self.parse_sample_params(sample_params) + + results = self._engine.batch_infer(messages, gen_config=lmd_gen_config) + + return [r.text for r in results] diff --git a/xtuner/chat/conversation.py b/xtuner/chat/conversation.py new file mode 100644 index 000000000..a26616221 --- /dev/null +++ b/xtuner/chat/conversation.py @@ -0,0 +1,147 @@ +from xtuner.chat.backend import HFBackend +from xtuner.types.chat import (ChatMsg, HybridChatMessages, ImageContentItem, + TextContentItem) + + +class Conversation(): + + def __init__(self, + backend: HFBackend, + name=None, + system=None, + functions=None, + code_interpreter=None) -> None: + + self.name = name + self.backend = backend + self.system = system + self.functions = functions + self.code_interpreter = code_interpreter + self._messages = HybridChatMessages() + + if system: + msg = ChatMsg(role='system', content=system) + self._messages.messages.append(msg) + + @property + def messages(self): + return self._messages + + def add_message(self, role, content): + if role == 'system': + assert isinstance(content, str) + msg = ChatMsg(role='system', content=content) + self._messages.messages.append(msg) + elif role == 'user': + self._add_user(content) + elif role == 'assistant': + assert isinstance(content, str) + msg = ChatMsg(role='assistant', content=content) + self._messages.messages.append(msg) + + def _add_user(self, content): + + if isinstance(content, str): + msg = ChatMsg(role='user', content=content) + self._messages.messages.append(msg) + elif isinstance(content, list): + _content = [] + for item in content: + if isinstance(item, (ImageContentItem, TextContentItem)): + _content.append(item) + continue + + assert isinstance(item, dict) + assert 'type' in item + assert item['type'] in item + if item['type'] == 'image_url': + _item = ImageContentItem(image_url=item['image_url']) + _content.append(_item) + elif item['type'] == 'text': + _item = TextContentItem(text=item['text']) + _content.append(_item) + else: + raise NotImplementedError + + msg = ChatMsg(role='user', content=_content) + self._messages.messages.append(msg) + else: + raise TypeError + + def run(self, sample_params=None, streamer=None): + + self.add_message(role='user', content=content) + response = self.backend.chat(self.messages) + self.add_message(role='assistant', content=response) + return response + + def regenerate(self): + + assert self._messages.messages[-1].role == 'assistant' + self._messages.messages.pop() + return self.backend.chat(self.messages) + + def create_streamer(self, iterable=False): + return self.backend.create_streamer(iterable=iterable) + + +if __name__ == '__main__': + + from xtuner.types import HybridChatTemplate + chat_template = HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n' + ) + + from transformers import AutoModelForCausalLM, AutoTokenizer + + from xtuner.chat.backend import HFBackend, VisionEncoderForDeploy + + llm = AutoModelForCausalLM.from_pretrained( + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + trust_remote_code=True) + vision_tower = VisionEncoderForDeploy( + model_name_or_path='openai/clip-vit-large-patch14-336', + adapter_name_or_path= + '/mnt/petrelfs/share_data/linzhihao/model/models--xtuner--llava-internlm2-7b/snapshots/f363b45ce4787bd0a21d43ed724a70ee40ce69b2/visual_encoder_adapter', + projector_name_or_path= + '/mnt/petrelfs/share_data/linzhihao/model/models--xtuner--llava-internlm2-7b/snapshots/f363b45ce4787bd0a21d43ed724a70ee40ce69b2/projector' + ) + + llm.cuda() + + backend = HFBackend( + chat_template, + llm, + tokenizer, + vision_tower, + ) + + conv = Conversation(backend) + print(conv.chat('who are you?')) + + from xtuner.chat.backend import LMDeployBackend + backend = LMDeployBackend( + chat_template, + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + vision_tower) + conv = Conversation(backend) + print(conv.chat('who are you?')) + + content = [ + TextContentItem(text='Please describe this image'), + ImageContentItem(image_url='llava.jpeg') + ] + + print(conv.chat(content)) diff --git a/xtuner/chat/streamer/__init__.py b/xtuner/chat/streamer/__init__.py new file mode 100644 index 000000000..7f83155fc --- /dev/null +++ b/xtuner/chat/streamer/__init__.py @@ -0,0 +1,7 @@ +from .huggingface import HFTextIteratorStreamer, HFTextStreamer +from .lmdeploy import LMDeployTextIteratorStreamer, LMDeployTextStreamer + +__all__ = [ + 'HFTextIteratorStreamer', 'HFTextStreamer', 'LMDeployTextIteratorStreamer', + 'LMDeployTextStreamer' +] diff --git a/xtuner/chat/streamer/huggingface.py b/xtuner/chat/streamer/huggingface.py new file mode 100644 index 000000000..91b0f29aa --- /dev/null +++ b/xtuner/chat/streamer/huggingface.py @@ -0,0 +1,37 @@ +from transformers import TextIteratorStreamer, TextStreamer +from transformers.models.auto import AutoTokenizer + + +class HFTextIteratorStreamer(TextIteratorStreamer): + + def __init__(self, + tokenizer: AutoTokenizer, + skip_prompt: bool = False, + timeout=None, + chat_template=None, + **decode_kwargs): + super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) + self.chat_template = chat_template + + def on_finalized_text(self, text: str, stream_end: bool = False): + + for word in self.chat_template.stop_words: + text = text.rstrip(word) + super().on_finalized_text(text, stream_end) + + +class HFTextStreamer(TextStreamer): + + def __init__(self, + tokenizer: AutoTokenizer, + skip_prompt: bool = False, + chat_template=None, + **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.chat_template = chat_template + + def on_finalized_text(self, text: str, stream_end: bool = False): + + for word in self.chat_template.stop_words: + text = text.rstrip(word) + super().on_finalized_text(text, stream_end) diff --git a/xtuner/chat/streamer/lmdeploy.py b/xtuner/chat/streamer/lmdeploy.py new file mode 100644 index 000000000..2ec03e482 --- /dev/null +++ b/xtuner/chat/streamer/lmdeploy.py @@ -0,0 +1,49 @@ +from queue import Queue +from typing import Optional + +from transformers.generation.streamers import BaseStreamer + + +class LMDeployTextStreamer(BaseStreamer): + + def put(self, text): + self.on_finalized_text(text) + + def end(self): + """Flushes any remaining cache and prints a newline to stdout.""" + self.on_finalized_text('', stream_end=True) + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Prints the new text to stdout. + + If the stream is ending, also prints a newline. + """ + print(text, flush=True, end='' if not stream_end else None) + + +class LMDeployTextIteratorStreamer(LMDeployTextStreamer): + + def __init__(self, timeout: Optional[float] = None): + super().__init__() + self.text_queue = Queue() + self.stop_signal = None + self.timeout = timeout + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. + + If the stream is ending, also put a stop signal in the queue. + """ + self.text_queue.put(text, timeout=self.timeout) + if stream_end: + self.text_queue.put(self.stop_signal, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json index 89a82e4aa..667c82d21 100644 --- a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json @@ -2,61 +2,62 @@ "messages": [ {"role": "system", "content": "You are InternLM2-Chat, a harmless AI assistant"}, { - "role": "user", - "content": "Please help me process and visualize this dataset.", + "role": "user", + "content": "Please help me process and visualize this dataset.", "files": [{"path": "data.csv", "size": "10K"}] - }, + }, { - "role": "assistant", - "content": "I have processed the data and visualized it for you.", + "role": "assistant", + "content": "I have processed the data and visualized it for you.", "code_interpreter_call": "```python\nimport plotly.express as px\nimport pandas as pd\n\n# Load the data into a pandas dataframe\ndf = pd.read_csv('data.csv')\n\n# Create a scatter plot of rainfall vs wind direction\nfig = px.scatter(df, x='WindDir9am', y='Rainfall', color='WindDir3pm',\n labels={'WindDir9am': 'Wind Direction 9am', 'Rainfall': '\n\nRainfall', 'WindDir3pm': 'Wind Direction 3pm'},\n title='Rainfall vs Wind Direction',\n template='plotly_dark',\n width=600, height=500)\n\n# Add a hover effect to show the date\nfig.update_traces(hovertemplate='Date: %{text}
Wind Direction 9am: %{x}
Rainfall: %{y}
Wind Direction 3pm: %{marker.color}')\n\n# Show the plot\nfig.show()\n```" - }, + }, { - "role": "code_interpreter", + "role": "code_interpreter", "content": "![image](xxx.png)" - }, + }, { - "role": "assistant", + "role": "assistant", "content": "Since the code output is not included here, I cannot provide specific chart content. However, if the code executed correctly, it should display a polar plot with two filled areas representing the relationship between wind direction at 9 am and rainfall, and between wind direction at 3 pm and rainfall, respectively. The values for each direction are based on the average rainfall calculated from the provided dataset. The chart should have a clear title, a legend, and be intuitive for comparing rainfall with different wind directions. Given the use of a dark theme, the overall appearance of the chart should be bright lines and filled areas on a dark background." - }, + }, { - "role": "user", + "role": "user", "content": "I want to know today's weather in Shanghai" }, { - "role": "assistant", - "content": "Sure, I will search for the weather of Shanghai.", + "role": "assistant", + "content": "Sure, I will search for the weather of Shanghai.", "function_call": { - "name": "get_current_weather", + "name": "get_current_weather", "parameters": {"location": "Shanghai"} } - }, + }, { - "role": "function", - "name": "get_current_weather", + "role": "function", + "name": "get_current_weather", "content": "{'temperature': 22}" - }, + }, { - "role": "assistant", + "role": "assistant", "content": "The weather in Shanghai is 22 celsius" } - ], - + ], + "functions": [ { - "name": "get_current_weather", - "description": "Get the current weather in a given location", + "name": "get_current_weather", + "description": "Get the current weather in a given location", "parameters": { - "type": "object", + "type": "object", "properties": { "location": { - "type": "string", + "type": "string", "description": "The city and state, e.g. San Francisco, CA", - "unit": {"type": "string"}}, + "unit": {"type": "string"}}, "required": ["location"] } } } - ], - - "code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)"} \ No newline at end of file + ], + + "code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)" +} diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py index 8e7a0d0dd..e9d5796bc 100644 --- a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py @@ -2,22 +2,24 @@ from xtuner.types import HybridChatTemplate, TrainingHybridChatMessages - chat_template = HybridChatTemplate( system='<|im_start|>system\n{system}<|im_end|>\n', user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', assistant='{assistant}<|im_end|>\n', stop_words=['<|im_end|>'], image_token='', - files='<|im_start|>user name=file\n{files}<|im_end|>\n', - function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 - function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + files='<|im_start|>user name=file\n{files}<|im_end|>\n', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n', - code_interpreter_call='{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 - code_interpreter_result='<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 - code_interpreter='<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n' - -) + code_interpreter_call= + '{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + code_interpreter_result= + '<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + code_interpreter= + '<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n') agent_data = json.load(open('agent.json')) @@ -25,5 +27,7 @@ print(msg.apply_chat_template(chat_template)) from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2-chat-7b', trust_remote_code=True) -print(msg.tokenize(tokenizer, chat_template)) \ No newline at end of file + +tokenizer = AutoTokenizer.from_pretrained( + 'internlm/internlm2-chat-7b', trust_remote_code=True) +print(msg.tokenize(tokenizer, chat_template)) diff --git a/xtuner/model/auto.py b/xtuner/model/auto.py new file mode 100644 index 000000000..d525f80e4 --- /dev/null +++ b/xtuner/model/auto.py @@ -0,0 +1,20 @@ +from mmengine import Config + +from xtuner.model.base import BaseTune +from xtuner.registry import BUILDER + + +class AutoModel(): + + @classmethod + def from_config(cls, config: str): + config = Config.fromfile(config) + model: BaseTune = BUILDER.build(config.model) + return model + + @classmethod + def from_pretrained(cls, config: str, checkpoint: str): + config = Config.fromfile(config) + model: BaseTune = BUILDER.build(config.model) + model.load_checkpoint(checkpoint) + return model diff --git a/xtuner/model/base.py b/xtuner/model/base.py new file mode 100644 index 000000000..84c4c1879 --- /dev/null +++ b/xtuner/model/base.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractclassmethod, abstractmethod + +from mmengine.model import BaseModel + +from xtuner.types import HybridChatMessages, HybridChatTemplate + + +class BaseTune(BaseModel): + + def __init__(): + super().__init__() + + def init_weights(self): + """Parent class method. + + To avoid overwriting the loaded weights, overload it to an empty + function. + """ + pass + + def avoid_override_weights(self): + self._is_init = True + + @abstractmethod + @property + def chat_template(self) -> HybridChatTemplate: + pass + + @abstractmethod + @property + def llm(self): + pass + + @abstractmethod + @property + def tokenizer(self): + pass + + @abstractmethod + def gradient_checkpointing_enable(self): + pass + + def forward(self, data, data_samples=None, mode='loss'): + """Overload parent class method, only support training.""" + + if mode == 'loss': + return self.compute_loss(data) + else: + raise NotImplementedError( + f"{type(self)}'s forward is only supported for use during " + 'training. If you want to get predictions or chat, please ' + "directly use `llm`'s forward.") + + @abstractmethod + def chat(self, messages: HybridChatMessages, sample_params, streamer): + pass + + @abstractmethod + def save_checkpoint(self, *args, **kwargs): + pass + + @abstractmethod + def load_checkpoint(self, *args, **kwargs) -> 'BaseTune': + pass + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.llm, name) diff --git a/xtuner/model/encoders/__init__.py b/xtuner/model/encoders/__init__.py new file mode 100644 index 000000000..4a863bf0c --- /dev/null +++ b/xtuner/model/encoders/__init__.py @@ -0,0 +1,2 @@ +from .base import EncoderWrapper +from .llava import LlavaEncoderWrapper diff --git a/xtuner/model/encoders/base.py b/xtuner/model/encoders/base.py new file mode 100644 index 000000000..462c42091 --- /dev/null +++ b/xtuner/model/encoders/base.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractclassmethod, abstractmethod +from typing import List, Union + +import torch +from PIL import Image +from torch import nn + +_ImageType = Union[str, Image.Image] + + +class EncoderWrapper(nn.Module): + + def __init__(self): + super().__init__() + + @abstractmethod + @property + def encoder(self): + pass + + @abstractmethod + @property + def projector(self): + pass + + @abstractmethod + def post_init_proj(self, llm): + pass + + @abstractmethod + def preprocess(self, image: _ImageType) -> torch.Tensor: + pass + + @abstractmethod + def batch_infer(images: List[_ImageType]) -> List[torch.Tensor]: + pass + + @abstractmethod + def gradient_checkpointing_enable(self): + pass + + @abstractclassmethod + def save_checkpoint(self, *args, **kwargs): + pass + + @abstractclassmethod + def load_checkpoint(self, *args, **kwargs) -> 'EncoderWrapper': + pass + + @abstractclassmethod + def only_build_processor(self, *args, **kwargs): + pass diff --git a/xtuner/model/encoders/llava.py b/xtuner/model/encoders/llava.py new file mode 100644 index 000000000..1267fa8ba --- /dev/null +++ b/xtuner/model/encoders/llava.py @@ -0,0 +1,284 @@ +import base64 +import os +from collections import OrderedDict +from io import BytesIO +from typing import List, Literal, Optional, Union + +import requests +import torch +from accelerate import load_checkpoint_in_model +from peft import LoraConfig, PeftModel +from PIL import Image +from torch import nn +from transformers import AutoModel, CLIPImageProcessor, CLIPVisionModel + +from xtuner.dataset.utils import expand2square +from xtuner.utils.config import build_from_cfg_or_obj +from ..modules import ProjectorConfig, ProjectorModel +from ..utils import (LoadWoInit, get_peft_model_state_dict, + prepare_for_vision_lora) +from .base import BaseEncoder, _ImageType + + +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: + """load image from base64 format.""" + return Image.open(BytesIO(base64.b64decode(image))) + + +def load_image(image_url: str) -> Image.Image: + """load image from url, local path or openai GPT4V.""" + + headers = { + 'User-Agent': + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' + } + if image_url.startswith('http'): + response = requests.get(image_url, headers=headers) + response.raise_for_status() + + # Open the image using PIL + img = Image.open(BytesIO(response.content)) + elif image_url.startswith('data:image'): + img = load_image_from_base64(image_url.split(',')[1]) + else: + img = Image.open(image_url) + + return img + + +class LlavaEncoderWrapper(BaseEncoder): + + def __init__(self, + model_name_or_path: str, + lora=None, + select_layer: int = -2, + freeze_clip: bool = True): + + super().__init__() + + assert not (lora is not None and freeze_clip) + self._projector = None + self.proj_inited = False + self.freeze_clip = freeze_clip + self.select_layer = select_layer + + _res = self.build_processor_and_encoder(model_name_or_path) + self._processor, self._encoder = _res + + if self.freeze_clip: + self._encoder.requires_grad_(False) + + if lora: + self.with_lora = True + lora_conf = build_from_cfg_or_obj(lora, accept=LoraConfig) + self._encoder = prepare_for_vision_lora(self._encoder, lora_conf) + else: + self.with_lora = False + + def post_init_proj(self, config: ProjectorConfig): + self._projector = ProjectorModel(config) + self.proj_inited = True + + def build_processor_and_encoder(self, model_name_or_path: str): + with LoadWoInit: + processor = CLIPImageProcessor.from_pretrained(model_name_or_path) + encoder = CLIPVisionModel.from_pretrained( + model_name_or_path, torch_dtype=torch.float16) + return processor, encoder + + @classmethod + def only_build_processor(self, model_name_or_path: str): + return CLIPImageProcessor.from_pretrained(model_name_or_path) + + @property + def encoder(self) -> CLIPVisionModel: + return self._encoder + + @property + def processor(self): + return self._processor + + @property + def projector(self) -> ProjectorModel: + if self._projector: + return self._projector + else: + raise RuntimeError('The projector has not been created yet, ' + 'please execute `post_init_proj` first.') + + def gradient_checkpointing_enable(self): + # For backward compatibility + if hasattr(self.encoder, 'enable_input_require_grads'): + self.encoder.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.encoder.get_input_embeddings().register_forward_hook( + make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + self.encoder.gradient_checkpointing_enable() + + self.projector.enable_input_require_grads() + self.projector.gradient_checkpointing_enable() + + def preprocess(self, image: _ImageType) -> List[torch.Tensor]: + """Preprocess the input image, including expanding to square and + normalization. + + Args: + image (Image.Image): The input image need to be preprocessed. + Returns: + torch.Tensor: The preprocessed image tensor. + """ + + if isinstance(image, str): + image = load_image(image) + + if not isinstance(image, Image.Image): + raise TypeError(f"Don't support {type(image).__name__}, " + 'the image type must be `PIL.Image`.') + + processor = self.processor + image_mean = processor.image_mean + + background_color = tuple(int(x * 255) for x in image_mean) + squared_img = expand2square(image, background_color) + + processed = processor.preprocess(squared_img, return_tensors='pt') + img_tensor = processed['pixel_values'][0] # shape: 3, h, w + + # before this line, `img_tensor` is on cpu. + img_tensor = img_tensor.to(self.device).to(self.dtype) + return img_tensor + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + outputs = self.encoder(pixel_values, output_hidden_states=True) + embeddings = self.projector( + outputs.hidden_states[self.select_layer][:, 1:]) + return embeddings + + @torch.no_grad() + def batch_infer(self, images: List[_ImageType]) -> List[torch.Tensor]: + """Obtain the corresponding embeddings based on the images. + + Args: + images (List[Image.Image]): The input images. The data layout + for each image is (c, h, w). + Returns: + List[torch.Tensor]: The list of extracted features from images. + The data layout for each tensor should be (tokens, dims). + """ + + num_imgs = len(images) + + img_tensors = [self.process_img(img) for img in images] + + # Determine if all image sizes are consistent. + # TODO (pppppM): Confirm when the image size will be inconsistent + shape_consistant = all(x.shape == img_tensors[0].shape + for x in img_tensors) + + from transformers.modeling_outputs import BaseModelOutputWithPooling + + if shape_consistant: + # Batch inference when all image sizes are consistent. + # img_tensors[0] shape: (3, h, w) + # tensor shape: (num_imgs, 3, h, w) + tensor = torch.stack(img_tensors, dim=0) + + enc_out = self.visual_encoder(tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + + # feat shape: (num_imgs, tokens, dims) + feat = self.projector(enc_out.hidden_states[self.select_layer][:, + 1:]) + + # Split along the batch dimension + # The feature of each image corresponds to a tensor. + # len(features): num_imgs, features[0] shape:(1, tokens, dims) + features = torch.chunk(feat, num_imgs, dim=0) + + # per image feature's layout should be (tokens, dims) + features = [x.flatten(0, 1) for x in features] + + else: + features = [] + for tensor in img_tensors: + tensor: torch.Tensor + # The visual encoder requires a data layout of (bs, c, h, w). + # tensor shape: (3, h, w) batch_tensor shape: (1, 3, h, w) + batch_tensor = tensor.unsqueeze(0) + enc_out = self.visual_encoder( + batch_tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + # feat shape: (1, tokens, dims) + feat = self.projector( + enc_out.hidden_states[self.select_layer][:, 1:]) + features.append(feat) + + return features + + def save_checkpoint(self, dir: str): + + if self.with_lora: + _save_dir = os.path.join(dir, 'visual_encoder_adapter') + self.encoder.save_pretrained(_save_dir, safe_serialization=False) + + if not self.freeze_clip: + _save_dir = os.path.join(dir, 'visual_encoder') + self.encoder.save_pretrained(_save_dir, safe_serialization=False) + self.processor.save_pretrained(_save_dir) + + _save_dir = os.path.join(dir, 'projector') + self.projector.save_pretrained(_save_dir) + + def load_checkpoint(self, dir): + + if self.with_lora: + _ckpt_dir = os.path.join(dir, 'visual_encoder_adapter') + self.encoder.load_adapter(_ckpt_dir) + + if not self.freeze_clip: + _ckpt_dir = os.path.join(dir, 'visual_encoder') + load_checkpoint_in_model(self.encoder, _ckpt_dir) + load_checkpoint_in_model(self.processor, _ckpt_dir) + + if self.proj_inited: + _ckpt_dir = os.path.join(dir, 'projector') + load_checkpoint_in_model(self.projector, _ckpt_dir) + else: + ProjectorModel.from_pretrained(_ckpt_dir) + + def state_dict(self, *args, **kwargs): + + state_dict = super().state_dict(*args, **kwargs) + to_return = OrderedDict() + # Step 1. encoder + if self.with_lora: + to_return.update( + get_peft_model_state_dict(self.encoder, state_dict=state_dict)) + elif not self.freeze_clip: + to_return.update( + {k: v + for k, v in state_dict.items() if '_encoder.' in k}) + + # Step 2. Projector + to_return.update( + {k: v + for k, v in state_dict.items() if '_projector.' in k}) + + return to_return + + +# if __name__ == '__main__': +# img = load_image('llava.jpeg') +# model = VisionEncoderForDeploy('xtuner/llava-internlm-7b', +# 'openai/clip-vit-large-patch14-336') + +# model.cuda() +# model.eval() +# outputs = model([img]) diff --git a/xtuner/model/hybrid.py b/xtuner/model/hybrid.py index 0f0fc7e76..662db5e7e 100644 --- a/xtuner/model/hybrid.py +++ b/xtuner/model/hybrid.py @@ -1,107 +1,142 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import OrderedDict +from typing import Dict, Optional, Union import torch import torch.distributed as dist from mmengine.model import BaseModel from peft import LoraConfig from torch import nn +from transformers import PreTrainedModel, PreTrainedTokenizer from xtuner.registry import BUILDER +from xtuner.types import HybridChatMessages, HybridChatTemplate from xtuner.utils.config import build_from_cfg_or_obj -from .modules import ProjectorConfig, ProjectorModel, dispatch_modules -from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing, - get_peft_model_state_dict, prepare_for_llm_lora, - prepare_for_vision_lora, - smart_tokenizer_and_embedding_resize) +from .base import BaseTune +from .encoders import EncoderWrapper +from .modules import ProjectorConfig, dispatch_modules +from .utils import (LoadWoInit, get_peft_model_state_dict, + prepare_for_llm_lora, smart_tokenizer_and_embedding_resize) -class HybridFinetune(BaseModel): +class HybridFinetune(BaseTune): def __init__( self, - llm, - visual_encoder=None, - visual_select_layer=-2, - projector_depth=2, - pretrained_pth=None, - tokenizer=None, - llm_lora=None, - visual_encoder_lora=None, - freeze_llm=False, - freeze_visual_encoder=False, - use_activation_checkpointing=True, - use_varlen_attn=False, + llm: Union[PreTrainedModel, Dict], + tokenizer: Union[PreTrainedTokenizer, Dict], + chat_template: HybridChatTemplate, + visual_encoder: Optional[Union[EncoderWrapper, Dict]] = None, + audio_encoder: Optional[Union[EncoderWrapper, Dict]] = None, + video_encoder: Optional[Union[EncoderWrapper, Dict]] = None, + proj_depth: int = 2, + llm_lora: Optional[Union[LoraConfig, Dict]] = None, + freeze_llm: bool = False, + use_gradient_checkpointing: bool = True, + use_varlen_attn: bool = False, ): super().__init__() + tokenizer = build_from_cfg_or_obj( + tokenizer, accept=PreTrainedTokenizer) + smart_tokenizer_and_embedding_resize(tokenizer, self.llm) + self._tokenizer: PreTrainedModel = tokenizer + + self._chat_template = chat_template + # Build the base language model without initialization. # This will greatly reduce the time to build the model. with LoadWoInit(): - self.llm = build_from_cfg_or_obj(llm, nn.Module) - if visual_encoder: - visual_encoder = build_from_cfg_or_obj(visual_encoder, - nn.Module) - self.visual_encoder = visual_encoder - self.visual_select_layer = visual_select_layer - self.llm.config.use_cache = False - dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) - - if tokenizer is not None: - if isinstance(tokenizer, dict): - tokenizer = BUILDER.build(tokenizer) - smart_tokenizer_and_embedding_resize(tokenizer, self.llm) - - projector_config = ProjectorConfig( - visual_hidden_size=self.visual_encoder.config.hidden_size, - llm_hidden_size=self.llm.config.hidden_size, - depth=projector_depth) - self.projector = ProjectorModel(projector_config).to( - self.visual_encoder.dtype) + self._llm: PreTrainedModel = build_from_cfg_or_obj(llm, nn.Module) + self._llm.config.use_cache = False self.freeze_llm = freeze_llm - self.freeze_visual_encoder = freeze_visual_encoder if self.freeze_llm: self.llm.requires_grad_(False) - if self.freeze_visual_encoder: - self.visual_encoder.requires_grad_(False) - - if use_activation_checkpointing: - # For backward compatibility - enable_hf_model_gradient_checkpointing(self.llm) - enable_hf_model_gradient_checkpointing(self.visual_encoder) - - self.projector.enable_input_require_grads() - self.projector.gradient_checkpointing_enable() - - self.use_llm_lora = llm_lora is not None - self.use_visual_encoder_lora = visual_encoder_lora is not None + self.with_lora = llm_lora is not None # Prepare the model for LoRA if specified - if self.use_llm_lora: + if self.with_lora: lora_conf = build_from_cfg_or_obj(llm_lora, accept=LoraConfig) - self.llm = prepare_for_llm_lora(self.llm, lora_conf, - use_activation_checkpointing) - - if self.use_visual_encoder_lora: - lora_conf = build_from_cfg_or_obj( - visual_encoder_lora, accept=LoraConfig) - self.visual_encoder = prepare_for_vision_lora( - self.visual_encoder, lora_conf, use_activation_checkpointing) - self._is_init = True + self.llm = prepare_for_llm_lora(self.llm, lora_conf) # Determines whether to calculate attention based on the # seq_len dimension (use_varlen_attn = False) or the actual length of # the sequence. self.use_varlen_attn = use_varlen_attn + dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) + + if visual_encoder: + visual_encoder = build_from_cfg_or_obj(visual_encoder, + EncoderWrapper) + self.visual_encoder: EncoderWrapper = visual_encoder + _proj_config = ProjectorConfig( + visual_hidden_size=self.visual_encoder.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=proj_depth) + + self.visual_encoder.post_init_proj(_proj_config) + else: + self.visual_encoder = None + + if audio_encoder: + audio_encoder = build_from_cfg_or_obj(audio_encoder, + EncoderWrapper) + self.audio_encoder: EncoderWrapper = audio_encoder + _proj_config = ProjectorConfig( + visual_hidden_size=self.audio_encoder.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=proj_depth) + + self.audio_encoder.post_init_proj(_proj_config) + else: + self.audio_encoder = None + + if video_encoder: + video_encoder = build_from_cfg_or_obj(video_encoder, + EncoderWrapper) + self.video_encoder: EncoderWrapper = video_encoder + _proj_config = ProjectorConfig( + visual_hidden_size=self.video_encoder.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=proj_depth) + + self.video_encoder.post_init_proj(_proj_config) + else: + self.video_encoder = None + + if use_gradient_checkpointing: + self.gradient_checkpointing_enable() + + self.avoid_override_weights() + + @property + def llm(self) -> PreTrainedModel: + return self._llm + + @property + def tokenizer(self) -> PreTrainedTokenizer: + return self._tokenizer - def init_weights(self): - """Parent class method. + @property + def chat_template(self) -> HybridChatTemplate: + return self._chat_template - To avoid overwriting the loaded weights, overload it to an empty - function. - """ - pass + def gradient_checkpointing_enable(self): + # For backward compatibility + if hasattr(self.llm, 'enable_input_require_grads'): + self.llm.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.llm.get_input_embeddings().register_forward_hook( + make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + self.llm.gradient_checkpointing_enable() + self.visual_encoder.gradient_checkpointing_enable() def forward(self, data, data_samples=None, mode='loss'): """Overload parent class method, only support training.""" @@ -132,10 +167,7 @@ def _get_vision_embeds_and_ranges(self, data): batch_total_imgs = len(img_rngs) - visual_outputs = self.visual_encoder( - pixel_values, output_hidden_states=True) - features = self.projector( - visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) + features = self.visual_encoder(pixel_values) batch_total_imgs, real_img_tokens, _ = features.shape for i in range(batch_total_imgs): @@ -144,12 +176,12 @@ def _get_vision_embeds_and_ranges(self, data): img_emb = features[i] img_bs_ind = img_belongs[i] + # pack 导致的截断 if real_img_tokens == exp_img_tokens: img_embeds.append(img_emb) - elif not real_img_tokens == exp_img_tokens and img_start == 0: + elif real_img_tokens != exp_img_tokens and img_start == 0: img_embeds.append(img_emb[real_img_tokens - img_end:]) - elif (not real_img_tokens == exp_img_tokens - and img_end == tokens): + elif (real_img_tokens != exp_img_tokens and img_end == tokens): img_embeds.append(img_emb[:exp_img_tokens]) else: raise RuntimeError @@ -176,13 +208,8 @@ def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges): return flat_embeds + _empty_embeds - def compute_loss(self, data): - + def _compute_postion_ids(self, data): input_ids = data['input_ids'] - labels = data['labels'] - # position_ids = data['position_ids'] - attention_mask = data['attention_mask'] - # breakpoint() bs, tokens = input_ids.shape if self.use_varlen_attn: assert bs == 1 @@ -206,12 +233,23 @@ def compute_loss(self, data): position_ids = torch.arange(0, tokens).unsqueeze(0).repeat(bs, 1) + def compute_loss(self, data): + + input_ids = data['input_ids'] + labels = data['labels'] + attention_mask = data['attention_mask'] + + bs, tokens = input_ids.shape + position_ids = self._compute_postion_ids(data) + input_embeds = self.llm.get_input_embeddings()(input_ids) bs, tokens, dim = input_embeds.shape flat_embeds = input_embeds.flatten(0, 1) img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) + # audio_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) + # video_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, flat_bs_img_rngs) input_embeds = flat_embeds.reshape(bs, tokens, dim) @@ -229,32 +267,22 @@ def compute_loss(self, data): def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) to_return = OrderedDict() - # Step 1. visual_encoder - if self.use_visual_encoder_lora: - to_return.update( - get_peft_model_state_dict( - self.visual_encoder, state_dict=state_dict)) - elif not self.freeze_visual_encoder: - to_return.update({ - k: v - for k, v in state_dict.items() if 'visual_encoder.' in k - }) - # Step 2. LLM + + # Step 1. LLM if self.use_llm_lora: to_return.update( get_peft_model_state_dict(self.llm, state_dict=state_dict)) elif not self.freeze_llm: to_return.update( {k: v - for k, v in state_dict.items() if 'llm.' in k}) - # Step 3. Projector + for k, v in state_dict.items() if '_llm.' in k}) + + # Step 2. Visual Encoder to_return.update( {k: v - for k, v in state_dict.items() if 'projector.' in k}) + for k, v in state_dict.items() if 'visual_encoder.' in k}) return to_return - def __getattr__(self, name: str): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.llm, name) + def chat(self, messages: HybridChatMessages, sample_params, streamer): + + prompt = messages.apply_chat_template(self.chat_template) diff --git a/xtuner/types/__init__.py b/xtuner/types/__init__.py index cc230e8f8..79ea745af 100644 --- a/xtuner/types/__init__.py +++ b/xtuner/types/__init__.py @@ -1,6 +1,11 @@ +from .chat import (ChatMsg, HybridChatMessages, ImageContentItem, + TextContentItem) from .chat_template import HybridChatTemplate +from .sample_params import SampleParams from .train import RawTrainingData, TrainingHybridChatMessages __all__ = [ - 'HybridChatTemplate', 'RawTrainingData', 'TrainingHybridChatMessages' + 'ChatMsg', 'HybridChatMessages', 'ImageContentItem', 'TextContentItem', + 'HybridChatTemplate', 'SampleParams', 'RawTrainingData', + 'TrainingHybridChatMessages' ] diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py index cd0a4d4a7..0e48391f6 100644 --- a/xtuner/types/chat.py +++ b/xtuner/types/chat.py @@ -6,7 +6,7 @@ class TextContentItem(BaseModel): - type: Literal['text'] + type: Literal['text'] = 'text' text: str def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: @@ -14,7 +14,7 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: class ImageContentItem(BaseModel): - type: Literal['image_url'] + type: Literal['image_url'] = 'image_url' image_url: str def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: diff --git a/xtuner/types/sample_params.py b/xtuner/types/sample_params.py new file mode 100644 index 000000000..137809648 --- /dev/null +++ b/xtuner/types/sample_params.py @@ -0,0 +1,14 @@ +from typing import Optional + +from pydantic import BaseModel + + +class SampleParams(BaseModel): + + max_new_tokens: int = 512 + temperature: float = 0.1 + top_k: int = 40 + top_p: float = 0.75 + repetition_penalty: float = 1.0 + stop_words: list = [] + seed: Optional[int] = None