File size: 1,240 Bytes
8337f1a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | # query_classifier/query_classifier.py
from transformers import pipeline
import os
# Path to the trained model – adjust if you renamed the folder
# Currently it points to "../trained_model" (one level up, folder named trained_model)
CLASSIFIER_PATH = os.path.join(os.path.dirname(__file__), "..", "trained_model")
class QueryClassifier:
def __init__(self):
print("Loading query classifier...")
if not os.path.exists(CLASSIFIER_PATH):
raise FileNotFoundError(
f"Classifier not found at {CLASSIFIER_PATH}. "
"Please train the model and place it in a folder named 'trained_model'."
)
self.pipeline = pipeline(
"text-classification",
model=CLASSIFIER_PATH,
tokenizer=CLASSIFIER_PATH
)
self.label_map = {"LABEL_0": "general", "LABEL_1": "realtime"}
print("Classifier ready.")
def classify(self, query: str) -> tuple[str, float]:
result = self.pipeline(query, truncation=True, max_length=128)[0]
label = result["label"]
confidence = result["score"]
category = self.label_map.get(label, "unknown")
return category, confidence |