| |
| """ |
| Backend Code Generation Model Training Pipeline |
| =============================================== |
| |
| A comprehensive training pipeline for building an AI model that generates |
| framework-agnostic backend code with full application scaffolding. |
| |
| Features: |
| - Data collection from multiple sources |
| - Multi-framework support (Express.js, FastAPI, Django, Flask, etc.) |
| - Full application scaffolding generation |
| - Model training with transformer architecture |
| - Evaluation and benchmarking tools |
| """ |
|
|
| import os |
| import json |
| import logging |
| import asyncio |
| import aiohttp |
| import pandas as pd |
| import numpy as np |
| from typing import Dict, List, Optional, Tuple, Any |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import ( |
| AutoTokenizer, AutoModelForCausalLM, TrainingArguments, |
| Trainer, DataCollatorForLanguageModeling |
| ) |
| from datasets import Dataset as HFDataset |
| import ast |
| import subprocess |
| import tempfile |
| from concurrent.futures import ThreadPoolExecutor |
| import requests |
| import time |
| import random |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class CodeExample: |
| """Represents a single training example""" |
| description: str |
| requirements: List[str] |
| framework: str |
| language: str |
| code_files: Dict[str, str] |
| project_structure: Dict[str, Any] |
| metadata: Dict[str, Any] |
|
|
|
|
| class DataCollector: |
| """Collects training data from various sources""" |
|
|
| def __init__(self): |
| self.github_token = os.getenv('GITHUB_TOKEN') |
| self.collected_examples: List[CodeExample] = [] |
|
|
| async def collect_github_repositories(self, queries: List[str], max_repos: int = 100): |
| """Collect backend projects from GitHub""" |
| logger.info("Starting GitHub repository collection...") |
|
|
| headers = {'Authorization': f'token {self.github_token}'} if self.github_token else {} |
|
|
| async with aiohttp.ClientSession(headers=headers) as session: |
| per_query = max(1, max_repos // max(1, len(queries))) |
| for query in queries: |
| await self._search_github_repos(session, query, per_query) |
|
|
| async def _search_github_repos(self, session: aiohttp.ClientSession, query: str, limit: int): |
| """Search GitHub for repositories matching query""" |
| url = f"https://api.github.com/search/repositories" |
| params = { |
| 'q': query, |
| 'sort': 'stars', |
| 'order': 'desc', |
| 'per_page': min(limit, 100) |
| } |
|
|
| try: |
| async with session.get(url, params=params) as response: |
| if response.status == 200: |
| data = await response.json() |
| for repo in data.get('items', []): |
| await self._process_repository(session, repo) |
| else: |
| logger.warning(f"GitHub API request failed: {response.status}") |
| except Exception as e: |
| logger.error(f"Error searching GitHub: {e}") |
|
|
| async def _process_repository(self, session: aiohttp.ClientSession, repo: Dict): |
| """Process a single repository to extract code examples""" |
| logger.info(f"Processing repository: {repo.get('full_name', '<unknown>')}") |
|
|
| try: |
| contents_url = f"https://api.github.com/repos/{repo['full_name']}/contents" |
| async with session.get(contents_url) as response: |
| if response.status == 200: |
| contents = await response.json() |
| await self._extract_code_example(session, repo, contents) |
| except Exception as e: |
| logger.error(f"Error processing repository {repo.get('full_name')}: {e}") |
|
|
| async def _extract_code_example(self, session: aiohttp.ClientSession, repo: Dict, contents: List[Dict]): |
| """Extract a structured code example from repository""" |
| framework = self._identify_framework(contents, repo.get('description', '')) |
| language = self._identify_language(contents) |
|
|
| if not framework or not language: |
| return |
|
|
| code_files: Dict[str, str] = {} |
| for item in contents: |
| if item.get('type') == 'file' and self._is_important_file(item.get('name', '')): |
| try: |
| async with session.get(item['download_url']) as response: |
| if response.status == 200: |
| content = await response.text() |
| code_files[item['name']] = content |
| except Exception: |
| continue |
|
|
| if code_files: |
| example = CodeExample( |
| description=repo.get('description', ''), |
| requirements=self._extract_requirements(code_files), |
| framework=framework, |
| language=language, |
| code_files=code_files, |
| project_structure=self._analyze_structure(contents), |
| metadata={ |
| 'stars': repo.get('stargazers_count', 0), |
| 'forks': repo.get('forks_count', 0), |
| 'url': repo.get('html_url'), |
| 'created_at': repo.get('created_at'), |
| 'updated_at': repo.get('updated_at') |
| } |
| ) |
| self.collected_examples.append(example) |
|
|
| def _identify_framework(self, contents: List[Dict], description: str) -> Optional[str]: |
| """Identify the backend framework used""" |
| filenames = [item.get('name', '').lower() for item in contents if item.get('type') == 'file'] |
|
|
| frameworks = { |
| 'express': ['package.json', 'app.js', 'server.js'], |
| 'fastapi': ['requirements.txt', 'main.py', 'app.py'], |
| 'django': ['manage.py', 'settings.py', 'requirements.txt'], |
| 'flask': ['app.py', 'requirements.txt'], |
| 'nestjs': ['nest-cli.json', 'package.json'], |
| 'koa': ['package.json'], |
| 'gin': ['go.mod', 'main.go'], |
| 'fiber': ['go.mod', 'main.go'], |
| } |
|
|
| for framework, required_files in frameworks.items(): |
| if all(any(req in filename for filename in filenames) for req in required_files[:2]): |
| return framework |
|
|
| desc_lower = description.lower() |
| for framework in frameworks.keys(): |
| if framework in desc_lower: |
| return framework |
|
|
| return None |
|
|
| def _identify_language(self, contents: List[Dict]) -> Optional[str]: |
| """Identify primary programming language""" |
| extensions: Dict[str, int] = {} |
| for item in contents: |
| if item.get('type') == 'file': |
| ext = Path(item.get('name', '')).suffix.lower() |
| if ext: |
| extensions[ext] = extensions.get(ext, 0) + 1 |
|
|
| lang_map = { |
| '.js': 'javascript', |
| '.ts': 'typescript', |
| '.py': 'python', |
| '.go': 'go', |
| '.java': 'java', |
| '.cs': 'csharp', |
| '.rb': 'ruby', |
| '.php': 'php' |
| } |
|
|
| if extensions: |
| most_common_ext = max(extensions.items(), key=lambda x: x[1])[0] |
| return lang_map.get(most_common_ext) |
|
|
| return None |
|
|
| def _is_important_file(self, filename: str) -> bool: |
| """Check if file is important for training""" |
| important_patterns = [ |
| 'package.json', 'requirements.txt', 'go.mod', 'pom.xml', |
| 'dockerfile', 'docker-compose.yml', 'readme.md', |
| 'app.py', 'main.py', 'server.js', 'app.js', 'index.js', |
| 'settings.py', 'config.py', 'routes.py', 'models.py', |
| 'controller.js', 'service.js', 'middleware.js' |
| ] |
|
|
| filename_lower = filename.lower() |
| return any(pattern in filename_lower for pattern in important_patterns) |
|
|
| def _extract_requirements(self, code_files: Dict[str, str]) -> List[str]: |
| """Extract functional requirements from code""" |
| requirements: List[str] = [] |
|
|
| if 'package.json' in code_files: |
| try: |
| pkg_data = json.loads(code_files['package.json']) |
| deps = list(pkg_data.get('dependencies', {}).keys()) |
| requirements.extend([f"Uses {dep}" for dep in deps[:5]]) |
| except Exception: |
| pass |
|
|
| if 'requirements.txt' in code_files: |
| lines = code_files['requirements.txt'].strip().split('\n') |
| deps = [line.split('==')[0].split('>=')[0].strip() for line in lines if line.strip()] |
| requirements.extend([f"Uses {dep}" for dep in deps[:5]]) |
|
|
| for filename, content in code_files.items(): |
| if filename.endswith(('.js', '.py')): |
| endpoints = self._extract_endpoints(content) |
| requirements.extend(endpoints) |
|
|
| return requirements[:10] |
|
|
| def _extract_endpoints(self, code_content: str) -> List[str]: |
| """Extract API endpoints from code""" |
| endpoints: List[str] = [] |
| lines = code_content.split('\n') |
|
|
| for line in lines: |
| s = line.strip() |
| if any(method in s for method in ['app.get(', 'app.post(', 'app.put(', 'app.delete(']): |
| endpoints.append(f"Implements {s}") |
| elif any(decorator in s for decorator in ['@app.get(', '@app.post(', '@app.put(', '@app.delete(']): |
| endpoints.append(f"Implements {s}") |
| elif 'def ' in s and any(word in s for word in ['get', 'post', 'put', 'delete']): |
| endpoints.append(f"Implements {s}") |
|
|
| return endpoints[:5] |
|
|
| def _analyze_structure(self, contents: List[Dict]) -> Dict[str, Any]: |
| """Analyze project structure""" |
| structure: Dict[str, Any] = { |
| 'files': [], |
| 'directories': [], |
| 'total_files': 0, |
| 'has_tests': False, |
| 'has_docs': False |
| } |
|
|
| for item in contents: |
| if item.get('type') == 'file': |
| name = item.get('name', '') |
| structure['files'].append(name) |
| structure['total_files'] += 1 |
| if 'test' in name.lower(): |
| structure['has_tests'] = True |
| if name.lower() in ['readme.md', 'docs.md']: |
| structure['has_docs'] = True |
| elif item.get('type') == 'dir': |
| structure['directories'].append(item.get('name', '')) |
|
|
| return structure |
|
|
| def generate_synthetic_examples(self, count: int = 100): |
| """Generate synthetic training examples""" |
| logger.info(f"Generating {count} synthetic examples...") |
|
|
| templates = [ |
| { |
| 'description': 'REST API for user management', |
| 'requirements': ['User registration', 'User authentication', 'Profile management'], |
| 'frameworks': ['express', 'fastapi', 'django'] |
| }, |
| { |
| 'description': 'E-commerce backend API', |
| 'requirements': ['Product catalog', 'Shopping cart', 'Order processing', 'Payment integration'], |
| 'frameworks': ['nestjs', 'fastapi', 'django'] |
| }, |
| { |
| 'description': 'Task management system', |
| 'requirements': ['Task CRUD operations', 'User assignments', 'Status tracking'], |
| 'frameworks': ['express', 'flask', 'gin'] |
| }, |
| { |
| 'description': 'Blog platform backend', |
| 'requirements': ['Article management', 'User comments', 'Category system'], |
| 'frameworks': ['express', 'django', 'fastapi'] |
| } |
| ] |
|
|
| for _ in range(count): |
| template = random.choice(templates) |
| framework = random.choice(template['frameworks']) |
|
|
| code_files = self._generate_code_for_template(template, framework) |
|
|
| example = CodeExample( |
| description=template['description'], |
| requirements=template['requirements'], |
| framework=framework, |
| language='python' if framework in ['fastapi', 'django', 'flask'] else 'javascript', |
| code_files=code_files, |
| project_structure=self._generate_synthetic_structure(framework), |
| metadata={'synthetic': True} |
| ) |
|
|
| self.collected_examples.append(example) |
|
|
| def _generate_code_for_template(self, template: Dict, framework: str) -> Dict[str, str]: |
| """Generate code files for a template and framework""" |
| if framework == 'express': |
| return { |
| 'package.json': json.dumps({ |
| "name": template['description'].lower().replace(' ', '-'), |
| "version": "1.0.0", |
| "dependencies": { |
| "express": "^4.18.0", |
| "mongoose": "^6.0.0", |
| "bcrypt": "^5.0.0", |
| "jsonwebtoken": "^8.5.0" |
| } |
| }, indent=2), |
| 'app.js': '''const express = require('express'); |
| const mongoose = require('mongoose'); |
| const app = express(); |
| |
| // Middleware |
| app.use(express.json()); |
| |
| // Routes |
| app.get('/health', (req, res) => { |
| res.json({ status: 'OK' }); |
| }); |
| |
| // Start server |
| const PORT = process.env.PORT || 3000; |
| app.listen(PORT, () => { |
| console.log(`Server running on port ${PORT}`); |
| }); |
| |
| module.exports = app;''' |
| } |
| elif framework == 'fastapi': |
| return { |
| 'requirements.txt': '''fastapi==0.68.0 |
| uvicorn==0.15.0 |
| sqlalchemy==1.4.23 |
| pydantic==1.8.2''', |
| 'main.py': '''from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from typing import List, Optional |
| |
| app = FastAPI() |
| |
| class Item(BaseModel): |
| id: Optional[int] = None |
| name: str |
| description: str |
| |
| @app.get("/") |
| async def root(): |
| return {"message": "Hello World"} |
| |
| @app.get("/health") |
| async def health_check(): |
| return {"status": "OK"} |
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000)''' |
| } |
| else: |
| return {'placeholder.txt': 'Generated code placeholder'} |
|
|
| def _generate_synthetic_structure(self, framework: str) -> Dict[str, Any]: |
| """Generate project structure for framework""" |
| if framework in ['express', 'nestjs']: |
| return { |
| 'files': ['package.json', 'app.js', 'README.md'], |
| 'directories': ['routes', 'controllers', 'middleware', 'models'], |
| 'total_files': 3, |
| 'has_tests': True, |
| 'has_docs': True |
| } |
| elif framework in ['fastapi', 'django', 'flask']: |
| return { |
| 'files': ['requirements.txt', 'main.py', 'README.md'], |
| 'directories': ['models', 'routes', 'services'], |
| 'total_files': 3, |
| 'has_tests': True, |
| 'has_docs': True |
| } |
| else: |
| return {} |
|
|
| def save_dataset(self, filepath: str): |
| """Save collected examples to file""" |
| data = [asdict(example) for example in self.collected_examples] |
| with open(filepath, 'w', encoding='utf-8') as f: |
| json.dump(data, f, indent=2, ensure_ascii=False) |
| logger.info(f"Saved {len(data)} examples to {filepath}") |
|
|
|
|
| class DataPreprocessor: |
| """Preprocesses collected data for training""" |
|
|
| def __init__(self, tokenizer_name: str = "microsoft/DialoGPT-medium"): |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| try: |
| model_max = getattr(self.tokenizer, 'model_max_length', 1024) |
| |
| if model_max and model_max > 0 and model_max < 100000: |
| self.max_length = min(1024, int(model_max)) |
| else: |
| self.max_length = 1024 |
| except Exception: |
| self.max_length = 1024 |
|
|
| def preprocess_examples(self, examples: List[CodeExample]) -> List[Dict[str, str]]: |
| """Convert examples to training format""" |
| processed: List[Dict[str, str]] = [] |
|
|
| for example in examples: |
| input_text = self._create_input_text(example) |
| output_text = self._create_output_text(example) |
|
|
| processed.append({ |
| 'input': input_text, |
| 'output': output_text, |
| 'framework': example.framework, |
| 'language': example.language |
| }) |
|
|
| return processed |
|
|
| def _create_input_text(self, example: CodeExample) -> str: |
| """Create model input text""" |
| input_parts: List[str] = [ |
| f"Description: {example.description}", |
| f"Framework: {example.framework}", |
| f"Language: {example.language}", |
| "Requirements:", |
| ] |
|
|
| for req in example.requirements: |
| input_parts.append(f"- {req}") |
|
|
| input_parts.append("Generate the backend application:") |
|
|
| return "\n".join(input_parts) |
|
|
| def _create_output_text(self, example: CodeExample) -> str: |
| """Create model output text""" |
| output_parts: List[str] = [] |
|
|
| output_parts.append("Project Structure:") |
| for directory in example.project_structure.get('directories', []): |
| output_parts.append(f"/{directory}/") |
|
|
| output_parts.append("\nGenerated Files:") |
|
|
| for filename, content in example.code_files.items(): |
| output_parts.append(f"\n--- {filename} ---") |
| output_parts.append(content) |
| output_parts.append("--- End ---") |
|
|
| return "\n".join(output_parts) |
|
|
| def create_training_dataset(self, processed_examples: List[Dict[str, str]]) -> HFDataset: |
| """Create Hugging Face dataset for training""" |
|
|
| def tokenize_function(examples: Dict[str, List[str]]): |
| texts: List[str] = [] |
| for inp, out in zip(examples['input'], examples['output']): |
| text = f"<|startoftext|>{inp}<|separator|>{out}<|endoftext|>" |
| texts.append(text) |
|
|
| return self.tokenizer( |
| texts, |
| truncation=True, |
| padding=True, |
| max_length=self.max_length |
| ) |
|
|
| dataset_dict = { |
| 'input': [ex['input'] for ex in processed_examples], |
| 'output': [ex['output'] for ex in processed_examples], |
| 'framework': [ex['framework'] for ex in processed_examples], |
| 'language': [ex['language'] for ex in processed_examples] |
| } |
|
|
| dataset = HFDataset.from_dict(dataset_dict) |
| tokenized_dataset = dataset.map(tokenize_function, batched=True) |
|
|
| return tokenized_dataset |
|
|
|
|
| class CodeGenerationModel: |
| """Custom model for backend code generation""" |
|
|
| def __init__(self, base_model: str = "microsoft/DialoGPT-medium"): |
| self.base_model = base_model |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model) |
| self.model = AutoModelForCausalLM.from_pretrained(base_model) |
|
|
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| def fine_tune(self, dataset: HFDataset, output_dir: str = "./trained_model"): |
| """Fine-tune the model on backend code generation""" |
| logger.info("Starting model fine-tuning...") |
|
|
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| overwrite_output_dir=True, |
| num_train_epochs=1, |
| per_device_train_batch_size=1, |
| per_device_eval_batch_size=1, |
| warmup_steps=50, |
| max_steps=100, |
| logging_steps=10, |
| save_steps=50, |
| save_total_limit=2, |
| prediction_loss_only=True, |
| fp16=torch.cuda.is_available(), |
| dataloader_pin_memory=False, |
| gradient_accumulation_steps=4, |
| learning_rate=5e-5, |
| ) |
|
|
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=self.tokenizer, |
| mlm=False, |
| ) |
|
|
| train_size = int(0.8 * len(dataset)) |
| eval_size = len(dataset) - train_size |
| train_dataset, eval_dataset = torch.utils.data.random_split( |
| dataset, [train_size, eval_size] |
| ) |
|
|
| trainer = Trainer( |
| model=self.model, |
| args=training_args, |
| data_collator=data_collator, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| ) |
|
|
| trainer.train() |
| trainer.save_model() |
|
|
| logger.info("Fine-tuning completed!") |
|
|
| def generate_code(self, description: str, framework: str, language: str) -> str: |
| """Generate backend code for given requirements""" |
| input_text = ( |
| f"Description: {description}\n" |
| f"Framework: {framework}\n" |
| f"Language: {language}\n" |
| f"Generate the backend application:" |
| ) |
|
|
| |
| model_max_len = getattr(self.tokenizer, 'model_max_length', 1024) |
| max_len = 1024 if model_max_len is None or model_max_len > 100000 else min(1024, int(model_max_len)) |
|
|
| inputs = self.tokenizer.encode(input_text, return_tensors='pt', truncation=True, max_length=max_len) |
|
|
| with torch.no_grad(): |
| outputs = self.model.generate( |
| inputs, |
| max_length=max_len, |
| num_return_sequences=1, |
| temperature=0.7, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
|
|
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return generated_text[len(input_text):] |
|
|
|
|
| class ModelEvaluator: |
| """Evaluates model performance""" |
|
|
| def __init__(self): |
| self.metrics: Dict[str, float] = {} |
|
|
| def evaluate_code_quality(self, generated_code: str, language: str) -> Dict[str, float]: |
| """Evaluate generated code quality""" |
| metrics = { |
| 'syntax_correctness': self._check_syntax(generated_code, language), |
| 'completeness': self._check_completeness(generated_code), |
| 'best_practices': self._check_best_practices(generated_code, language) |
| } |
|
|
| return metrics |
|
|
| def _check_syntax(self, code: str, language: str) -> float: |
| """Check if generated code has valid syntax""" |
| if language == 'python': |
| try: |
| ast.parse(code) |
| return 1.0 |
| except SyntaxError: |
| return 0.0 |
| elif language == 'javascript': |
| if '{' in code and '}' in code: |
| return 0.8 |
| return 0.5 |
|
|
| return 0.5 |
|
|
| def _check_completeness(self, code: str) -> float: |
| """Check if code appears complete""" |
| completeness_indicators = [ |
| 'import', 'require', 'function', 'def', 'class', |
| 'app.', 'router.', '@app.', 'app.listen', 'if __name__' |
| ] |
|
|
| indicators_found = sum(1 for indicator in completeness_indicators if indicator in code) |
| return min(indicators_found / 3.0, 1.0) |
|
|
| def _check_best_practices(self, code: str, language: str) -> float: |
| """Check adherence to best practices""" |
| best_practices_score = 0.0 |
|
|
| if 'try:' in code or 'catch' in code: |
| best_practices_score += 0.2 |
|
|
| if any(comment in code for comment in ['#', '//', '/*']): |
| best_practices_score += 0.2 |
|
|
| if language == 'python': |
| if 'if __name__ == "__main__"' in code: |
| best_practices_score += 0.2 |
| elif language == 'javascript': |
| if 'const' in code or 'let' in code: |
| best_practices_score += 0.2 |
|
|
| return min(best_practices_score, 1.0) |
|
|
| def benchmark_model(self, model: 'CodeGenerationModel', test_cases: List[Dict]) -> Dict[str, float]: |
| """Benchmark model on test cases""" |
| total_scores = {'syntax': 0.0, 'completeness': 0.0, 'best_practices': 0.0} |
|
|
| for i, test_case in enumerate(test_cases): |
| generated_code = model.generate_code( |
| test_case['description'], |
| test_case['framework'], |
| test_case['language'] |
| ) |
|
|
| scores = self.evaluate_code_quality(generated_code, test_case['language']) |
|
|
| total_scores['syntax'] += scores['syntax_correctness'] |
| total_scores['completeness'] += scores['completeness'] |
| total_scores['best_practices'] += scores['best_practices'] |
|
|
| logger.info(f"Test case {i+1}: {scores}") |
|
|
| num_cases = max(1, len(test_cases)) |
| avg_scores = {key: value / num_cases for key, value in total_scores.items()} |
|
|
| return avg_scores |
|
|
|
|
| class TrainingPipeline: |
| """Main training pipeline orchestrator""" |
|
|
| def __init__(self, config: Dict[str, Any]): |
| self.config = config |
| self.data_collector = DataCollector() |
| self.preprocessor = DataPreprocessor(config.get('tokenizer', 'microsoft/DialoGPT-medium')) |
| self.model = CodeGenerationModel(config.get('base_model', 'microsoft/DialoGPT-medium')) |
| self.evaluator = ModelEvaluator() |
|
|
| async def run_full_pipeline(self): |
| """Run the complete training pipeline""" |
| logger.info("Starting full training pipeline...") |
|
|
| logger.info("Step 1: Collecting training data...") |
|
|
| if self.data_collector.github_token: |
| github_queries = [ |
| 'express api backend', |
| 'fastapi python backend', |
| 'django rest api', |
| 'nodejs backend server', |
| 'flask api backend' |
| ] |
| await self.data_collector.collect_github_repositories(github_queries, max_repos=50) |
|
|
| self.data_collector.generate_synthetic_examples(count=200) |
|
|
| self.data_collector.save_dataset('raw_dataset.json') |
|
|
| logger.info("Step 2: Preprocessing data...") |
| processed_examples = self.preprocessor.preprocess_examples(self.data_collector.collected_examples) |
| training_dataset = self.preprocessor.create_training_dataset(processed_examples) |
|
|
| logger.info("Step 3: Training model...") |
| self.model.fine_tune(training_dataset, output_dir=self.config.get('output_dir', './trained_model')) |
|
|
| logger.info("Step 4: Evaluating model...") |
| test_cases = [ |
| { |
| 'description': 'REST API for user management with authentication', |
| 'framework': 'express', |
| 'language': 'javascript' |
| }, |
| { |
| 'description': 'FastAPI backend for e-commerce platform', |
| 'framework': 'fastapi', |
| 'language': 'python' |
| }, |
| { |
| 'description': 'Django REST API for blog platform', |
| 'framework': 'django', |
| 'language': 'python' |
| } |
| ] |
|
|
| benchmark_results = self.evaluator.benchmark_model(self.model, test_cases) |
| logger.info(f"Benchmark results: {benchmark_results}") |
|
|
| logger.info("Training pipeline completed!") |
| return benchmark_results |
|
|
|
|
| if __name__ == "__main__": |
| config = { |
| 'base_model': 'microsoft/DialoGPT-medium', |
| 'tokenizer': 'microsoft/DialoGPT-medium', |
| 'output_dir': './backend_code_model', |
| 'github_token': os.getenv('GITHUB_TOKEN'), |
| } |
|
|
| pipeline = TrainingPipeline(config) |
|
|
| asyncio.run(pipeline.run_full_pipeline()) |
|
|
| logger.info("\nTesting trained model...") |
| generated_code = pipeline.model.generate_code( |
| description="Create a REST API for managing tasks with CRUD operations", |
| framework="express", |
| language="javascript" |
| ) |
|
|
| print("\nGenerated Code:") |
| print("=" * 50) |
| print(generated_code) |
|
|
|
|
|
|