| """
|
| FARA Backend Server for HuggingFace Space
|
| Provides WebSocket communication and REST API for the React frontend
|
| """
|
|
|
| import asyncio
|
| import base64
|
| import logging
|
| import os
|
|
|
|
|
| import sys
|
| import tempfile
|
| import uuid
|
| from datetime import datetime
|
| from typing import Dict, Optional
|
|
|
| import httpx
|
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from fastapi.responses import JSONResponse
|
| from playwright._impl._errors import TargetClosedError
|
|
|
| sys.path.insert(0, "/app")
|
| from fara import FaraAgent
|
| from fara.browser.browser_bb import BrowserBB
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| MODAL_TRACE_STORAGE_URL = os.environ.get("MODAL_TRACE_STORAGE_URL", "")
|
| MODAL_TOKEN_ID = os.environ.get("MODAL_TOKEN_ID", "")
|
| MODAL_TOKEN_SECRET = os.environ.get("MODAL_TOKEN_SECRET", "")
|
|
|
|
|
|
|
| ENDPOINT_CONFIG = {
|
| "model": os.environ.get("FARA_MODEL_NAME", "microsoft/Fara-7B"),
|
| "base_url": os.environ.get("FARA_ENDPOINT_URL"),
|
| "api_key": os.environ.get("FARA_API_KEY", "not-needed"),
|
| "default_headers": {
|
| "Modal-Key": MODAL_TOKEN_ID,
|
| "Modal-Secret": MODAL_TOKEN_SECRET,
|
| }
|
| if MODAL_TOKEN_ID and MODAL_TOKEN_SECRET
|
| else None,
|
| }
|
|
|
|
|
| AVAILABLE_MODELS = ["microsoft/Fara-7B"]
|
|
|
| app = FastAPI(title="FARA Backend")
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
| active_connections: Dict[str, WebSocket] = {}
|
| active_sessions: Dict[str, "FaraSession"] = {}
|
|
|
|
|
| class FaraSession:
|
| """Manages a single FARA agent session"""
|
|
|
| def __init__(self, trace_id: str, websocket: WebSocket):
|
| self.trace_id = trace_id
|
| self.websocket = websocket
|
| self.agent: Optional[FaraAgent] = None
|
| self.browser_manager: Optional[BrowserBB] = None
|
| self.screenshots_dir: Optional[str] = None
|
| self.is_running = False
|
| self.should_stop = False
|
| self.step_count = 0
|
| self.start_time: Optional[datetime] = None
|
| self.total_input_tokens = 0
|
| self.total_output_tokens = 0
|
|
|
| async def initialize(self, start_page: str = "https://www.bing.com/"):
|
| """Initialize the browser and agent"""
|
|
|
| self.screenshots_dir = tempfile.mkdtemp(prefix="fara_screenshots_")
|
|
|
|
|
| self.browser_manager = BrowserBB(
|
| headless=True,
|
| viewport_height=900,
|
| viewport_width=1440,
|
| page_script_path=None,
|
| browser_channel="chromium",
|
| browser_data_dir=None,
|
| downloads_folder=self.screenshots_dir,
|
| to_resize_viewport=True,
|
| single_tab_mode=True,
|
| animate_actions=False,
|
| use_browser_base=False,
|
| logger=logger,
|
| )
|
|
|
| self.agent = FaraAgent(
|
| browser_manager=self.browser_manager,
|
| client_config=ENDPOINT_CONFIG,
|
| start_page=start_page,
|
| downloads_folder=self.screenshots_dir,
|
| save_screenshots=True,
|
| max_rounds=50,
|
| )
|
|
|
| await self.agent.initialize()
|
| return True
|
|
|
| async def send_event(self, event: dict):
|
| """Send event to the connected WebSocket"""
|
| try:
|
| await self.websocket.send_json(event)
|
| except Exception as e:
|
| logger.error(f"Error sending event: {e}")
|
|
|
| async def get_screenshot_base64(self) -> Optional[str]:
|
| """Get the current browser screenshot as base64"""
|
| if self.agent:
|
| try:
|
|
|
| page = self._get_active_page()
|
| if page:
|
| screenshot_bytes = (
|
| await self.agent._playwright_controller.get_screenshot(page)
|
| )
|
| return f"data:image/png;base64,{base64.b64encode(screenshot_bytes).decode()}"
|
| except TargetClosedError:
|
| logger.warning(
|
| "Page closed while getting screenshot, attempting recovery..."
|
| )
|
| page = self._get_active_page()
|
| if page:
|
| try:
|
| screenshot_bytes = (
|
| await self.agent._playwright_controller.get_screenshot(page)
|
| )
|
| return f"data:image/png;base64,{base64.b64encode(screenshot_bytes).decode()}"
|
| except Exception as e:
|
| logger.error(f"Recovery screenshot failed: {e}")
|
| except Exception as e:
|
| logger.error(f"Error getting screenshot: {e}")
|
| return None
|
|
|
| def _get_active_page(self):
|
| """Get the currently active page from the browser context"""
|
| if (
|
| self.agent
|
| and self.agent.browser_manager
|
| and self.agent.browser_manager._context
|
| ):
|
| pages = self.agent.browser_manager._context.pages
|
| if pages:
|
|
|
| return pages[-1]
|
| return self.agent._page if self.agent else None
|
|
|
| async def run_task(self, instruction: str, model_id: str):
|
| """Run a task and stream results via WebSocket"""
|
| self.is_running = True
|
| self.should_stop = False
|
| self.step_count = 0
|
| self.start_time = datetime.now()
|
| self.total_input_tokens = 0
|
| self.total_output_tokens = 0
|
|
|
| try:
|
|
|
| await self.send_event(
|
| {
|
| "type": "agent_start",
|
| "agentTrace": {
|
| "id": self.trace_id,
|
| "instruction": instruction,
|
| "modelId": model_id,
|
| "timestamp": self.start_time.isoformat(),
|
| "isRunning": True,
|
| "traceMetadata": {
|
| "traceId": self.trace_id,
|
| "inputTokensUsed": 0,
|
| "outputTokensUsed": 0,
|
| "duration": 0,
|
| "numberOfSteps": 0,
|
| "maxSteps": 50,
|
| "completed": False,
|
| },
|
| },
|
| }
|
| )
|
|
|
|
|
| await self.initialize()
|
|
|
|
|
| initial_screenshot = await self.get_screenshot_base64()
|
|
|
|
|
| await self._run_agent_with_streaming(instruction)
|
|
|
| except Exception as e:
|
| logger.exception("Error running agent task")
|
| await self.send_event({"type": "agent_error", "error": str(e)})
|
| finally:
|
| self.is_running = False
|
| await self.close()
|
|
|
| async def _run_agent_with_streaming(self, user_message: str):
|
| """Run the agent and stream each step to the frontend"""
|
| agent = self.agent
|
|
|
|
|
| await agent.initialize()
|
| assert agent._page is not None, "Page should be initialized"
|
|
|
|
|
| scaled_screenshot = await agent._get_scaled_screenshot()
|
|
|
| if agent.save_screenshots:
|
| await agent._playwright_controller.get_screenshot(
|
| agent._page,
|
| path=os.path.join(
|
| agent.downloads_folder, f"screenshot{agent._num_actions}.png"
|
| ),
|
| )
|
|
|
|
|
| from fara.types import ImageObj, UserMessage
|
|
|
| agent._chat_history.append(
|
| UserMessage(
|
| content=[ImageObj.from_pil(scaled_screenshot), user_message],
|
| is_original=True,
|
| )
|
| )
|
|
|
| final_answer = "<no_answer>"
|
| is_stop_action = False
|
|
|
| for i in range(agent.max_rounds):
|
| if self.should_stop:
|
|
|
| await self.send_event(
|
| {
|
| "type": "agent_complete",
|
| "traceMetadata": self._get_metadata(),
|
| "final_state": "stopped",
|
| }
|
| )
|
| return
|
|
|
| is_first_round = i == 0
|
| step_start_time = datetime.now()
|
|
|
|
|
| if not agent.browser_manager._captcha_event.is_set():
|
| logger.info("Waiting 60s for captcha to finish...")
|
| captcha_solved = await agent.wait_for_captcha_with_timeout(60)
|
| if (
|
| not captcha_solved
|
| and not agent.browser_manager._captcha_event.is_set()
|
| ):
|
| raise RuntimeError("Captcha timed out")
|
|
|
| try:
|
|
|
| function_call, raw_response = await agent.generate_model_call(
|
| is_first_round, scaled_screenshot if is_first_round else None
|
| )
|
|
|
|
|
| thoughts, action_dict = agent._parse_thoughts_and_action(raw_response)
|
| action_args = action_dict.get("arguments", {})
|
| action = action_args["action"]
|
|
|
| logger.info(
|
| f"\nThought #{i + 1}: {thoughts}\nAction #{i + 1}: {action}"
|
| )
|
|
|
|
|
| try:
|
| (
|
| is_stop_action,
|
| new_screenshot,
|
| action_description,
|
| ) = await agent.execute_action(function_call)
|
| except TargetClosedError as e:
|
| logger.warning(
|
| "Page closed during action execution, attempting recovery..."
|
| )
|
|
|
| new_page = self._get_active_page()
|
| if new_page and new_page != agent._page:
|
| logger.info("Recovered with new active page")
|
| agent._page = new_page
|
|
|
| await asyncio.sleep(1)
|
| action_description = (
|
| "Action completed (page navigation occurred)"
|
| )
|
| is_stop_action = False
|
| new_screenshot = None
|
| else:
|
| raise e
|
|
|
|
|
| active_page = self._get_active_page()
|
| if active_page and active_page != agent._page:
|
| logger.info("Updating agent page reference to active page")
|
| agent._page = active_page
|
|
|
|
|
| screenshot_base64 = await self.get_screenshot_base64()
|
|
|
| except TargetClosedError as e:
|
| logger.error(f"Unrecoverable page error: {e}")
|
| await self.send_event(
|
| {
|
| "type": "agent_error",
|
| "error": f"Browser page closed unexpectedly: {str(e)}",
|
| }
|
| )
|
| return
|
| except Exception as e:
|
| logger.exception(f"Error in agent step {i + 1}")
|
| await self.send_event({"type": "agent_error", "error": str(e)})
|
| return
|
|
|
|
|
| step_duration = (datetime.now() - step_start_time).total_seconds()
|
| step_input_tokens = 1000
|
| step_output_tokens = len(raw_response) // 4
|
|
|
| self.total_input_tokens += step_input_tokens
|
| self.total_output_tokens += step_output_tokens
|
| self.step_count += 1
|
|
|
|
|
| step = {
|
| "stepId": str(uuid.uuid4()),
|
| "traceId": self.trace_id,
|
| "stepNumber": self.step_count,
|
| "thought": thoughts,
|
| "actions": [
|
| {
|
| "function_name": action,
|
| "description": action_description,
|
| "parameters": action_args,
|
| }
|
| ],
|
| "image": screenshot_base64,
|
| "duration": step_duration,
|
| "inputTokensUsed": step_input_tokens,
|
| "outputTokensUsed": step_output_tokens,
|
| "timestamp": datetime.now().isoformat(),
|
| }
|
|
|
|
|
| await self.send_event(
|
| {
|
| "type": "agent_progress",
|
| "agentStep": step,
|
| "traceMetadata": self._get_metadata(),
|
| }
|
| )
|
|
|
| if is_stop_action:
|
| final_answer = thoughts
|
| break
|
|
|
|
|
| final_state = "success" if is_stop_action else "max_steps_reached"
|
| await self.send_event(
|
| {
|
| "type": "agent_complete",
|
| "traceMetadata": self._get_metadata(completed=True),
|
| "final_state": final_state,
|
| }
|
| )
|
|
|
| def _get_metadata(self, completed: bool = False) -> dict:
|
| """Get current trace metadata"""
|
| duration = 0
|
| if self.start_time:
|
| duration = (datetime.now() - self.start_time).total_seconds()
|
|
|
| return {
|
| "traceId": self.trace_id,
|
| "inputTokensUsed": self.total_input_tokens,
|
| "outputTokensUsed": self.total_output_tokens,
|
| "duration": duration,
|
| "numberOfSteps": self.step_count,
|
| "maxSteps": 50,
|
| "completed": completed,
|
| }
|
|
|
| async def stop(self):
|
| """Request the agent to stop"""
|
| self.should_stop = True
|
|
|
| async def close(self):
|
| """Clean up resources"""
|
| if self.agent:
|
| try:
|
| await self.agent.close()
|
| except Exception as e:
|
| logger.error(f"Error closing agent: {e}")
|
| self.agent = None
|
| self.browser_manager = None
|
|
|
| if self.screenshots_dir and os.path.exists(self.screenshots_dir):
|
| import shutil
|
|
|
| try:
|
| shutil.rmtree(self.screenshots_dir)
|
| except Exception as e:
|
| logger.error(f"Error cleaning up screenshots: {e}")
|
| self.screenshots_dir = None
|
|
|
|
|
| @app.get("/api/models")
|
| async def get_models():
|
| """Return available models"""
|
| return JSONResponse(content=AVAILABLE_MODELS)
|
|
|
|
|
| @app.post("/api/traces")
|
| async def store_trace(trace_data: dict):
|
| """
|
| Store a task trace by forwarding to the Modal trace storage endpoint.
|
| This keeps Modal credentials on the server side.
|
| """
|
| if not MODAL_TRACE_STORAGE_URL:
|
| logger.warning("Modal trace storage URL not configured")
|
| return JSONResponse(
|
| status_code=503,
|
| content={"success": False, "error": "Trace storage not configured"},
|
| )
|
|
|
| if not MODAL_TOKEN_ID or not MODAL_TOKEN_SECRET:
|
| logger.warning("Modal proxy auth credentials not configured")
|
| return JSONResponse(
|
| status_code=503,
|
| content={"success": False, "error": "Modal auth not configured"},
|
| )
|
|
|
| try:
|
| async with httpx.AsyncClient(timeout=30.0) as client:
|
| response = await client.post(
|
| MODAL_TRACE_STORAGE_URL,
|
| json=trace_data,
|
| headers={
|
| "Content-Type": "application/json",
|
| "Modal-Key": MODAL_TOKEN_ID,
|
| "Modal-Secret": MODAL_TOKEN_SECRET,
|
| },
|
| )
|
|
|
| if response.status_code == 200:
|
| result = response.json()
|
| logger.info(
|
| f"Trace stored successfully: {result.get('trace_id', 'unknown')}"
|
| )
|
| return JSONResponse(content=result)
|
| else:
|
| error_text = response.text
|
| logger.error(
|
| f"Failed to store trace: {response.status_code} - {error_text}"
|
| )
|
| return JSONResponse(
|
| status_code=response.status_code,
|
| content={
|
| "success": False,
|
| "error": f"Modal API error: {error_text}",
|
| },
|
| )
|
| except httpx.TimeoutException:
|
| logger.error("Timeout storing trace to Modal")
|
| return JSONResponse(
|
| status_code=504,
|
| content={"success": False, "error": "Timeout connecting to trace storage"},
|
| )
|
| except Exception as e:
|
| logger.exception("Error storing trace")
|
| return JSONResponse(
|
| status_code=500, content={"success": False, "error": str(e)}
|
| )
|
|
|
|
|
| @app.get("/api/random-question")
|
| async def get_random_question():
|
| """Return a random example question"""
|
| questions = [
|
| "Search for the latest news about AI agents",
|
| "Find the weather forecast for San Francisco",
|
| "Go to GitHub and search for 'computer use agent'",
|
| "Find the top trending repositories on GitHub today",
|
| "Search for Python tutorials on YouTube",
|
| "Look up the current stock price of Microsoft",
|
| "Find the schedule for upcoming SpaceX launches",
|
| "Search for healthy breakfast recipes",
|
| ]
|
| import random
|
|
|
| return JSONResponse(content={"question": random.choice(questions)})
|
|
|
|
|
| @app.websocket("/ws")
|
| async def websocket_endpoint(websocket: WebSocket):
|
| """WebSocket endpoint for real-time communication"""
|
| await websocket.accept()
|
|
|
|
|
| connection_id = str(uuid.uuid4())
|
| active_connections[connection_id] = websocket
|
|
|
|
|
| trace_id = str(uuid.uuid4())
|
| await websocket.send_json(
|
| {"type": "heartbeat", "uuid": trace_id, "timestamp": datetime.now().isoformat()}
|
| )
|
|
|
| try:
|
| while True:
|
|
|
| data = await websocket.receive_json()
|
| message_type = data.get("type")
|
|
|
| if message_type == "user_task":
|
|
|
| trace = data.get("trace", {})
|
| trace_id = trace.get("id", str(uuid.uuid4()))
|
| instruction = trace.get("instruction", "")
|
| model_id = trace.get("modelId", "microsoft/Fara-7B")
|
|
|
|
|
| session = FaraSession(trace_id, websocket)
|
| active_sessions[trace_id] = session
|
|
|
|
|
| asyncio.create_task(session.run_task(instruction, model_id))
|
|
|
| elif message_type == "stop_task":
|
|
|
| trace_id = data.get("trace_id")
|
| if trace_id and trace_id in active_sessions:
|
| await active_sessions[trace_id].stop()
|
|
|
| elif message_type == "ping":
|
| await websocket.send_json({"type": "pong"})
|
|
|
| except WebSocketDisconnect:
|
| logger.info(f"WebSocket disconnected: {connection_id}")
|
| except Exception as e:
|
| logger.exception(f"WebSocket error: {e}")
|
| finally:
|
|
|
| if connection_id in active_connections:
|
| del active_connections[connection_id]
|
|
|
|
|
| sessions_to_remove = []
|
| for trace_id, session in active_sessions.items():
|
| if session.websocket == websocket:
|
| await session.close()
|
| sessions_to_remove.append(trace_id)
|
| for trace_id in sessions_to_remove:
|
| del active_sessions[trace_id]
|
|
|
|
|
| @app.get("/api/health")
|
| async def health_check():
|
| """Health check endpoint"""
|
| return {"status": "healthy"}
|
|
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
|
|
| uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|