File size: 6,641 Bytes
ef0d2f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3a8f3a
ef0d2f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3a8f3a
 
 
 
 
 
ef0d2f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
api/main.py
FastAPI application β€” exposes the vulnerability classification pipeline
as an HTTP API. Single endpoint: POST /classify

Run locally:
    uvicorn api.main:app --reload --port 8000

Run in HF Spaces:
    uvicorn api.main:app --host 0.0.0.0 --port 7860
"""

from __future__ import annotations
import time
from contextlib import asynccontextmanager
from typing import Optional

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field

from pipeline.router import get_router

# ── Lifespan β€” warm up classifier on startup ──────────────────────────────────
# RoBERTa (125MB) loads fast β€” warm it up at startup so first request isn't slow
# Qwen (5GB) is lazy-loaded on first code input β€” too large to preload

@asynccontextmanager
async def lifespan(app: FastAPI):
    print("[API] Warming up RoBERTa classifier...")
    router = get_router()
    router._get_classifier()  # preload RoBERTa only
    print("[API] Ready.")
    yield
    print("[API] Shutting down.")


# ── App ───────────────────────────────────────────────────────────────────────

app = FastAPI(
    title="CodeSentinel API",
    description=(
        "Vulnerability classification API. "
        "Classifies code snippets and CVE descriptions into CWE categories, "
        "with optional ATLAS AI/ML attack pattern matching."
    ),
    version="0.1.0",
    lifespan=lifespan,
)

# Allow frontend (HF Spaces, localhost) to call the API
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # tighten this in production
    allow_methods=["POST", "GET"],
    allow_headers=["*"],
)


# ── Schemas ───────────────────────────────────────────────────────────────────

class ClassifyRequest(BaseModel):
    input: str = Field(
        ...,
        min_length=10,
        max_length=8000,
        description="Raw input β€” code snippet, CVE description, or bug report.",
        examples=[
            "def get_user(name): return db.execute('SELECT * FROM users WHERE name=' + name)"
        ],
    )


class CWEPrediction(BaseModel):
    cwe_id:      str
    description: str
    severity:    str
    confidence:  float


class ATLASMatch(BaseModel):
    atlas_id:        str
    technique:       str
    tactic:          str
    confidence:      str
    matched_signals: list[str]
    description:     str
    mitigations:     list[str]
    real_world:      Optional[str]
    reasoning:       str


class ClassifyResponse(BaseModel):
    # Primary result
    cwe_id:       str
    cwe_name:     str
    severity:     str
    confidence:   float

    # Explanation (Qwen's structured description or original input)
    description:  str

    # Top-3 alternatives (excluding top-1)
    alternatives: list[CWEPrediction]

    # ATLAS match β€” None if no AI/ML signals detected
    atlas_match:  Optional[ATLASMatch]

    # Metadata
    input_type:   str   # "code" or "text"
    warning:      Optional[str]
    elapsed_s:    float


class HealthResponse(BaseModel):
    status:  str
    version: str
    models:  dict


class ErrorResponse(BaseModel):
    error:   str
    detail:  Optional[str]


@app.get("/health", response_model=HealthResponse)
async def health():
    """Health check β€” returns model load status."""
    router = get_router()
    return {
        "status":  "ok",
        "version": "0.1.0",
        "models": {
            "roberta":       "loaded" if router._classifier is not None else "not loaded",
            "qwen":          "loaded" if router._code_analyzer is not None else "not loaded (lazy)",
            "atlas_matcher": "loaded" if router._atlas_matcher is not None else "not loaded (lazy)",
        }
    }


@app.post(
    "/classify",
    response_model=ClassifyResponse,
    responses={
        400: {"model": ErrorResponse, "description": "Invalid input"},
        500: {"model": ErrorResponse, "description": "Internal classification error"},
    },
)
async def classify(request: ClassifyRequest):
    """
    Classify a vulnerability input.

    - **Code snippets** β†’ Qwen analyzes β†’ RoBERTa classifies β†’ CWE result
    - **CVE/text descriptions** β†’ RoBERTa classifies directly β†’ CWE result
    - **AI/ML-related inputs** β†’ also runs ATLAS pattern matcher

    Returns a unified output card with CWE ID, severity, explanation,
    remediation hints, and optional ATLAS technique match.
    """
    try:
        router = get_router()
        result = router.run(request.input)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Classification failed: {str(e)}"
        )

    # Build alternatives list
    alternatives = [
        CWEPrediction(
            cwe_id=alt["cwe_id"],
            description=alt["description"],
            severity=alt["severity"],
            confidence=alt["confidence"],
        ) for alt in result.get("alternatives", [])
    ]

    # Build ATLAS match if present
    atlas_match = None
    if result.get("atlas_match"):
        atlas_match = ATLASMatch(**result["atlas_match"])

    return ClassifyResponse(
        cwe_id=result["cwe_id"],
        cwe_name=result["cwe_name"],
        severity=result["severity"],
        confidence=result["confidence"],
        description=result["description"],
        alternatives=alternatives,
        atlas_match=atlas_match,
        input_type=result["input_type"],
        warning=result.get("warning"),
        elapsed_s=result["elapsed_s"],
    )


# ── Error handlers ────────────────────────────────────────────────────────────

@app.exception_handler(404)
async def not_found_handler(request: Request, exc):
    return JSONResponse(
        status_code=404,
        content={"error": "Not found", "detail": str(request.url)}
    )


@app.exception_handler(500)
async def server_error_handler(request: Request, exc):
    return JSONResponse(
        status_code=500,
        content={"error": "Internal server error", "detail": "Check server logs."}
    )