| | import torch |
| | from model import get_model |
| | from torchvision.transforms import ToTensor |
| | from PIL import Image |
| | import io |
| | import os |
| |
|
| | |
| | NUM_CLASSES = 4 |
| | CONFIDENCE_THRESHOLD = 0.5 |
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize the handler: load the model. |
| | """ |
| | |
| | self.model_weights_path = os.path.join(path, "model.pt") |
| | self.model = get_model(NUM_CLASSES).to(DEVICE) |
| | checkpoint = torch.load(self.model_weights_path, map_location=DEVICE) |
| | self.model.load_state_dict(checkpoint["model_state_dict"]) |
| | self.model.eval() |
| |
|
| | |
| | self.preprocess = ToTensor() |
| |
|
| | |
| | self.label_map = {1: "yellow", 2: "red", 3: "blue"} |
| |
|
| | def preprocess_frame(self, image_bytes): |
| | """ |
| | Convert raw binary image data to a tensor. |
| | """ |
| | |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE) |
| | return image_tensor |
| |
|
| | def __call__(self, data): |
| | """ |
| | Process incoming raw binary image data. |
| | """ |
| | try: |
| | if "body" not in data: |
| | |
| | return [{"error": "No image data provided in request."}] |
| |
|
| | image_bytes = data["body"] |
| | image_tensor = self.preprocess_frame(image_bytes) |
| |
|
| | with torch.no_grad(): |
| | predictions = self.model(image_tensor) |
| |
|
| | |
| | boxes = predictions[0]["boxes"].cpu().tolist() |
| | labels = predictions[0]["labels"].cpu().tolist() |
| | scores = predictions[0]["scores"].cpu().tolist() |
| |
|
| | |
| | results = [] |
| | for box, label, score in zip(boxes, labels, scores): |
| | if score >= CONFIDENCE_THRESHOLD: |
| | x1, y1, x2, y2 = map(int, box) |
| | label_text = self.label_map.get(label, "unknown") |
| | results.append({ |
| | "label": label_text, |
| | "score": round(score, 2), |
| | "box": { |
| | "xmin": x1, |
| | "ymin": y1, |
| | "xmax": x2, |
| | "ymax": y2 |
| | } |
| | }) |
| |
|
| | |
| | return results |
| | except Exception as e: |
| | |
| | return [{"error": str(e)}] |