| import os |
| import base64 |
| import tempfile |
| from inference import Chat, get_conv_template |
| import torch |
|
|
| def save_base64_to_tempfile(base64_str, suffix): |
| header_removed = base64_str |
| |
| if ',' in base64_str: |
| header_removed = base64_str.split(',', 1)[1] |
|
|
| data = base64.b64decode(header_removed) |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) |
| tmp.write(data) |
| tmp.close() |
| return tmp.name |
|
|
| class EndpointHandler: |
| def __init__(self, model_path: str): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.chat = Chat( |
| model_path=model_path, |
| device=device, |
| num_gpus=1, |
| max_new_tokens=1024, |
| load_8bit=False, |
| ) |
| self.vision_feature = None |
| self.modal_type = "text" |
| self.chat.conv = get_conv_template("husky").copy() |
|
|
| def __call__(self, data: dict) -> dict: |
| |
| if data.get("clear_history"): |
| self.chat.conv = get_conv_template("husky").copy() |
| self.vision_feature = None |
| self.modal_type = "text" |
|
|
| prompt = data.get("inputs", "") |
| image_input = data.get("image", None) |
| video_input = data.get("video", None) |
|
|
| print("📨 收到 prompt:", repr(prompt)) |
| |
| |
| if image_input: |
| if os.path.exists(image_input): |
| |
| self.vision_feature = self.chat.get_image_embedding(image_input) |
| else: |
| |
| tmp_path = save_base64_to_tempfile(image_input, suffix=".jpg") |
| self.vision_feature = self.chat.get_image_embedding(tmp_path) |
| os.unlink(tmp_path) |
| self.modal_type = "image" |
| self.chat.conv = get_conv_template("husky").copy() |
|
|
| elif video_input: |
| if os.path.exists(video_input): |
| self.vision_feature = self.chat.get_video_embedding(video_input) |
| else: |
| tmp_path = save_base64_to_tempfile(video_input, suffix=".mp4") |
| print("📼 保存临时视频路径:", tmp_path) |
| self.vision_feature = self.chat.get_video_embedding(tmp_path) |
| os.unlink(tmp_path) |
| self.modal_type = "video" |
| self.chat.conv = get_conv_template("husky").copy() |
| |
| |
| if isinstance(self.vision_feature, torch.Tensor): |
| print("📏 视觉特征张量 shape:", self.vision_feature.shape) |
| else: |
| print("❌ self.vision_feature 不是张量,类型:", type(self.vision_feature)) |
|
|
| else: |
| self.modal_type = "text" |
| self.vision_feature = None |
|
|
| try: |
| |
| print("🧠 当前 modal_type:", self.modal_type) |
| print("🧠 是否有视觉特征:", self.vision_feature is not None) |
|
|
| conversations = self.chat.ask(prompt, self.chat.conv, modal_type=self.modal_type) |
| output = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type) |
|
|
| |
| print("📤 推理输出:", repr(output.strip())) |
|
|
| self.chat.conv.messages[-1][1] = output.strip() |
| return {"output": output.strip()} |
|
|
| except Exception as e: |
| |
| import traceback |
| print("❌ 推理出错:") |
| traceback.print_exc() |
| return {"error": str(e)} |
|
|