| """Keypoint–Argument Matching Endpoints""" |
|
|
| from fastapi import APIRouter, HTTPException |
| from datetime import datetime |
| import logging |
|
|
| from models import ( |
| PredictionRequest, |
| PredictionResponse, |
| BatchPredictionRequest, |
| BatchPredictionResponse |
| ) |
|
|
| from services import kpa_model_manager |
|
|
| router = APIRouter() |
| logger = logging.getLogger(__name__) |
|
|
|
|
| @router.get("/model-info", tags=["KPA"]) |
| async def get_model_info(): |
| """ |
| Return information about the loaded KPA model. |
| """ |
| try: |
| model_info = kpa_model_manager.get_model_info() |
|
|
| return { |
| "model_name": model_info.get("model_name", "unknown"), |
| "device": model_info.get("device", "cpu"), |
| "max_length": model_info.get("max_length", 256), |
| "num_labels": model_info.get("num_labels", 2), |
| "loaded": model_info.get("loaded", False), |
| "timestamp": datetime.now().isoformat() |
| } |
|
|
| except Exception as e: |
| logger.error(f"Model info error: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}") |
|
|
|
|
| @router.post("/predict", response_model=PredictionResponse, tags=["KPA"]) |
| async def predict_kpa(request: PredictionRequest): |
| """ |
| Predict keypoint-argument matching for a single pair. |
| |
| - **argument**: The argument text |
| - **key_point**: The key point to evaluate |
| |
| Returns the predicted class (apparie / non_apparie) with probabilities. |
| """ |
| try: |
| result = kpa_model_manager.predict( |
| argument=request.argument, |
| key_point=request.key_point |
| ) |
|
|
| response = PredictionResponse( |
| prediction=result["prediction"], |
| confidence=result["confidence"], |
| label=result["label"], |
| probabilities=result["probabilities"] |
| ) |
|
|
| logger.info( |
| f"KPA Prediction: {response.label} " |
| f"(conf={response.confidence:.4f})" |
| ) |
|
|
| return response |
|
|
| except Exception as e: |
| logger.error(f"KPA prediction error: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
|
|
| @router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"]) |
| async def batch_predict_kpa(request: BatchPredictionRequest): |
| """ |
| Predict keypoint-argument matching for multiple argument/keypoint pairs. |
| |
| - **pairs**: List of items to classify |
| |
| Returns predictions for all pairs. |
| """ |
| try: |
| results = [] |
|
|
| for item in request.pairs: |
| try: |
| result = kpa_model_manager.predict( |
| argument=item.argument, |
| key_point=item.key_point |
| ) |
|
|
| response = PredictionResponse( |
| prediction=result["prediction"], |
| confidence=result["confidence"], |
| label=result["label"], |
| probabilities=result["probabilities"] |
| ) |
|
|
| results.append(response) |
|
|
| except Exception: |
| results.append( |
| PredictionResponse( |
| prediction=-1, |
| confidence=0.0, |
| label="error", |
| probabilities={"error": 1.0} |
| ) |
| ) |
|
|
| |
| successful_predictions = [r for r in results if r.prediction != -1] |
| |
| if successful_predictions: |
| total_apparie = sum(1 for r in successful_predictions if r.prediction == 1) |
| total_non_apparie = sum(1 for r in successful_predictions if r.prediction == 0) |
| average_confidence = sum(r.confidence for r in successful_predictions) / len(successful_predictions) |
| |
| summary = { |
| "total_apparie": total_apparie, |
| "total_non_apparie": total_non_apparie, |
| "average_confidence": round(average_confidence, 4), |
| "successful_predictions": len(successful_predictions), |
| "failed_predictions": len(results) - len(successful_predictions) |
| } |
| else: |
| summary = { |
| "total_apparie": 0, |
| "total_non_apparie": 0, |
| "average_confidence": 0.0, |
| "successful_predictions": 0, |
| "failed_predictions": len(results) |
| } |
|
|
| logger.info(f"Batch KPA prediction completed — {len(results)} items processed") |
|
|
| return BatchPredictionResponse( |
| predictions=results, |
| total_processed=len(results), |
| summary=summary |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Batch KPA prediction error: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") |