|
|
| from transformers import pipeline
|
| import os
|
|
|
|
|
|
|
| CLASSIFIER_PATH = os.path.join(os.path.dirname(__file__), "..", "trained_model")
|
|
|
| class QueryClassifier:
|
| def __init__(self):
|
| print("Loading query classifier...")
|
| if not os.path.exists(CLASSIFIER_PATH):
|
| raise FileNotFoundError(
|
| f"Classifier not found at {CLASSIFIER_PATH}. "
|
| "Please train the model and place it in a folder named 'trained_model'."
|
| )
|
| self.pipeline = pipeline(
|
| "text-classification",
|
| model=CLASSIFIER_PATH,
|
| tokenizer=CLASSIFIER_PATH
|
| )
|
| self.label_map = {"LABEL_0": "general", "LABEL_1": "realtime"}
|
| print("Classifier ready.")
|
|
|
| def classify(self, query: str) -> tuple[str, float]:
|
| result = self.pipeline(query, truncation=True, max_length=128)[0]
|
| label = result["label"]
|
| confidence = result["score"]
|
| category = self.label_map.get(label, "unknown")
|
| return category, confidence |