| """ |
| pipeline/classifier.py |
| RoBERTa-based CWE classifier β wraps the fine-tuned model for inference. |
| Input: natural language vulnerability description (str) |
| Output: list of top-k CWE predictions with confidence scores |
| """ |
|
|
| from __future__ import annotations |
| import json |
| from pathlib import Path |
| from typing import Optional |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
| |
|
|
| HF_REPO = "martynattakit/vuln-classifier-roberta" |
| MAX_LENGTH = 256 |
| TOP_K = 3 |
|
|
| |
| LOW_CONFIDENCE_CWES = { |
| "CWE-77", |
| "CWE-863", |
| } |
|
|
| CWE_DESCRIPTIONS = { |
| "CWE-787": "Out-of-bounds Write", |
| "CWE-79": "Cross-site Scripting (XSS)", |
| "CWE-89": "SQL Injection", |
| "CWE-416": "Use After Free", |
| "CWE-78": "OS Command Injection", |
| "CWE-20": "Improper Input Validation", |
| "CWE-125": "Out-of-bounds Read", |
| "CWE-22": "Path Traversal", |
| "CWE-352": "Cross-Site Request Forgery (CSRF)", |
| "CWE-434": "Unrestricted File Upload", |
| "CWE-862": "Missing Authorization", |
| "CWE-476": "NULL Pointer Dereference", |
| "CWE-287": "Improper Authentication", |
| "CWE-190": "Integer Overflow", |
| "CWE-502": "Deserialization of Untrusted Data", |
| "CWE-77": "Command Injection", |
| "CWE-119": "Buffer Overflow (Generic)", |
| "CWE-798": "Hardcoded Credentials", |
| "CWE-918": "Server-Side Request Forgery (SSRF)", |
| "CWE-306": "Missing Authentication", |
| "CWE-362": "Race Condition", |
| "CWE-269": "Improper Privilege Management", |
| "CWE-94": "Code Injection", |
| "CWE-863": "Incorrect Authorization", |
| "CWE-276": "Incorrect Default Permissions", |
| } |
|
|
| SEVERITY_MAP = { |
| "CWE-787": "HIGH", |
| "CWE-79": "MEDIUM", |
| "CWE-89": "HIGH", |
| "CWE-416": "HIGH", |
| "CWE-78": "HIGH", |
| "CWE-20": "MEDIUM", |
| "CWE-125": "MEDIUM", |
| "CWE-22": "HIGH", |
| "CWE-352": "MEDIUM", |
| "CWE-434": "HIGH", |
| "CWE-862": "HIGH", |
| "CWE-476": "MEDIUM", |
| "CWE-287": "HIGH", |
| "CWE-190": "MEDIUM", |
| "CWE-502": "HIGH", |
| "CWE-77": "HIGH", |
| "CWE-119": "HIGH", |
| "CWE-798": "CRITICAL", |
| "CWE-918": "HIGH", |
| "CWE-306": "CRITICAL", |
| "CWE-362": "MEDIUM", |
| "CWE-269": "HIGH", |
| "CWE-94": "HIGH", |
| "CWE-863": "HIGH", |
| "CWE-276": "MEDIUM", |
| } |
|
|
| |
|
|
| class CWEClassifier: |
| """ |
| Wraps the fine-tuned RoBERTa model for CWE classification. |
| Lazy-loaded on first call β fast import, slow first inference. |
| """ |
|
|
| def __init__(self, repo: str = HF_REPO, device: Optional[str] = None): |
| self.repo = repo |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self._pipeline = None |
|
|
| def _load(self): |
| """Lazy load the model on first inference call.""" |
| if self._pipeline is not None: |
| return |
|
|
| print(f"[CWEClassifier] Loading model from {self.repo}...") |
| self._pipeline = pipeline( |
| "text-classification", |
| model=self.repo, |
| tokenizer=self.repo, |
| device=0 if self.device == "cuda" else -1, |
| top_k=TOP_K, |
| truncation=True, |
| max_length=MAX_LENGTH, |
| ) |
| print("[CWEClassifier] Model loaded.") |
|
|
| def classify(self, text: str) -> dict: |
| """ |
| Classify a vulnerability description. |
| |
| Args: |
| text: Natural language vulnerability description. |
| Should follow the structured format: |
| "This function performs X on Y without Z, which may allow..." |
| |
| Returns: |
| { |
| "top1": { "cwe_id", "description", "severity", "confidence" }, |
| "top3": [ { "cwe_id", "description", "severity", "confidence" }, ... ], |
| "warning": str | None, # set if top1 is a known weak class |
| "raw_scores": { cwe_id: score, ... } |
| } |
| """ |
| self._load() |
|
|
| if not text or not text.strip(): |
| raise ValueError("Input text cannot be empty.") |
|
|
| raw = self._pipeline(text[:MAX_LENGTH * 4]) |
| predictions = raw[0] |
|
|
| results = [] |
| for pred in predictions: |
| cwe_id = pred["label"] |
| confidence = round(pred["score"], 4) |
| results.append({ |
| "cwe_id": cwe_id, |
| "description": CWE_DESCRIPTIONS.get(cwe_id, "Unknown"), |
| "severity": SEVERITY_MAP.get(cwe_id, "UNKNOWN"), |
| "confidence": confidence, |
| }) |
|
|
| top1 = results[0] |
|
|
| |
| warning = None |
| if top1["cwe_id"] in LOW_CONFIDENCE_CWES: |
| warning = ( |
| f"{top1['cwe_id']} has limited training data β " |
| f"confidence may be unreliable. Review top-3 predictions." |
| ) |
|
|
| |
| if top1["confidence"] < 0.5 and warning is None: |
| warning = ( |
| f"Low confidence ({top1['confidence']:.0%}) β " |
| f"input may not match known vulnerability patterns." |
| ) |
|
|
| return { |
| "top1": top1, |
| "top3": results, |
| "warning": warning, |
| "raw_scores": {p["label"]: round(p["score"], 4) for p in predictions}, |
| } |
|
|
|
|
| |
|
|
| _classifier: Optional[CWEClassifier] = None |
|
|
| def get_classifier() -> CWEClassifier: |
| """Return the module-level singleton classifier.""" |
| global _classifier |
| if _classifier is None: |
| _classifier = CWEClassifier() |
| return _classifier |
|
|
|
|
| def classify(text: str) -> dict: |
| """Convenience function β classify without instantiating manually.""" |
| return get_classifier().classify(text) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| test_cases = [ |
| "This function constructs a SQL query by concatenating user-controlled input without parameterization, which may allow an attacker to inject arbitrary SQL commands.", |
| "This function reflects user-supplied data into the HTTP response without encoding, which may allow an attacker to inject malicious scripts.", |
| "This function performs operations on a memory buffer without verifying bounds, which may allow an attacker to read or write out-of-bounds memory.", |
| ] |
|
|
| clf = CWEClassifier() |
| for text in test_cases: |
| result = clf.classify(text) |
| print(f"Input: {text[:60]}...") |
| print(f" Top-1: {result['top1']['cwe_id']} ({result['top1']['severity']}) β {result['top1']['confidence']:.1%}") |
| if result["warning"]: |
| print(f" β {result['warning']}") |
| print() |
|
|