| import torch |
| from transformers import RobertaTokenizer, RobertaModel |
| import numpy as np |
| from scipy.special import softmax |
| import gradio as gr |
| import re |
| from huggingface_hub import hf_hub_download |
|
|
| |
| class CodeClassifier(torch.nn.Module): |
| def __init__(self, base_model, num_labels=6): |
| super(CodeClassifier, self).__init__() |
| self.base = base_model |
| self.reduction = torch.nn.Linear(768, 512) |
| self.classifier = torch.nn.Linear(512, num_labels) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.base(input_ids=input_ids, attention_mask=attention_mask) |
| reduced = self.reduction(outputs.pooler_output) |
| return self.classifier(reduced) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') |
| base_model = RobertaModel.from_pretrained('microsoft/codebert-base') |
| |
| model = CodeClassifier(base_model) |
| |
| checkpoint_path = hf_hub_download(repo_id="martynattakit/CodeSentinel-Model", filename="best_model.pt") |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
| |
| model_state = checkpoint.get('model_state_dict', checkpoint) |
| model.load_state_dict(model_state, strict=False) |
| print("Loaded state dict keys:", model.state_dict().keys()) |
| print("Classifier weight shape:", model.classifier.weight.shape) |
| model.eval() |
| model.to(device) |
|
|
| |
| label_map = { |
| 0: ('none', 'No Vulnerability Detected'), |
| 1: ('cwe-121', 'Stack-based Buffer Overflow'), |
| 2: ('cwe-78', 'OS Command Injection'), |
| 3: ('cwe-190', 'Integer Overflow or Wraparound'), |
| 4: ('cwe-191', 'Integer Underflow'), |
| 5: ('cwe-122', 'Heap-based Buffer Overflow') |
| } |
|
|
| def load_c_file(file): |
| try: |
| if file is None: |
| return "" |
| with open(file.name, 'r', encoding='utf-8') as f: |
| content = f.read() |
| return content |
| except Exception as e: |
| return f"Error reading file: {str(e)}" |
|
|
| def clean_code(code): |
| code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) |
| code = re.sub(r'//.*$', '', code, flags=re.MULTILINE) |
| code = ' '.join(code.split()) |
| return code |
|
|
| def evaluate_code(code): |
| try: |
| if len(code) >= 1500000: |
| return "Code too large" |
| |
| cleaned_code = clean_code(code) |
| inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device) |
| print("Input shape:", inputs['input_ids'].shape) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| print("Raw logits:", outputs.cpu().numpy()) |
| probs = softmax(outputs.cpu().numpy(), axis=1) |
| pred = np.argmax(probs, axis=1)[0] |
| cwe, description = label_map[pred] |
| return f"{cwe} {description}" |
| |
| except Exception as e: |
| return f"Error during prediction: {str(e)}" |
|
|
| with gr.Blocks() as web: |
| with gr.Row(): |
| with gr.Column(scale=1): |
| code_box = gr.Textbox(lines=20, label="** C/C++ Code", placeholder="Paste your C or C++ code here...") |
| with gr.Column(scale=1): |
| cc_file = gr.File(label="Upload C/C++ File (.c or .cpp)", file_types=[".c", ".cpp"]) |
| check_btn = gr.Button("Check") |
|
|
| with gr.Row(): |
| gr.Markdown("### Result:") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| label_box = gr.Textbox(label="Vulnerability", interactive=False) |
|
|
| cc_file.change(fn=load_c_file, inputs=cc_file, outputs=code_box) |
| check_btn.click(fn=evaluate_code, inputs=code_box, outputs=[label_box]) |
|
|
| web.launch() |