MartyNattakit
Reset and redo the project
9613bf1
"""
pipeline/classifier.py
RoBERTa-based CWE classifier β€” wraps the fine-tuned model for inference.
Input: natural language vulnerability description (str)
Output: list of top-k CWE predictions with confidence scores
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
# ── Constants ────────────────────────────────────────────────────────────────
HF_REPO = "martynattakit/vuln-classifier-roberta"
MAX_LENGTH = 256
TOP_K = 3
# CWEs where the model is known to be unreliable (from eval report)
LOW_CONFIDENCE_CWES = {
"CWE-77", # 0 samples in training β€” never predicts correctly
"CWE-863", # F1 0.60 β€” overlaps with CWE-862
}
CWE_DESCRIPTIONS = {
"CWE-787": "Out-of-bounds Write",
"CWE-79": "Cross-site Scripting (XSS)",
"CWE-89": "SQL Injection",
"CWE-416": "Use After Free",
"CWE-78": "OS Command Injection",
"CWE-20": "Improper Input Validation",
"CWE-125": "Out-of-bounds Read",
"CWE-22": "Path Traversal",
"CWE-352": "Cross-Site Request Forgery (CSRF)",
"CWE-434": "Unrestricted File Upload",
"CWE-862": "Missing Authorization",
"CWE-476": "NULL Pointer Dereference",
"CWE-287": "Improper Authentication",
"CWE-190": "Integer Overflow",
"CWE-502": "Deserialization of Untrusted Data",
"CWE-77": "Command Injection",
"CWE-119": "Buffer Overflow (Generic)",
"CWE-798": "Hardcoded Credentials",
"CWE-918": "Server-Side Request Forgery (SSRF)",
"CWE-306": "Missing Authentication",
"CWE-362": "Race Condition",
"CWE-269": "Improper Privilege Management",
"CWE-94": "Code Injection",
"CWE-863": "Incorrect Authorization",
"CWE-276": "Incorrect Default Permissions",
}
SEVERITY_MAP = {
"CWE-787": "HIGH",
"CWE-79": "MEDIUM",
"CWE-89": "HIGH",
"CWE-416": "HIGH",
"CWE-78": "HIGH",
"CWE-20": "MEDIUM",
"CWE-125": "MEDIUM",
"CWE-22": "HIGH",
"CWE-352": "MEDIUM",
"CWE-434": "HIGH",
"CWE-862": "HIGH",
"CWE-476": "MEDIUM",
"CWE-287": "HIGH",
"CWE-190": "MEDIUM",
"CWE-502": "HIGH",
"CWE-77": "HIGH",
"CWE-119": "HIGH",
"CWE-798": "CRITICAL",
"CWE-918": "HIGH",
"CWE-306": "CRITICAL",
"CWE-362": "MEDIUM",
"CWE-269": "HIGH",
"CWE-94": "HIGH",
"CWE-863": "HIGH",
"CWE-276": "MEDIUM",
}
# ── Classifier class ─────────────────────────────────────────────────────────
class CWEClassifier:
"""
Wraps the fine-tuned RoBERTa model for CWE classification.
Lazy-loaded on first call β€” fast import, slow first inference.
"""
def __init__(self, repo: str = HF_REPO, device: Optional[str] = None):
self.repo = repo
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._pipeline = None
def _load(self):
"""Lazy load the model on first inference call."""
if self._pipeline is not None:
return
print(f"[CWEClassifier] Loading model from {self.repo}...")
self._pipeline = pipeline(
"text-classification",
model=self.repo,
tokenizer=self.repo,
device=0 if self.device == "cuda" else -1,
top_k=TOP_K,
truncation=True,
max_length=MAX_LENGTH,
)
print("[CWEClassifier] Model loaded.")
def classify(self, text: str) -> dict:
"""
Classify a vulnerability description.
Args:
text: Natural language vulnerability description.
Should follow the structured format:
"This function performs X on Y without Z, which may allow..."
Returns:
{
"top1": { "cwe_id", "description", "severity", "confidence" },
"top3": [ { "cwe_id", "description", "severity", "confidence" }, ... ],
"warning": str | None, # set if top1 is a known weak class
"raw_scores": { cwe_id: score, ... }
}
"""
self._load()
if not text or not text.strip():
raise ValueError("Input text cannot be empty.")
raw = self._pipeline(text[:MAX_LENGTH * 4]) # rough char limit before tokenizer
predictions = raw[0] # list of {label, score}
results = []
for pred in predictions:
cwe_id = pred["label"]
confidence = round(pred["score"], 4)
results.append({
"cwe_id": cwe_id,
"description": CWE_DESCRIPTIONS.get(cwe_id, "Unknown"),
"severity": SEVERITY_MAP.get(cwe_id, "UNKNOWN"),
"confidence": confidence,
})
top1 = results[0]
# Warn if top1 is a known unreliable class
warning = None
if top1["cwe_id"] in LOW_CONFIDENCE_CWES:
warning = (
f"{top1['cwe_id']} has limited training data β€” "
f"confidence may be unreliable. Review top-3 predictions."
)
# Also warn if top1 confidence is low
if top1["confidence"] < 0.5 and warning is None:
warning = (
f"Low confidence ({top1['confidence']:.0%}) β€” "
f"input may not match known vulnerability patterns."
)
return {
"top1": top1,
"top3": results,
"warning": warning,
"raw_scores": {p["label"]: round(p["score"], 4) for p in predictions},
}
# ── Module-level singleton ───────────────────────────────────────────────────
_classifier: Optional[CWEClassifier] = None
def get_classifier() -> CWEClassifier:
"""Return the module-level singleton classifier."""
global _classifier
if _classifier is None:
_classifier = CWEClassifier()
return _classifier
def classify(text: str) -> dict:
"""Convenience function β€” classify without instantiating manually."""
return get_classifier().classify(text)
# ── CLI test ─────────────────────────────────────────────────────────────────
if __name__ == "__main__":
test_cases = [
"This function constructs a SQL query by concatenating user-controlled input without parameterization, which may allow an attacker to inject arbitrary SQL commands.",
"This function reflects user-supplied data into the HTTP response without encoding, which may allow an attacker to inject malicious scripts.",
"This function performs operations on a memory buffer without verifying bounds, which may allow an attacker to read or write out-of-bounds memory.",
]
clf = CWEClassifier()
for text in test_cases:
result = clf.classify(text)
print(f"Input: {text[:60]}...")
print(f" Top-1: {result['top1']['cwe_id']} ({result['top1']['severity']}) β€” {result['top1']['confidence']:.1%}")
if result["warning"]:
print(f" ⚠ {result['warning']}")
print()