File size: 1,240 Bytes
8337f1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# query_classifier/query_classifier.py
from transformers import pipeline
import os

# Path to the trained model – adjust if you renamed the folder
# Currently it points to "../trained_model" (one level up, folder named trained_model)
CLASSIFIER_PATH = os.path.join(os.path.dirname(__file__), "..", "trained_model")

class QueryClassifier:
    def __init__(self):
        print("Loading query classifier...")
        if not os.path.exists(CLASSIFIER_PATH):
            raise FileNotFoundError(
                f"Classifier not found at {CLASSIFIER_PATH}. "
                "Please train the model and place it in a folder named 'trained_model'."
            )
        self.pipeline = pipeline(
            "text-classification",
            model=CLASSIFIER_PATH,
            tokenizer=CLASSIFIER_PATH
        )
        self.label_map = {"LABEL_0": "general", "LABEL_1": "realtime"}
        print("Classifier ready.")

    def classify(self, query: str) -> tuple[str, float]:
        result = self.pipeline(query, truncation=True, max_length=128)[0]
        label = result["label"]
        confidence = result["score"]
        category = self.label_map.get(label, "unknown")
        return category, confidence