| |
| import os |
| import subprocess |
| import traceback |
| import gradio as gr |
|
|
| MODEL_ID = "CADCODER/CAD-Coder" |
| REPO_GIT = "https://github.com/CADCODER/CAD-Coder.git" |
| REPO_DIR = "CAD-Coder" |
|
|
| |
| if not os.path.isdir(REPO_DIR): |
| try: |
| print("Cloning CAD-Coder repo...") |
| subprocess.run(["git", "clone", REPO_GIT, REPO_DIR], check=True) |
| except Exception as e: |
| print("Could not clone repository:", e) |
|
|
| |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HF_HUB_API_TOKEN") |
| local_generate = None |
| api_generate = None |
|
|
| |
| try: |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| try: |
| import bitsandbytes as bnb |
| has_bnb = True |
| except Exception: |
| has_bnb = False |
|
|
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=hf_token, trust_remote_code=True) |
|
|
| load_kwargs = {"device_map": "auto", "trust_remote_code": True} |
| if has_bnb: |
| print("bitsandbytes available — will attempt 8-bit load (saves memory).") |
| load_kwargs.update({"load_in_8bit": True, "torch_dtype": torch.float16}) |
| else: |
| |
| if torch.cuda.is_available(): |
| load_kwargs["torch_dtype"] = torch.float16 |
|
|
| print("Loading model (this can take a while)...") |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=hf_token, **load_kwargs) |
|
|
| if hasattr(model, "to"): |
| |
| pass |
|
|
| device = next(model.parameters()).device |
| print("Model loaded on device:", device) |
|
|
| def local_generate_fn(prompt, max_new_tokens=512): |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| gen = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) |
| return tokenizer.decode(gen[0], skip_special_tokens=True) |
|
|
| local_generate = local_generate_fn |
|
|
| except Exception as e: |
| print("Local model load failed or not feasible in this environment.") |
| traceback.print_exc() |
|
|
| |
| if local_generate is None: |
| try: |
| from huggingface_hub import InferenceApi |
| print("Setting up HF Inference API client as fallback...") |
| api = InferenceApi(repo_id=MODEL_ID, token=hf_token) |
|
|
| def api_generate_fn(prompt, max_new_tokens=512): |
| |
| out = api(inputs=prompt, params={"max_new_tokens": max_new_tokens}) |
| |
| if isinstance(out, list): |
| first = out[0] |
| if isinstance(first, dict): |
| return first.get("generated_text") or str(first) |
| return str(first) |
| elif isinstance(out, dict): |
| return out.get("generated_text") or str(out) |
| else: |
| return str(out) |
|
|
| api_generate = api_generate_fn |
| print("Inference API fallback ready.") |
| except Exception as e: |
| print("HF Inference API not available:", e) |
| traceback.print_exc() |
|
|
| |
| def generate(prompt, max_new_tokens=512): |
| if local_generate: |
| return local_generate(prompt, max_new_tokens=max_new_tokens) |
| elif api_generate: |
| return api_generate(prompt, max_new_tokens=max_new_tokens) |
| else: |
| return "ERROR: No model loaded and no API fallback available. Check HF_TOKEN and Space hardware." |
|
|
| |
| def run_prompt(prompt, max_tokens=512): |
| if not prompt or prompt.strip() == "": |
| return "Enter a prompt describing the CAD sketch you want (e.g., 'rectangle width 10 height 5 with hole radius 1')." |
| try: |
| return generate(prompt, max_new_tokens=int(max_tokens)) |
| except Exception as e: |
| traceback.print_exc() |
| return f"Generation error: {e}" |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# CAD-Coder (Text → CadQuery code)") |
| prompt = gr.Textbox(label="Natural language prompt", lines=4, placeholder="e.g. 'create a rectangular plate 100x50 with a centered 10mm hole'...") |
| max_tokens = gr.Slider(minimum=64, maximum=2048, step=64, value=512, label="Max new tokens") |
| out = gr.Textbox(label="Generated CadQuery code", lines=18) |
| btn = gr.Button("Generate") |
| btn.click(run_prompt, inputs=[prompt, max_tokens], outputs=out) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |
|
|