| | --- |
| | base_model: |
| | - Qwen/Qwen2.5-VL-3B-Instruct |
| | language: |
| | - en |
| | license: apache-2.0 |
| | tags: |
| | - gui |
| | - agent |
| | pipeline_tag: image-text-to-text |
| | library_name: transformers |
| | --- |
| | |
| | # InfiGUI-R1-3B |
| |
|
| | This repository contains the model from the [InfiGUI-R1](https://arxiv.org/abs/2504.14239) paper. The model is based on `Qwen2.5-VL-3B-Instruct` and trained using the proposed Actor2Reasoner framework, enhanced through reinforcement learning to improve its planning and reflection capabilities for GUI tasks. |
| |
|
| | ## Quick Start |
| |
|
| | ### Installation |
| | First install required dependencies: |
| | ```bash |
| | pip install transformers qwen-vl-utils |
| | ``` |
| |
|
| | ### An Example of GUI Grounding & Trajectory Task |
| | ```python |
| | import cv2 |
| | import json |
| | import torch |
| | import requests |
| | from PIL import Image |
| | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
| | from qwen_vl_utils import process_vision_info, smart_resize |
| | |
| | MAX_IMAGE_PIXELS = 5600*28*28 |
| | |
| | # Load model and processor |
| | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| | "Reallm-Labs/InfiGUI-R1-3B", |
| | torch_dtype=torch.bfloat16, |
| | attn_implementation="flash_attention_2", |
| | device_map="auto" |
| | ) |
| | processor = AutoProcessor.from_pretrained("Reallm-Labs/InfiGUI-R1-3B", max_pixels=MAX_IMAGE_PIXELS, padding_side="left") |
| | |
| | # Prepare image |
| | img_url = "https://raw.githubusercontent.com/Reallm-Labs/InfiGUI-R1/main/images/test_img.png" |
| | response = requests.get(img_url) |
| | with open("test_img.png", "wb") as f: |
| | f.write(response.content) |
| | image = Image.open("test_img.png") |
| | width, height = image.size |
| | new_height, new_width = smart_resize(height, width, max_pixels=MAX_IMAGE_PIXELS) |
| | |
| | # Prepare inputs |
| | instruction = "View detailed storage space usage" |
| | |
| | system_prompt = "You FIRST think about the reasoning process as an internal monologue and then provide the final answer. |
| | The reasoning process MUST BE enclosed within <think> </think> tags." |
| | tool_prompt = " |
| | |
| | # Tools |
| | |
| | You may call one or more functions to assist with the user query. |
| | |
| | You are provided with function signatures within <tools></tools> XML tags: |
| | <tools> |
| | {\"type\": \"function\", \"function\": {\"name\": \"mobile_use\", \"description\": \"Use a touchscreen to interact with a mobile device, and take screenshots.\ |
| | * This is an interface to a mobile device with touchscreen. You can perform actions like clicking, typing, swiping, etc.\ |
| | * Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions.\ |
| | * The screen's resolution is " + str(new_width) + "x" + str(new_height) + ".\ |
| | * Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\", \"parameters\": {\"properties\": {\"action\": {\"description\": \"The action to perform. The available actions are:\ |
| | * `key`: Perform a key event on the mobile device.\ |
| | - This supports adb's `keyevent` syntax.\ |
| | - Examples: \\\"volume_up\\\", \\\"volume_down\\\", \\\"power\\\", \\\"camera\\\", \\\"clear\\\".\ |
| | * `click`: Click the point on the screen with coordinate (x, y).\ |
| | * `long_press`: Press the point on the screen with coordinate (x, y) for specified seconds.\ |
| | * `swipe`: Swipe from the starting point with coordinate (x, y) to the end point with coordinates2 (x2, y2).\ |
| | * `type`: Input the specified text into the activated input box.\ |
| | * `system_button`: Press the system button.\ |
| | * `open`: Open an app on the device.\ |
| | * `wait`: Wait specified seconds for the change to happen.\ |
| | * `terminate`: Terminate the current task and report its completion status.\", \"enum\": [\"key\", \"click\", \"long_press\", \"swipe\", \"type\", \"system_button\", \"open\", \"wait\", \"terminate\"], \"type\": \"string\"}, \"coordinate\": {\"description\": \"(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=click`, `action=long_press`, and `action=swipe`.\", \"type\": \"array\"}, \"coordinate2\": {\"description\": \"(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=swipe`.\", \"type\": \"array\"}, \"text\": {\"description\": \"Required only by `action=key`, `action=type`, and `action=open`.\", \"type\": \"string\"}, \"time\": {\"description\": \"The seconds to wait. Required only by `action=long_press` and `action=wait`.\", \"type\": \"number\"}, \"button\": {\"description\": \"Back means returning to the previous interface, Home means returning to the desktop, Menu means opening the application background menu, and Enter means pressing the enter. Required only by `action=system_button`\", \"enum\": [\"Back\", \"Home\", \"Menu\", \"Enter\"], \"type\": \"string\"}, \"status\": {\"description\": \"The status of the task. Required only by `action=terminate`.\", \"type\": \"string\", \"enum\": [\"success\", \"failure\"]}}, \"required\": [\"action\"], \"type\": \"object\"}}} |
| | </tools> |
| | |
| | For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: |
| | <tool_call> |
| | {\"name\": <function-name>, \"arguments\": <args-json-object>} |
| | </tool_call>" |
| | grounding_prompt = f'''The screen's resolution is {new_width}x{new_height}. |
| | Point to the UI element most relevant to "{instruction}", output its coordinates using JSON format: |
| | ```json |
| | [ |
| | {{"point_2d": [x, y], "label": "object name/description"}} |
| | ]```''' |
| | trajectory_prompt = f"The user query: {instruction} |
| | Task progress (You have done the following operation on the current device): " |
| | |
| |
|
| | # Build messages |
| | grounding_messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "image": "test_img.png"}, |
| | {"type": "text", "text": grounding_prompt} |
| | ] |
| | } |
| | ] |
| | trajectory_messages = [ |
| | {"role": "system", "content": system_prompt + tool_prompt}, |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "text", "text": trajectory_prompt}, |
| | {"type": "image", "image": "test_img.png"} |
| | ], |
| | }, |
| | ] |
| | messages = [grounding_messages, trajectory_messages] |
| | |
| | # Process and generate |
| | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | image_inputs, video_inputs = process_vision_info(messages) |
| | inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to("cuda") |
| | generated_ids = model.generate(**inputs, max_new_tokens=512) |
| | output_text = processor.batch_decode( |
| | [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)], |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False |
| | ) |
| | |
| | # Visualize results |
| | output_text = [ot.split("</think>")[-1] for ot in output_text] |
| | |
| | grounding_output = output_text[0].replace("```json", "").replace("```", "").strip() |
| | trajectory_output = output_text[1].replace("<tool_call>", "").replace("</tool_call>", "").strip() |
| | |
| | try: |
| | grounding_output = json.loads(grounding_output) |
| | trajectory_output = json.loads(trajectory_output) |
| | |
| | grounding_coords = grounding_output[0]['point_2d'] |
| | trajectory_coords = trajectory_output["arguments"]['coordinate'] if "coordinate" in trajectory_output["arguments"] else None |
| | |
| | grounding_label = grounding_output[0]['label'] |
| | trajectory_label = json.dumps(trajectory_output["arguments"]) |
| | |
| | # Load the original image |
| | img = cv2.imread("test_img.png") |
| | if img is None: |
| | raise ValueError("Could not load the image") |
| | |
| | height, width = img.shape[:2] |
| | |
| | # Create copies for each visualization |
| | grounding_img = img.copy() |
| | trajectory_img = img.copy() |
| | |
| | # Visualize grounding coordinates |
| | if grounding_coords: |
| | x = int(grounding_coords[0] / new_width * width) |
| | y = int(grounding_coords[1] / new_height * height) |
| | |
| | cv2.circle(grounding_img, (x, y), 10, (0, 0, 255), -1) |
| | cv2.putText(grounding_img, grounding_label, (x+10, y-10), |
| | cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2) |
| | cv2.imwrite("grounding_output.png", grounding_img) |
| | print("Predicted coordinates:", grounding_coords) |
| | print(f"Grounding visualization saved to grounding_output.png") |
| | |
| | # Visualize trajectory coordinates |
| | if trajectory_coords: |
| | x = int(trajectory_coords[0] / new_width * width) |
| | y = int(trajectory_coords[1] / new_height * height) |
| | |
| | cv2.circle(trajectory_img, (x, y), 10, (0, 0, 255), -1) |
| | cv2.putText(trajectory_img, trajectory_label, (x+10, y-10), |
| | cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2) |
| | cv2.imwrite("trajectory_output.png", trajectory_img) |
| | print("Predicted action:", trajectory_label) |
| | print(f"Trajectory visualization saved to trajectory_output.png") |
| | |
| | except: |
| | print("Error: Failed to parse coordinates or process image") |
| | ``` |
| | |
| | For more information, please refer to our [repo](https://github.com/Reallm-Labs/InfiGUI-R1). |