s2_5 code
Esse commit está contido em:
@@ -1,3 +0,0 @@
|
||||
[submodule "Perplexica"]
|
||||
path = Perplexica
|
||||
url = https://github.com/ItzCrazyKns/Perplexica
|
||||
-1
Submodule Perplexica deleted from dfb532e4d3
+94
-197
@@ -42,6 +42,7 @@
|
||||
</div>
|
||||
|
||||
## 🥳 Updates
|
||||
- [x] **2025/08/01**: Agent S2.5 is out with new SOTA scores on OSWorld Verified for 100-step and 50-step!
|
||||
- [x] **2025/07/07**: The [Agent S2 paper](https://arxiv.org/abs/2504.00906) is accepted to COLM 2025! See you in Montreal!
|
||||
- [x] **2025/04/01**: Released the [Agent S2 paper](https://arxiv.org/abs/2504.00906) with new SOTA results on OSWorld, WindowsAgentArena, and AndroidWorld!
|
||||
- [x] **2025/03/12**: Released Agent S2 along with v0.2.0 of [gui-agents](https://github.com/simular-ai/Agent-S), the new state-of-the-art for computer use agents (CUA), outperforming OpenAI's CUA/Operator and Anthropic's Claude 3.7 Sonnet Computer-Use!
|
||||
@@ -61,205 +62,125 @@
|
||||
|
||||
## 💡 Introduction
|
||||
|
||||
<p align="center">
|
||||
<img src="./images/agent_s2_teaser.png" width="800">
|
||||
</p>
|
||||
|
||||
Welcome to **Agent S**, an open-source framework designed to enable autonomous interaction with computers through Agent-Computer Interface. Our mission is to build intelligent GUI agents that can learn from past experiences and perform complex tasks autonomously on your computer.
|
||||
|
||||
Whether you're interested in AI, automation, or contributing to cutting-edge agent-based systems, we're excited to have you here!
|
||||
|
||||
## 🎯 Current Results
|
||||
|
||||
<p align="center">
|
||||
<img src="./images/agent_s2_osworld_result.png" width="600">
|
||||
<br>
|
||||
Results of Agent S2's Successful Rate (%) on the OSWorld full test set using Screenshot input only.
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<table border="0" cellspacing="0" cellpadding="5">
|
||||
<tr>
|
||||
<th>Benchmark</th>
|
||||
<th rowspan="2">OSWorld Verified</th>
|
||||
<th colspan="3"> </th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Agent S2.5</th>
|
||||
<th>Agent S2</th>
|
||||
<th>Previous SOTA</th>
|
||||
<th>Δ improve</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>OSWorld (15 step)</td>
|
||||
<td>27.0%</td>
|
||||
<td>22.7% (UI-TARS)</td>
|
||||
<td>+4.3%</td>
|
||||
<td>50 step</td>
|
||||
<td>54.2%</td>
|
||||
<td>45.8%</td>
|
||||
<td>50.6%</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>OSWorld (50 step)</td>
|
||||
<td>34.5%</td>
|
||||
<td>32.6% (OpenAI CUA)</td>
|
||||
<td>+1.9%</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>WindowsAgentArena</td>
|
||||
<td>29.8%</td>
|
||||
<td>19.5% (NAVI)</td>
|
||||
<td>+10.3%</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>AndroidWorld</td>
|
||||
<td>54.3%</td>
|
||||
<td>46.8% (UI-TARS)</td>
|
||||
<td>+7.5%</td>
|
||||
<td>100 step</td>
|
||||
<td>56.0%</td>
|
||||
<td>-</td>
|
||||
<td>53.1%</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
|
||||
## 🛠️ Installation & Setup
|
||||
> **Note**: Our agent returns `pyautogui` code and is intended for a single monitor screen.
|
||||
|
||||
> ❗**Warning**❗: If you are on a Linux machine, creating a `conda` environment will interfere with `pyatspi`. As of now, there's no clean solution for this issue. Proceed through the installation without using `conda` or any virtual environment.
|
||||
### Prerequisites
|
||||
- **Single Monitor**: Our agent is designed for single monitor screens
|
||||
- **Linux Users**: Avoid `conda` environments as they interfere with `pyatspi`
|
||||
- **Security**: The agent runs Python code to control your computer - use with care
|
||||
|
||||
> ⚠️**Disclaimer**⚠️: To leverage the full potential of Agent S2, we utilize [UI-TARS](https://github.com/bytedance/UI-TARS) as a grounding model (7B-DPO or 72B-DPO for better performance). They can be hosted locally, or on Hugging Face Inference Endpoints. Our code supports Hugging Face Inference Endpoints. Check out [Hugging Face Inference Endpoints](https://huggingface.co/learn/cookbook/en/enterprise_dedicated_endpoints) for more information on how to set up and query this endpoint. However, running Agent S2 does not require this model, and you can use alternative API based models for visual grounding, such as Claude.
|
||||
|
||||
Install the package:
|
||||
```
|
||||
### Installation
|
||||
```bash
|
||||
pip install gui-agents
|
||||
```
|
||||
|
||||
Set your LLM API Keys and other environment variables. You can do this by adding the following line to your .bashrc (Linux), or .zshrc (MacOS) file.
|
||||
### API Configuration
|
||||
|
||||
```
|
||||
#### Option 1: Environment Variables
|
||||
Add to your `.bashrc` (Linux) or `.zshrc` (MacOS):
|
||||
```bash
|
||||
export OPENAI_API_KEY=<YOUR_API_KEY>
|
||||
export ANTHROPIC_API_KEY=<YOUR_ANTHROPIC_API_KEY>
|
||||
export HF_TOKEN=<YOUR_HF_TOKEN>
|
||||
```
|
||||
|
||||
Alternatively, you can set the environment variable in your Python script:
|
||||
|
||||
```
|
||||
#### Option 2: Python Script
|
||||
```python
|
||||
import os
|
||||
os.environ["OPENAI_API_KEY"] = "<YOUR_API_KEY>"
|
||||
```
|
||||
|
||||
We also support Azure OpenAI, Anthropic, Gemini, Open Router, and vLLM inference. For more information refer to [models.md](models.md).
|
||||
### Supported Models
|
||||
We support Azure OpenAI, Anthropic, Gemini, Open Router, and vLLM inference. See [models.md](models.md) for details.
|
||||
|
||||
### Setup Retrieval from Web using Perplexica
|
||||
Agent S works best with web-knowledge retrieval. To enable this feature, you need to setup Perplexica:
|
||||
|
||||
1. Ensure Docker Desktop is installed and running on your system.
|
||||
|
||||
2. Navigate to the directory containing the project files.
|
||||
|
||||
```bash
|
||||
cd Perplexica
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
3. Rename the `sample.config.toml` file to `config.toml`. For Docker setups, you need only fill in the following fields:
|
||||
|
||||
- `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**.
|
||||
- `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**.
|
||||
- `GROQ`: Your Groq API key. **You only need to fill this if you wish to use Groq's hosted models**.
|
||||
- `ANTHROPIC`: Your Anthropic API key. **You only need to fill this if you wish to use Anthropic models**.
|
||||
|
||||
**Note**: You can change these after starting Perplexica from the settings dialog.
|
||||
|
||||
- `SIMILARITY_MEASURE`: The similarity measure to use (This is filled by default; you can leave it as is if you are unsure about it.)
|
||||
|
||||
4. Ensure you are in the directory containing the `docker-compose.yaml` file and execute:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
5. Export your Perplexica URL using the port found in the [`docker-compose.yaml`](https://github.com/ItzCrazyKns/Perplexica/blob/master/docker-compose.yaml) file Under `app/ports`, you'll see `3000:3000`. The port is the left-hand number (in this case, 3000).
|
||||
|
||||
```bash
|
||||
export PERPLEXICA_URL=http://localhost:{port}/api/search
|
||||
```
|
||||
6. Our implementation of Agent S incorporates the Perplexica API to integrate a search engine capability, which allows for a more convenient and responsive user experience. If you want to tailor the API to your settings and specific requirements, you may modify the URL and the message of request parameters in `agent_s/query_perplexica.py`. For a comprehensive guide on configuring the Perplexica API, please refer to [Perplexica Search API Documentation](https://github.com/ItzCrazyKns/Perplexica/blob/master/docs/API/SEARCH.md).
|
||||
For a more detailed setup and usage guide, please refer to the [Perplexica Repository](https://github.com/ItzCrazyKns/Perplexica.git).
|
||||
|
||||
> ❗**Warning**❗: The agent will directly run python code to control your computer. Please use with care.
|
||||
### Grounding Models (Required)
|
||||
For optimal performance, we recommend [UI-TARS-1.5-7B](https://huggingface.co/ByteDance-Seed/UI-TARS-1.5-7B) hosted on Hugging Face Inference Endpoints or another provider. See [Hugging Face Inference Endpoints](https://huggingface.co/learn/cookbook/en/enterprise_dedicated_endpoints) for setup instructions.
|
||||
|
||||
## 🚀 Usage
|
||||
|
||||
|
||||
> **Note**: Our best configuration uses Claude 3.7 with extended thinking and UI-TARS-72B-DPO. If you are unable to run UI-TARS-72B-DPO due to resource constraints, UI-TARS-7B-DPO can be used as a lighter alternative with minimal performance degradation.
|
||||
> ⚡️ **Recommended Setup:**
|
||||
> For the best configuration, we recommend using **OpenAI o3-2025-04-16** as the main model, paired with **UI-TARS-1.5-7B** for grounding.
|
||||
|
||||
|
||||
### CLI
|
||||
|
||||
Run Agent S2 with a specific model (default is `gpt-4o`):
|
||||
|
||||
```sh
|
||||
agent_s2 \
|
||||
--provider "anthropic" \
|
||||
--model "claude-3-7-sonnet-20250219" \
|
||||
--grounding_model_provider "anthropic" \
|
||||
--grounding_model "claude-3-7-sonnet-20250219" \
|
||||
```
|
||||
|
||||
Or use a custom endpoint:
|
||||
Run Agent S2.5 with the required parameters:
|
||||
|
||||
```bash
|
||||
agent_s2 \
|
||||
--provider "anthropic" \
|
||||
--model "claude-3-7-sonnet-20250219" \
|
||||
--endpoint_provider "huggingface" \
|
||||
--endpoint_url "<endpoint_url>/v1/"
|
||||
agent_s \
|
||||
--provider openai \
|
||||
--model o3-2025-04-16 \
|
||||
--ground_provider huggingface \
|
||||
--ground_url http://localhost:8080 \
|
||||
--ground_model ui-tars-1.5-7b \
|
||||
--grounding_width 1920 \
|
||||
--grounding_height 1080
|
||||
```
|
||||
|
||||
#### Main Model Settings
|
||||
- **`--provider`**, **`--model`**
|
||||
- Purpose: Specifies the main generation model
|
||||
- Supports: all model providers in [models.md](models.md)
|
||||
- Default: `--provider "anthropic" --model "claude-3-7-sonnet-20250219"`
|
||||
- **`--model_url`**, **`--model_api_key`**
|
||||
- Purpose: Specifies the custom endpoint for the main generation model and your API key
|
||||
- Note: These are optional. If not specified, `gui-agents` will default to your environment variables for the URL and API key.
|
||||
- Supports: all model providers in [models.md](models.md)
|
||||
- Default: None
|
||||
#### Required Parameters
|
||||
- **`--provider`**: Main generation model provider (e.g., openai, anthropic, etc.) - Default: "openai"
|
||||
- **`--model`**: Main generation model name (e.g., o3-2025-04-16) - Default: "o3-2025-04-16"
|
||||
- **`--ground_provider`**: The provider for the grounding model - **Required**
|
||||
- **`--ground_url`**: The URL of the grounding model - **Required**
|
||||
- **`--ground_model`**: The model name for the grounding model - **Required**
|
||||
- **`--grounding_width`**: Width of the output coordinate resolution from the grounding model - **Required**
|
||||
- **`--grounding_height`**: Height of the output coordinate resolution from the grounding model - **Required**
|
||||
|
||||
#### Grounding Configuration Options
|
||||
#### Grounding Model Dimensions
|
||||
The grounding width and height should match the output coordinate resolution of your grounding model:
|
||||
- **UI-TARS-1.5-7B**: Use `--grounding_width 1920 --grounding_height 1080`
|
||||
- **UI-TARS-72B**: Use `--grounding_width 1000 --grounding_height 1000`
|
||||
|
||||
You can use either Configuration 1 or Configuration 2:
|
||||
|
||||
##### **(Default) Configuration 1: API-Based Models**
|
||||
- **`--grounding_model_provider`**, **`--grounding_model`**
|
||||
- Purpose: Specifies the model for visual grounding (coordinate prediction)
|
||||
- Supports: all model providers in [models.md](models.md)
|
||||
- Default: `--grounding_model_provider "anthropic" --grounding_model "claude-3-7-sonnet-20250219"`
|
||||
- ❗**Important**❗ **`--grounding_model_resize_width`**
|
||||
- Purpose: Some API providers automatically rescale images. Therefore, the generated (x, y) will be relative to the rescaled image dimensions, instead of the original image dimensions.
|
||||
- Supports: [Anthropic rescaling](https://docs.anthropic.com/en/docs/build-with-claude/vision#)
|
||||
- Tips: If your grounding is inaccurate even for very simple queries, double check your rescaling width is correct for your machine's resolution.
|
||||
- Default: `--grounding_model_resize_width 1366` (Anthropic)
|
||||
|
||||
##### **Configuration 2: Custom Endpoint**
|
||||
- **`--endpoint_provider`**
|
||||
- Purpose: Specifies the endpoint provider
|
||||
- Supports: HuggingFace TGI, vLLM, Open Router
|
||||
- Default: None
|
||||
|
||||
- **`--endpoint_url`**
|
||||
- Purpose: The URL for your custom endpoint
|
||||
- Default: None
|
||||
|
||||
- **`--endpoint_api_key`**
|
||||
- Purpose: Your API key for your custom endpoint
|
||||
- Note: This is optional. If not specified, `gui-agents` will default to your environment variables for the API key.
|
||||
- Default: None
|
||||
|
||||
> **Note**: Configuration 2 takes precedence over Configuration 1.
|
||||
|
||||
This will show a user query prompt where you can enter your query and interact with Agent S2. You can use any model from the list of supported models in [models.md](models.md).
|
||||
#### Optional Parameters
|
||||
- **`--model_url`**: Custom API URL for main generation model - Default: ""
|
||||
- **`--model_api_key`**: API key for main generation model - Default: ""
|
||||
- **`--ground_api_key`**: API key for grounding model endpoint - Default: ""
|
||||
- **`--max_trajectory_length`**: Maximum number of image turns to keep in trajectory - Default: 8
|
||||
- **`--enable_reflection`**: Enable reflection agent to assist the worker agent - Default: True
|
||||
|
||||
### `gui_agents` SDK
|
||||
|
||||
First, we import the necessary modules. `AgentS2` is the main agent class for Agent S2. `OSWorldACI` is our grounding agent that translates agent actions into executable python code.
|
||||
First, we import the necessary modules. `AgentS2_5` is the main agent class for Agent S2.5. `OSWorldACI` is our grounding agent that translates agent actions into executable python code.
|
||||
```python
|
||||
import pyautogui
|
||||
import io
|
||||
from gui_agents.s2.agents.agent_s import AgentS2
|
||||
from gui_agents.s2.agents.grounding import OSWorldACI
|
||||
from gui_agents.s2_5.agents.agent_s import AgentS2_5
|
||||
from gui_agents.s2_5.agents.grounding import OSWorldACI
|
||||
|
||||
# Load in your API keys.
|
||||
from dotenv import load_dotenv
|
||||
@@ -268,7 +189,7 @@ load_dotenv()
|
||||
current_platform = "linux" # "darwin", "windows"
|
||||
```
|
||||
|
||||
Next, we define our engine parameters. `engine_params` is used for the main agent, and `engine_params_for_grounding` is for grounding. For `engine_params_for_grounding`, we support the Claude, GPT series, and Hugging Face Inference Endpoints.
|
||||
Next, we define our engine parameters. `engine_params` is used for the main agent, and `engine_params_for_grounding` is for grounding. For `engine_params_for_grounding`, we support custom endpoints like HuggingFace TGI, vLLM, and Open Router.
|
||||
|
||||
```python
|
||||
engine_params = {
|
||||
@@ -278,50 +199,45 @@ engine_params = {
|
||||
"api_key": model_api_key, # Optional
|
||||
}
|
||||
|
||||
# Grounding Configuration 1: Load the grounding engine from an API based model
|
||||
grounding_model_provider = "<your_grounding_model_provider>"
|
||||
grounding_model = "<your_grounding_model>"
|
||||
grounding_model_resize_width = 1366
|
||||
screen_width, screen_height = pyautogui.size()
|
||||
# Load the grounding engine from a custom endpoint
|
||||
ground_provider = "<your_ground_provider>"
|
||||
ground_url = "<your_ground_url>"
|
||||
ground_model = "<your_ground_model>"
|
||||
ground_api_key = "<your_ground_api_key>"
|
||||
|
||||
# Set grounding dimensions based on your model's output coordinate resolution
|
||||
# UI-TARS-1.5-7B: grounding_width=1920, grounding_height=1080
|
||||
# UI-TARS-72B: grounding_width=1000, grounding_height=1000
|
||||
grounding_width = 1920 # Width of output coordinate resolution
|
||||
grounding_height = 1080 # Height of output coordinate resolution
|
||||
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": grounding_model_provider,
|
||||
"model": grounding_model,
|
||||
"grounding_width": grounding_model_resize_width,
|
||||
"grounding_height": screen_height
|
||||
* grounding_model_resize_width
|
||||
/ screen_width,
|
||||
}
|
||||
|
||||
# Grounding Configuration 2: Load the grounding engine from a HuggingFace TGI endpoint
|
||||
endpoint_provider = "<your_endpoint_provider>"
|
||||
endpoint_url = "<your_endpoint_url>"
|
||||
endpoint_api_key = "<your_api_key>"
|
||||
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": endpoint_provider,
|
||||
"base_url": endpoint_url,
|
||||
"api_key": endpoint_api_key, # Optional
|
||||
"engine_type": ground_provider,
|
||||
"model": ground_model,
|
||||
"base_url": ground_url,
|
||||
"api_key": ground_api_key, # Optional
|
||||
"grounding_width": grounding_width,
|
||||
"grounding_height": grounding_height,
|
||||
}
|
||||
```
|
||||
|
||||
Then, we define our grounding agent and Agent S2.
|
||||
Then, we define our grounding agent and Agent S2.5.
|
||||
|
||||
```python
|
||||
grounding_agent = OSWorldACI(
|
||||
platform=current_platform,
|
||||
engine_params_for_generation=engine_params,
|
||||
engine_params_for_grounding=engine_params_for_grounding
|
||||
engine_params_for_grounding=engine_params_for_grounding,
|
||||
width=1920, # Optional: screen width
|
||||
height=1080 # Optional: screen height
|
||||
)
|
||||
|
||||
agent = AgentS2(
|
||||
engine_params,
|
||||
grounding_agent,
|
||||
platform=current_platform,
|
||||
action_space="pyautogui",
|
||||
observation_type="screenshot",
|
||||
search_engine="Perplexica", # Assuming you have set up Perplexica.
|
||||
embedding_engine_type="openai" # Supports "gemini", "openai"
|
||||
agent = AgentS2_5(
|
||||
engine_params,
|
||||
grounding_agent,
|
||||
platform=current_platform,
|
||||
max_trajectory_length=8, # Optional: maximum image turns to keep
|
||||
enable_reflection=True # Optional: enable reflection agent
|
||||
)
|
||||
```
|
||||
|
||||
@@ -344,34 +260,15 @@ info, action = agent.predict(instruction=instruction, observation=obs)
|
||||
exec(action[0])
|
||||
```
|
||||
|
||||
Refer to `gui_agents/s2/cli_app.py` for more details on how the inference loop works.
|
||||
|
||||
#### Downloading the Knowledge Base
|
||||
|
||||
Agent S2 uses a knowledge base that continually updates with new knowledge during inference. The knowledge base is initially downloaded when initializing `AgentS2`. The knowledge base is stored as assets under our [GitHub Releases](https://github.com/simular-ai/Agent-S/releases). The `AgentS2` initialization will only download the knowledge base for your specified platform and agent version (e.g s1, s2). If you'd like to download the knowledge base programmatically, you can use the following code:
|
||||
|
||||
```python
|
||||
download_kb_data(
|
||||
version="s2",
|
||||
release_tag="v0.2.2",
|
||||
download_dir="kb_data",
|
||||
platform="linux" # "darwin", "windows"
|
||||
)
|
||||
```
|
||||
|
||||
This will download Agent S2's knowledge base for Linux from release tag `v0.2.2` to the `kb_data` directory. Refer to our [GitHub Releases](https://github.com/simular-ai/Agent-S/releases) or release tags that include the knowledge bases.
|
||||
Refer to `gui_agents/s2_5/cli_app.py` for more details on how the inference loop works.
|
||||
|
||||
### OSWorld
|
||||
|
||||
To deploy Agent S2 in OSWorld, follow the [OSWorld Deployment instructions](osworld_setup/s2/OSWorld.md).
|
||||
|
||||
### WindowsAgentArena
|
||||
|
||||
To deploy Agent S2 in WindowsAgentArena, follow the [WindowsAgentArena Deployment Instructions](WAA_setup.md).
|
||||
To deploy Agent S2.5 in OSWorld, follow the [OSWorld Deployment instructions](osworld_setup/s2_5/OSWorld.md).
|
||||
|
||||
## 💬 Citations
|
||||
|
||||
If you find this codebase useful, please cite
|
||||
If you find this codebase useful, please cite:
|
||||
|
||||
```
|
||||
@misc{Agent-S2,
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
import platform
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from gui_agents.s2_5.agents.grounding import ACI
|
||||
from gui_agents.s2_5.agents.worker import Worker
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class UIAgent:
|
||||
"""Base class for UI automation agents"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_params: Dict,
|
||||
grounding_agent: ACI,
|
||||
platform: str = platform.system().lower(),
|
||||
):
|
||||
"""Initialize UIAgent
|
||||
|
||||
Args:
|
||||
engine_params: Configuration parameters for the LLM engine
|
||||
grounding_agent: Instance of ACI class for UI interaction
|
||||
platform: Operating system platform (macos, linux, windows)
|
||||
"""
|
||||
self.engine_params = engine_params
|
||||
self.grounding_agent = grounding_agent
|
||||
self.platform = platform
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset agent state"""
|
||||
pass
|
||||
|
||||
def predict(self, instruction: str, observation: Dict) -> Tuple[Dict, List[str]]:
|
||||
"""Generate next action prediction
|
||||
|
||||
Args:
|
||||
instruction: Natural language instruction
|
||||
observation: Current UI state observation
|
||||
|
||||
Returns:
|
||||
Tuple containing agent info dictionary and list of actions
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AgentS2_5(UIAgent):
|
||||
"""Agent that uses no hierarchy for less inference time"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_params: Dict,
|
||||
grounding_agent: ACI,
|
||||
platform: str = platform.system().lower(),
|
||||
max_trajectory_length: int = 8,
|
||||
enable_reflection: bool = True,
|
||||
):
|
||||
"""Initialize a minimalist AgentS2 without hierarchy
|
||||
|
||||
Args:
|
||||
engine_params: Configuration parameters for the LLM engine
|
||||
grounding_agent: Instance of ACI class for UI interaction
|
||||
platform: Operating system platform (darwin, linux, windows)
|
||||
max_trajectory_length: Maximum number of image turns to keep
|
||||
enable_reflection: Creates a reflection agent to assist the worker agent
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
engine_params, grounding_agent, platform
|
||||
)
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.enable_reflection = enable_reflection
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset agent state and initialize components"""
|
||||
self.executor = Worker(
|
||||
engine_params=self.engine_params,
|
||||
grounding_agent=self.grounding_agent,
|
||||
platform=self.platform,
|
||||
max_trajectory_length=self.max_trajectory_length,
|
||||
enable_reflection=self.enable_reflection,
|
||||
)
|
||||
|
||||
def predict(self, instruction: str, observation: Dict) -> Tuple[Dict, List[str]]:
|
||||
# Initialize the three info dictionaries
|
||||
executor_info, actions = self.executor.generate_next_action(
|
||||
instruction=instruction, obs=observation
|
||||
)
|
||||
|
||||
# concatenate the three info dictionaries
|
||||
info = {
|
||||
**{
|
||||
k: v
|
||||
for d in [executor_info or {}]
|
||||
for k, v in d.items()
|
||||
}
|
||||
}
|
||||
|
||||
return info, actions
|
||||
@@ -0,0 +1,617 @@
|
||||
import ast
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
from pytesseract import Output
|
||||
|
||||
from gui_agents.s2_5.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from gui_agents.s2_5.core.mllm import LMMAgent
|
||||
from gui_agents.s2_5.utils.common_utils import (
|
||||
call_llm_safe,
|
||||
parse_single_code_from_string,
|
||||
)
|
||||
|
||||
|
||||
class ACI:
|
||||
def __init__(self):
|
||||
self.notes: List[str] = []
|
||||
|
||||
|
||||
# Agent action decorator
|
||||
def agent_action(func):
|
||||
func.is_agent_action = True
|
||||
return func
|
||||
|
||||
|
||||
UBUNTU_APP_SETUP = f"""import subprocess;
|
||||
import difflib;
|
||||
import pyautogui;
|
||||
pyautogui.press('escape');
|
||||
time.sleep(0.5);
|
||||
output = subprocess.check_output(['wmctrl', '-lx']);
|
||||
output = output.decode('utf-8').splitlines();
|
||||
window_titles = [line.split(None, 4)[2] for line in output];
|
||||
closest_matches = difflib.get_close_matches('APP_NAME', window_titles, n=1, cutoff=0.1);
|
||||
if closest_matches:
|
||||
closest_match = closest_matches[0];
|
||||
for line in output:
|
||||
if closest_match in line:
|
||||
window_id = line.split()[0]
|
||||
break;
|
||||
subprocess.run(['wmctrl', '-ia', window_id])
|
||||
subprocess.run(['wmctrl', '-ir', window_id, '-b', 'add,maximized_vert,maximized_horz'])
|
||||
"""
|
||||
|
||||
|
||||
SET_CELL_VALUES_CMD = """import uno
|
||||
import subprocess
|
||||
|
||||
def identify_document_type(component):
|
||||
if component.supportsService("com.sun.star.sheet.SpreadsheetDocument"):
|
||||
return "Calc"
|
||||
|
||||
if component.supportsService("com.sun.star.text.TextDocument"):
|
||||
return "Writer"
|
||||
|
||||
if component.supportsService("com.sun.star.sheet.PresentationDocument"):
|
||||
return "Impress"
|
||||
|
||||
return None
|
||||
|
||||
def cell_ref_to_indices(cell_ref):
|
||||
column_letters = ''.join(filter(str.isalpha, cell_ref))
|
||||
row_number = ''.join(filter(str.isdigit, cell_ref))
|
||||
|
||||
col = sum((ord(char.upper()) - ord('A') + 1) * (26**idx) for idx, char in enumerate(reversed(column_letters))) - 1
|
||||
row = int(row_number) - 1
|
||||
return col, row
|
||||
|
||||
def set_cell_values(new_cell_values: dict[str, str], app_name: str = "Untitled 1", sheet_name: str = "Sheet1"):
|
||||
new_cell_values_idx = {{}}
|
||||
for k, v in new_cell_values.items():
|
||||
try:
|
||||
col, row = cell_ref_to_indices(k)
|
||||
except:
|
||||
col = row = None
|
||||
|
||||
if col is not None and row is not None:
|
||||
new_cell_values_idx[(col, row)] = v
|
||||
|
||||
# Clean up previous TCP connections.
|
||||
subprocess.run(
|
||||
'echo \"osworld-public-evaluation\" | sudo -S ss --kill --tcp state TIME-WAIT sport = :2002',
|
||||
shell=True,
|
||||
check=True,
|
||||
text=True,
|
||||
capture_output=True
|
||||
)
|
||||
|
||||
# Dynamically allow soffice to listen on port 2002.
|
||||
subprocess.run(
|
||||
[
|
||||
"soffice",
|
||||
"--accept=socket,host=localhost,port=2002;urp;StarOffice.Service"
|
||||
]
|
||||
)
|
||||
|
||||
local_context = uno.getComponentContext()
|
||||
resolver = local_context.ServiceManager.createInstanceWithContext(
|
||||
"com.sun.star.bridge.UnoUrlResolver", local_context
|
||||
)
|
||||
context = resolver.resolve(
|
||||
f"uno:socket,host=localhost,port=2002;urp;StarOffice.ComponentContext"
|
||||
)
|
||||
desktop = context.ServiceManager.createInstanceWithContext(
|
||||
"com.sun.star.frame.Desktop", context
|
||||
)
|
||||
|
||||
# Collect all LibreOffice-related opened windows.
|
||||
documents = []
|
||||
for i, component in enumerate(desktop.Components):
|
||||
title = component.Title
|
||||
doc_type = identify_document_type(component)
|
||||
documents.append((i, component, title, doc_type))
|
||||
|
||||
# Find the LibreOffice Calc app and the sheet of interest.
|
||||
spreadsheet = [doc for doc in documents if doc[3] == "Calc"]
|
||||
selected_spreadsheet = [doc for doc in spreadsheet if doc[2] == app_name]
|
||||
if spreadsheet:
|
||||
try:
|
||||
if selected_spreadsheet:
|
||||
spreadsheet = selected_spreadsheet[0][1]
|
||||
else:
|
||||
spreadsheet = spreadsheet[0][1]
|
||||
|
||||
sheet = spreadsheet.Sheets.getByName(sheet_name)
|
||||
except:
|
||||
raise ValueError(f"Could not find sheet {{sheet_name}} in {{app_name}}.")
|
||||
|
||||
for (col, row), value in new_cell_values_idx.items():
|
||||
cell = sheet.getCellByPosition(col, row)
|
||||
|
||||
# Set the cell value.
|
||||
if isinstance(value, (int, float)):
|
||||
cell.Value = value
|
||||
elif isinstance(value, str):
|
||||
if value.startswith("="):
|
||||
cell.Formula = value
|
||||
else:
|
||||
cell.String = value
|
||||
elif isinstance(value, bool):
|
||||
cell.Value = 1 if value else 0
|
||||
elif value is None:
|
||||
cell.clearContents(0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported cell value type: {{type(value)}}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Could not find LibreOffice Calc app corresponding to {{app_name}}.")
|
||||
|
||||
set_cell_values(new_cell_values={cell_values}, app_name="{app_name}", sheet_name="{sheet_name}")
|
||||
"""
|
||||
|
||||
|
||||
# ACI primitives are parameterized by description, and coordinate generation uses a pretrained grounding model
|
||||
class OSWorldACI(ACI):
|
||||
def __init__(
|
||||
self,
|
||||
platform: str,
|
||||
engine_params_for_generation: Dict,
|
||||
engine_params_for_grounding: Dict,
|
||||
width: int = 1920,
|
||||
height: int = 1080,
|
||||
):
|
||||
self.platform = (
|
||||
platform # Dictates how the switch_applications agent action works.
|
||||
)
|
||||
|
||||
# Configure scaling
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
# Maintain state for save_to_knowledge
|
||||
self.notes = []
|
||||
|
||||
# Coordinates used during ACI execution
|
||||
self.coords1 = None
|
||||
self.coords2 = None
|
||||
|
||||
# Configure the visual grounding model responsible for coordinate generation
|
||||
self.grounding_model = LMMAgent(engine_params_for_grounding)
|
||||
self.engine_params_for_grounding = engine_params_for_grounding
|
||||
|
||||
# Configure text grounding agent
|
||||
self.text_span_agent = LMMAgent(
|
||||
engine_params=engine_params_for_generation,
|
||||
system_prompt=PROCEDURAL_MEMORY.PHRASE_TO_WORD_COORDS_PROMPT,
|
||||
)
|
||||
|
||||
# Given the state and worker's referring expression, use the grounding model to generate (x,y)
|
||||
def generate_coords(self, ref_expr: str, obs: Dict) -> List[int]:
|
||||
|
||||
# Reset the grounding model state
|
||||
self.grounding_model.reset()
|
||||
|
||||
# Configure the context, UI-TARS demo does not use system prompt
|
||||
prompt = f"Query:{ref_expr}\nOutput only the coordinate of one point in your response.\n"
|
||||
self.grounding_model.add_message(
|
||||
text_content=prompt, image_content=obs["screenshot"], put_text_last=True
|
||||
)
|
||||
|
||||
# Generate and parse coordinates
|
||||
response = call_llm_safe(self.grounding_model)
|
||||
print("RAW GROUNDING MODEL RESPONSE:", response)
|
||||
numericals = re.findall(r"\d+", response)
|
||||
assert len(numericals) >= 2
|
||||
return [int(numericals[0]), int(numericals[1])]
|
||||
|
||||
# Calls pytesseract to generate word level bounding boxes for text grounding
|
||||
def get_ocr_elements(self, b64_image_data: str) -> Tuple[str, List]:
|
||||
image = Image.open(BytesIO(b64_image_data))
|
||||
image_data = pytesseract.image_to_data(image, output_type=Output.DICT)
|
||||
|
||||
# Clean text by removing leading and trailing spaces and non-alphabetical characters, but keeping punctuation
|
||||
for i, word in enumerate(image_data["text"]):
|
||||
image_data["text"][i] = re.sub(
|
||||
r"^[^a-zA-Z\s.,!?;:\-\+]+|[^a-zA-Z\s.,!?;:\-\+]+$", "", word
|
||||
)
|
||||
|
||||
ocr_elements = []
|
||||
ocr_table = "Text Table:\nWord id\tText\n"
|
||||
# Obtain the <id, text, group number, word number> for each valid element
|
||||
grouping_map = defaultdict(list)
|
||||
ocr_id = 0
|
||||
for i in range(len(image_data["text"])):
|
||||
block_num = image_data["block_num"][i]
|
||||
if image_data["text"][i]:
|
||||
grouping_map[block_num].append(image_data["text"][i])
|
||||
ocr_table += f"{ocr_id}\t{image_data['text'][i]}\n"
|
||||
ocr_elements.append(
|
||||
{
|
||||
"id": ocr_id,
|
||||
"text": image_data["text"][i],
|
||||
"group_num": block_num,
|
||||
"word_num": len(grouping_map[block_num]),
|
||||
"left": image_data["left"][i],
|
||||
"top": image_data["top"][i],
|
||||
"width": image_data["width"][i],
|
||||
"height": image_data["height"][i],
|
||||
}
|
||||
)
|
||||
ocr_id += 1
|
||||
|
||||
return ocr_table, ocr_elements
|
||||
|
||||
# Given the state and worker's text phrase, generate the coords of the first/last word in the phrase
|
||||
def generate_text_coords(
|
||||
self, phrase: str, obs: Dict, alignment: str = ""
|
||||
) -> List[int]:
|
||||
|
||||
ocr_table, ocr_elements = self.get_ocr_elements(obs["screenshot"])
|
||||
|
||||
alignment_prompt = ""
|
||||
if alignment == "start":
|
||||
alignment_prompt = "**Important**: Output the word id of the FIRST word in the provided phrase.\n"
|
||||
elif alignment == "end":
|
||||
alignment_prompt = "**Important**: Output the word id of the LAST word in the provided phrase.\n"
|
||||
|
||||
# Load LLM prompt
|
||||
self.text_span_agent.reset()
|
||||
self.text_span_agent.add_message(
|
||||
alignment_prompt + "Phrase: " + phrase + "\n" + ocr_table, role="user"
|
||||
)
|
||||
self.text_span_agent.add_message(
|
||||
"Screenshot:\n", image_content=obs["screenshot"], role="user"
|
||||
)
|
||||
|
||||
# Obtain the target element
|
||||
response = call_llm_safe(self.text_span_agent)
|
||||
print("TEXT SPAN AGENT RESPONSE:", response)
|
||||
numericals = re.findall(r"\d+", response)
|
||||
if len(numericals) > 0:
|
||||
text_id = int(numericals[-1])
|
||||
else:
|
||||
text_id = 0
|
||||
elem = ocr_elements[text_id]
|
||||
|
||||
# Compute the element coordinates
|
||||
if alignment == "start":
|
||||
coords = [elem["left"], elem["top"] + (elem["height"] // 2)]
|
||||
elif alignment == "end":
|
||||
coords = [elem["left"] + elem["width"], elem["top"] + (elem["height"] // 2)]
|
||||
else:
|
||||
coords = [
|
||||
elem["left"] + (elem["width"] // 2),
|
||||
elem["top"] + (elem["height"] // 2),
|
||||
]
|
||||
return coords
|
||||
|
||||
# Takes a description based action and assigns the coordinates for any coordinate based action
|
||||
# Raises an error if function can't be parsed
|
||||
def assign_coordinates(self, plan: str, obs: Dict):
|
||||
|
||||
# Reset coords from previous action generation
|
||||
self.coords1, self.coords2 = None, None
|
||||
|
||||
try:
|
||||
# Extract the function name and args
|
||||
action = parse_single_code_from_string(plan.split("Grounded Action")[-1])
|
||||
function_name = re.match(r"(\w+\.\w+)\(", action).group(1)
|
||||
args = self.parse_function_args(action)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in parsing grounded action: {e}") from e
|
||||
|
||||
# arg0 is a description
|
||||
if (
|
||||
function_name in ["agent.click", "agent.type", "agent.scroll"]
|
||||
and len(args) >= 1
|
||||
and args[0] != None
|
||||
):
|
||||
self.coords1 = self.generate_coords(args[0], obs)
|
||||
# arg0 and arg1 are descriptions
|
||||
elif function_name == "agent.drag_and_drop" and len(args) >= 2:
|
||||
self.coords1 = self.generate_coords(args[0], obs)
|
||||
self.coords2 = self.generate_coords(args[1], obs)
|
||||
# arg0 and arg1 are text phrases
|
||||
elif function_name == "agent.highlight_text_span" and len(args) >= 2:
|
||||
self.coords1 = self.generate_text_coords(args[0], obs, alignment="start")
|
||||
self.coords2 = self.generate_text_coords(args[1], obs, alignment="end")
|
||||
|
||||
# Resize from grounding model dim into OSWorld dim (1920 * 1080)
|
||||
def resize_coordinates(self, coordinates: List[int]) -> List[int]:
|
||||
grounding_width = self.engine_params_for_grounding["grounding_width"]
|
||||
grounding_height = self.engine_params_for_grounding["grounding_height"]
|
||||
|
||||
return [
|
||||
round(coordinates[0] * self.width / grounding_width),
|
||||
round(coordinates[1] * self.height / grounding_height),
|
||||
]
|
||||
|
||||
# Given a generated ACI function, returns a list of argument values, where descriptions are at the front of the list
|
||||
def parse_function_args(self, function: str) -> List[str]:
|
||||
tree = ast.parse(function)
|
||||
call_node = tree.body[0].value
|
||||
|
||||
def safe_eval(node):
|
||||
if isinstance(
|
||||
node, ast.Constant
|
||||
): # Handles literals like numbers, strings, etc.
|
||||
return node.value
|
||||
else:
|
||||
return ast.unparse(node) # Return as a string if not a literal
|
||||
|
||||
positional_args = [safe_eval(arg) for arg in call_node.args]
|
||||
keyword_args = {kw.arg: safe_eval(kw.value) for kw in call_node.keywords}
|
||||
|
||||
res = []
|
||||
|
||||
for key, val in keyword_args.items():
|
||||
if "description" in key:
|
||||
res.append(val)
|
||||
|
||||
for arg in positional_args:
|
||||
res.append(arg)
|
||||
|
||||
return res
|
||||
|
||||
@agent_action
|
||||
def click(
|
||||
self,
|
||||
element_description: str,
|
||||
num_clicks: int = 1,
|
||||
button_type: str = "left",
|
||||
hold_keys: List = [],
|
||||
):
|
||||
"""Click on the element
|
||||
Args:
|
||||
element_description:str, a detailed descriptions of which element to click on. This description should be at least a full sentence.
|
||||
num_clicks:int, number of times to click the element
|
||||
button_type:str, which mouse button to press can be "left", "middle", or "right"
|
||||
hold_keys:List, list of keys to hold while clicking
|
||||
"""
|
||||
x, y = self.resize_coordinates(self.coords1)
|
||||
command = "import pyautogui; "
|
||||
|
||||
# TODO: specified duration?
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
# Return pyautoguicode to click on the element
|
||||
return command
|
||||
|
||||
@agent_action
|
||||
def switch_applications(self, app_code):
|
||||
"""Switch to a different application that is already open
|
||||
Args:
|
||||
app_code:str the code name of the application to switch to from the provided list of open applications
|
||||
"""
|
||||
if self.platform == "darwin":
|
||||
return f"import pyautogui; import time; pyautogui.hotkey('command', 'space', interval=0.5); pyautogui.typewrite({repr(app_code)}); pyautogui.press('enter'); time.sleep(1.0)"
|
||||
elif self.platform == "linux":
|
||||
return UBUNTU_APP_SETUP.replace("APP_NAME", app_code)
|
||||
elif self.platform == "windows":
|
||||
return f"import pyautogui; import time; pyautogui.hotkey('win', 'd', interval=0.5); pyautogui.typewrite({repr(app_code)}); pyautogui.press('enter'); time.sleep(1.0)"
|
||||
|
||||
@agent_action
|
||||
def open(self, app_or_filename: str):
|
||||
"""Open any application or file with name app_or_filename. Use this action to open applications or files on the desktop, do not open manually.
|
||||
Args:
|
||||
app_or_filename:str, the name of the application or filename to open
|
||||
"""
|
||||
if self.platform == "linux":
|
||||
return f"import pyautogui; pyautogui.hotkey('win'); time.sleep(0.5); pyautogui.write({repr(app_or_filename)}); time.sleep(1.0); pyautogui.hotkey('enter'); time.sleep(0.5)"
|
||||
elif self.platform == "darwin":
|
||||
return f"import pyautogui; import time; pyautogui.hotkey('command', 'space', interval=0.5); pyautogui.typewrite({repr(app_or_filename)}); pyautogui.press('enter'); time.sleep(1.0)"
|
||||
|
||||
@agent_action
|
||||
def type(
|
||||
self,
|
||||
element_description: Optional[str] = None,
|
||||
text: str = "",
|
||||
overwrite: bool = False,
|
||||
enter: bool = False,
|
||||
):
|
||||
"""Type text into a specific element
|
||||
Args:
|
||||
element_description:str, a detailed description of which element to enter text in. This description should be at least a full sentence.
|
||||
text:str, the text to type
|
||||
overwrite:bool, Assign it to True if the text should overwrite the existing text, otherwise assign it to False. Using this argument clears all text in an element.
|
||||
enter:bool, Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
|
||||
"""
|
||||
|
||||
select_mod = "command" if self.platform == "darwin" else "ctrl"
|
||||
|
||||
if self.coords1 is not None:
|
||||
# If a node is found, retrieve its coordinates and size
|
||||
# Start typing at the center of the element
|
||||
|
||||
x, y = self.resize_coordinates(self.coords1)
|
||||
|
||||
command = "import pyautogui; "
|
||||
command += f"pyautogui.click({x}, {y}); "
|
||||
|
||||
if overwrite:
|
||||
command += (
|
||||
f"pyautogui.hotkey({repr(select_mod)}, 'a'); "
|
||||
"pyautogui.press('backspace'); "
|
||||
)
|
||||
|
||||
command += f"pyautogui.write({repr(text)}); "
|
||||
|
||||
if enter:
|
||||
command += "pyautogui.press('enter'); "
|
||||
else:
|
||||
# If no element is found, start typing at the current cursor location
|
||||
command = "import pyautogui; "
|
||||
|
||||
if overwrite:
|
||||
command += (
|
||||
f"pyautogui.hotkey({repr(select_mod)}, 'a'); "
|
||||
"pyautogui.press('backspace'); "
|
||||
)
|
||||
|
||||
command += f"pyautogui.write({repr(text)}); "
|
||||
|
||||
if enter:
|
||||
command += "pyautogui.press('enter'); "
|
||||
|
||||
return command
|
||||
|
||||
@agent_action
|
||||
def save_to_knowledge(self, text: List[str]):
|
||||
"""Save facts, elements, texts, etc. to a long-term knowledge bank for reuse during this task. Can be used for copy-pasting text, saving elements, etc.
|
||||
Args:
|
||||
text:List[str] the text to save to the knowledge
|
||||
"""
|
||||
self.notes.extend(text)
|
||||
return """WAIT"""
|
||||
|
||||
@agent_action
|
||||
def drag_and_drop(
|
||||
self, starting_description: str, ending_description: str, hold_keys: List = []
|
||||
):
|
||||
"""Drag from the starting description to the ending description
|
||||
Args:
|
||||
starting_description:str, a very detailed description of where to start the drag action. This description should be at least a full sentence.
|
||||
ending_description:str, a very detailed description of where to end the drag action. This description should be at least a full sentence.
|
||||
hold_keys:List list of keys to hold while dragging
|
||||
"""
|
||||
x1, y1 = self.resize_coordinates(self.coords1)
|
||||
x2, y2 = self.resize_coordinates(self.coords2)
|
||||
|
||||
command = "import pyautogui; "
|
||||
|
||||
command += f"pyautogui.moveTo({x1}, {y1}); "
|
||||
# TODO: specified duration?
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"pyautogui.dragTo({x2}, {y2}, duration=1., button='left'); pyautogui.mouseUp(); "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
|
||||
# Return pyautoguicode to drag and drop the elements
|
||||
|
||||
return command
|
||||
|
||||
@agent_action
|
||||
def highlight_text_span(
|
||||
self, starting_phrase: str, ending_phrase: str, button: str = "left"
|
||||
):
|
||||
"""Highlight a text span between a provided starting phrase and ending phrase. Use this to highlight words, lines, and paragraphs.
|
||||
Args:
|
||||
starting_phrase:str, the phrase that denotes the start of the text span you want to highlight. If you only want to highlight one word, just pass in that single word.
|
||||
ending_phrase:str, the phrase that denotes the end of the text span you want to highlight. If you only want to highlight one word, just pass in that single word.
|
||||
button:str, the button to use to highlight the text span. Defaults to "left". Can be "left", "right", or "middle".
|
||||
"""
|
||||
|
||||
x1, y1 = self.coords1
|
||||
x2, y2 = self.coords2
|
||||
|
||||
command = "import pyautogui; "
|
||||
command += f"pyautogui.moveTo({x1}, {y1}); "
|
||||
command += f"pyautogui.dragTo({x2}, {y2}, duration=1., button='{button}'); pyautogui.mouseUp(); "
|
||||
|
||||
# Return pyautoguicode to drag and drop the elements
|
||||
return command
|
||||
|
||||
@agent_action
|
||||
def set_cell_values(
|
||||
self, cell_values: Dict[str, Any], app_name: str, sheet_name: str
|
||||
):
|
||||
"""Use this to set individual cell values in a spreadsheet. For example, setting A2 to "hello" would be done by passing {"A2": "hello"} as cell_values. The sheet must be opened before this command can be used.
|
||||
Args:
|
||||
cell_values: Dict[str, Any], A dictionary of cell values to set in the spreadsheet. The keys are the cell coordinates in the format "A1", "B2", etc.
|
||||
Supported value types include: float, int, string, bool, formulas.
|
||||
app_name: str, The name of the spreadsheet application. For example, "Some_sheet.xlsx".
|
||||
sheet_name: str, The name of the sheet in the spreadsheet. For example, "Sheet1".
|
||||
"""
|
||||
return SET_CELL_VALUES_CMD.format(
|
||||
cell_values=cell_values, app_name=app_name, sheet_name=sheet_name
|
||||
)
|
||||
|
||||
@agent_action
|
||||
def scroll(self, element_description: str, clicks: int, shift: bool = False):
|
||||
"""Scroll the element in the specified direction
|
||||
Args:
|
||||
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
|
||||
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
|
||||
shift:bool, whether to use shift+scroll for horizontal scrolling
|
||||
"""
|
||||
|
||||
x, y = self.resize_coordinates(self.coords1)
|
||||
|
||||
if shift:
|
||||
return f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})"
|
||||
else:
|
||||
return f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})"
|
||||
|
||||
@agent_action
|
||||
def hotkey(self, keys: List):
|
||||
"""Press a hotkey combination
|
||||
Args:
|
||||
keys:List the keys to press in combination in a list format (e.g. ['ctrl', 'c'])
|
||||
"""
|
||||
# add quotes around the keys
|
||||
keys = [f"'{key}'" for key in keys]
|
||||
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)})"
|
||||
|
||||
@agent_action
|
||||
def hold_and_press(self, hold_keys: List, press_keys: List):
|
||||
"""Hold a list of keys and press a list of keys
|
||||
Args:
|
||||
hold_keys:List, list of keys to hold
|
||||
press_keys:List, list of keys to press in a sequence
|
||||
"""
|
||||
|
||||
press_keys_str = "[" + ", ".join([f"'{key}'" for key in press_keys]) + "]"
|
||||
command = "import pyautogui; "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"pyautogui.press({press_keys_str}); "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
|
||||
return command
|
||||
|
||||
@agent_action
|
||||
def wait(self, time: float):
|
||||
"""Wait for a specified amount of time
|
||||
Args:
|
||||
time:float the amount of time to wait in seconds
|
||||
"""
|
||||
return f"""import time; time.sleep({time})"""
|
||||
|
||||
@agent_action
|
||||
def done(
|
||||
self,
|
||||
return_value: Optional[Union[Dict, str, List, Tuple, int, float, bool]] = None,
|
||||
):
|
||||
"""End the current task with a success and the required return value"""
|
||||
self.returned_info = return_value
|
||||
return """DONE"""
|
||||
|
||||
@agent_action
|
||||
def fail(self):
|
||||
"""End the current task with a failure, and replan the whole task."""
|
||||
return """FAIL"""
|
||||
|
||||
|
||||
# ACI that supports the worker-only mode: done() and fail() become task scoped instead
|
||||
class OSWorldWorkerOnlyACI(OSWorldACI):
|
||||
@agent_action
|
||||
def done(
|
||||
self,
|
||||
):
|
||||
"""End the current task with a success. Use this when you believe the entire task has been fully completed."""
|
||||
return """DONE"""
|
||||
|
||||
@agent_action
|
||||
def fail(self):
|
||||
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete."""
|
||||
return """FAIL"""
|
||||
@@ -0,0 +1,206 @@
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from gui_agents.s2_5.agents.grounding import ACI
|
||||
from gui_agents.s2_5.core.module import BaseModule
|
||||
from gui_agents.s2_5.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from gui_agents.s2_5.utils.common_utils import (
|
||||
call_llm_safe,
|
||||
extract_first_agent_function,
|
||||
parse_single_code_from_string,
|
||||
sanitize_code,
|
||||
split_thinking_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class Worker(BaseModule):
|
||||
def __init__(
|
||||
self,
|
||||
engine_params: Dict,
|
||||
grounding_agent: ACI,
|
||||
platform: str = "ubuntu",
|
||||
max_trajectory_length: int = 8,
|
||||
enable_reflection: bool = True,
|
||||
):
|
||||
"""
|
||||
Worker receives the main task and generates actions, without the need of hierarchical planning
|
||||
Args:
|
||||
engine_params: Dict
|
||||
Parameters for the multimodal engine
|
||||
grounding_agent: Agent
|
||||
The grounding agent to use
|
||||
platform: str
|
||||
OS platform the agent runs on (darwin, linux, windows)
|
||||
max_trajectory_length: int
|
||||
The amount of images turns to keep
|
||||
enable_reflection: bool
|
||||
Whether to enable reflection
|
||||
"""
|
||||
super().__init__(engine_params, platform)
|
||||
|
||||
self.grounding_agent = grounding_agent
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.enable_reflection = enable_reflection
|
||||
self.temperature = engine_params.get("temperature", 0.0)
|
||||
self.use_thinking = engine_params.get("model", "") in [
|
||||
"claude-3-7-sonnet-20250219"
|
||||
]
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
if self.platform != "linux":
|
||||
skipped_actions = ["set_cell_values"]
|
||||
else:
|
||||
skipped_actions = []
|
||||
|
||||
sys_prompt = PROCEDURAL_MEMORY.construct_simple_worker_procedural_memory(
|
||||
type(self.grounding_agent), skipped_actions=skipped_actions
|
||||
).replace("CURRENT_OS", self.platform)
|
||||
|
||||
self.generator_agent = self._create_agent(sys_prompt)
|
||||
self.reflection_agent = self._create_agent(
|
||||
PROCEDURAL_MEMORY.REFLECTION_ON_TRAJECTORY
|
||||
)
|
||||
|
||||
self.turn_count = 0
|
||||
self.worker_history = []
|
||||
self.reflections = []
|
||||
self.cost_this_turn = 0
|
||||
self.screenshot_inputs = []
|
||||
|
||||
# Flushing strategy dependant on model context limits
|
||||
def flush_messages(self):
|
||||
engine_type = self.engine_params.get("engine_type", "")
|
||||
|
||||
# Flush strategy for long-context models: keep all text, only keep latest images
|
||||
if engine_type in ["anthropic", "openai", "gemini"]:
|
||||
max_images = self.max_trajectory_length
|
||||
for agent in [self.generator_agent, self.reflection_agent]:
|
||||
# keep latest k images
|
||||
img_count = 0
|
||||
for i in range(len(agent.messages) - 1, -1, -1):
|
||||
for j in range(len(agent.messages[i]["content"])):
|
||||
if "image" in agent.messages[i]["content"][j].get("type", ""):
|
||||
img_count += 1
|
||||
if img_count > max_images:
|
||||
del agent.messages[i]["content"][j]
|
||||
|
||||
# Flush strategy for non-long-context models: drop full turns
|
||||
else:
|
||||
# generator msgs are alternating [user, assistant], so 2 per round
|
||||
if len(self.generator_agent.messages) > 2 * self.max_trajectory_length + 1:
|
||||
self.generator_agent.messages.pop(1)
|
||||
self.generator_agent.messages.pop(1)
|
||||
# reflector msgs are all [(user text, user image)], so 1 per round
|
||||
if len(self.reflection_agent.messages) > self.max_trajectory_length + 1:
|
||||
self.reflection_agent.messages.pop(1)
|
||||
|
||||
def generate_next_action(
|
||||
self,
|
||||
instruction: str,
|
||||
obs: Dict,
|
||||
) -> Tuple[Dict, List]:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
agent = self.grounding_agent
|
||||
generator_message = (
|
||||
""
|
||||
if self.turn_count > 0
|
||||
else "The initial screen is provided. No action has been taken yet."
|
||||
)
|
||||
|
||||
# Load the task into the system prompt
|
||||
if self.turn_count == 0:
|
||||
self.generator_agent.add_system_prompt(
|
||||
self.generator_agent.system_prompt.replace(
|
||||
"TASK_DESCRIPTION", instruction
|
||||
)
|
||||
)
|
||||
|
||||
# Get the per-step reflection
|
||||
reflection = None
|
||||
reflection_thoughts = None
|
||||
if self.enable_reflection:
|
||||
# Load the initial message
|
||||
if self.turn_count == 0:
|
||||
text_content = textwrap.dedent(
|
||||
f"""
|
||||
Task Description: {instruction}
|
||||
Current Trajectory below:
|
||||
"""
|
||||
)
|
||||
updated_sys_prompt = (
|
||||
self.reflection_agent.system_prompt + "\n" + text_content
|
||||
)
|
||||
self.reflection_agent.add_system_prompt(updated_sys_prompt)
|
||||
self.reflection_agent.add_message(
|
||||
text_content="The initial screen is provided. No action has been taken yet.",
|
||||
image_content=obs["screenshot"],
|
||||
role="user",
|
||||
)
|
||||
# Load the latest action
|
||||
else:
|
||||
self.reflection_agent.add_message(
|
||||
text_content=self.worker_history[-1],
|
||||
image_content=obs["screenshot"],
|
||||
role="user",
|
||||
)
|
||||
full_reflection = call_llm_safe(
|
||||
self.reflection_agent,
|
||||
temperature=self.temperature,
|
||||
use_thinking=self.use_thinking,
|
||||
)
|
||||
reflection, reflection_thoughts = split_thinking_response(
|
||||
full_reflection
|
||||
)
|
||||
self.reflections.append(reflection)
|
||||
generator_message += f"REFLECTION: You may use this reflection on the previous action and overall trajectory:\n{reflection}\n"
|
||||
logger.info("REFLECTION: %s", reflection)
|
||||
|
||||
# Add finalized message to conversation
|
||||
generator_message += f"\nCurrent Text Buffer = [{','.join(agent.notes)}]\n"
|
||||
self.generator_agent.add_message(
|
||||
generator_message, image_content=obs["screenshot"], role="user"
|
||||
)
|
||||
|
||||
full_plan = call_llm_safe(
|
||||
self.generator_agent,
|
||||
temperature=self.temperature,
|
||||
use_thinking=self.use_thinking,
|
||||
)
|
||||
plan, plan_thoughts = split_thinking_response(full_plan)
|
||||
# NOTE: currently dropping thinking tokens from context
|
||||
self.worker_history.append(plan)
|
||||
logger.info("FULL PLAN:\n %s", full_plan)
|
||||
self.generator_agent.add_message(plan, role="assistant")
|
||||
|
||||
# Use the grounding agent to convert agent_action("desc") into agent_action([x, y])
|
||||
try:
|
||||
agent.assign_coordinates(plan, obs)
|
||||
plan_code = parse_single_code_from_string(plan.split("Grounded Action")[-1])
|
||||
plan_code = sanitize_code(plan_code)
|
||||
plan_code = extract_first_agent_function(plan_code)
|
||||
exec_code = eval(plan_code)
|
||||
except Exception as e:
|
||||
logger.error("Error in parsing plan code: %s", e)
|
||||
plan_code = "agent.wait(1.0)"
|
||||
exec_code = eval(plan_code)
|
||||
|
||||
executor_info = {
|
||||
"full_plan": full_plan,
|
||||
"executor_plan": plan,
|
||||
"plan_thoughts": plan_thoughts,
|
||||
"plan_code": plan_code,
|
||||
"reflection": reflection,
|
||||
"reflection_thoughts": reflection_thoughts,
|
||||
}
|
||||
self.turn_count += 1
|
||||
|
||||
self.screenshot_inputs.append(obs["screenshot"])
|
||||
self.flush_messages()
|
||||
|
||||
return executor_info, [exec_code]
|
||||
@@ -0,0 +1,276 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import pyautogui
|
||||
import sys
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from gui_agents.s2_5.agents.grounding import OSWorldACI
|
||||
from gui_agents.s2_5.agents.agent_s import AgentS2_5
|
||||
|
||||
current_platform = platform.system().lower()
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
log_dir = "logs"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
|
||||
platform_os = platform.system()
|
||||
|
||||
|
||||
def show_permission_dialog(code: str, action_description: str):
|
||||
"""Show a platform-specific permission dialog and return True if approved."""
|
||||
if platform.system() == "Darwin":
|
||||
result = os.system(
|
||||
f'osascript -e \'display dialog "Do you want to execute this action?\n\n{code} which will try to {action_description}" with title "Action Permission" buttons {{"Cancel", "OK"}} default button "OK" cancel button "Cancel"\''
|
||||
)
|
||||
return result == 0
|
||||
elif platform.system() == "Linux":
|
||||
result = os.system(
|
||||
f'zenity --question --title="Action Permission" --text="Do you want to execute this action?\n\n{code}" --width=400 --height=200'
|
||||
)
|
||||
return result == 0
|
||||
return False
|
||||
|
||||
|
||||
def scale_screen_dimensions(width: int, height: int, max_dim_size: int):
|
||||
scale_factor = min(max_dim_size / width, max_dim_size / height, 1)
|
||||
safe_width = int(width * scale_factor)
|
||||
safe_height = int(height * scale_factor)
|
||||
return safe_width, safe_height
|
||||
|
||||
|
||||
def run_agent(agent, instruction: str, scaled_width: int, scaled_height: int):
|
||||
obs = {}
|
||||
traj = "Task:\n" + instruction
|
||||
subtask_traj = ""
|
||||
for _ in range(15):
|
||||
# Get screen shot using pyautogui
|
||||
screenshot = pyautogui.screenshot()
|
||||
screenshot = screenshot.resize((scaled_width, scaled_height), Image.LANCZOS)
|
||||
|
||||
# Save the screenshot to a BytesIO object
|
||||
buffered = io.BytesIO()
|
||||
screenshot.save(buffered, format="PNG")
|
||||
|
||||
# Get the byte value of the screenshot
|
||||
screenshot_bytes = buffered.getvalue()
|
||||
# Convert to base64 string.
|
||||
obs["screenshot"] = screenshot_bytes
|
||||
|
||||
# Get next action code from the agent
|
||||
info, code = agent.predict(instruction=instruction, observation=obs)
|
||||
|
||||
if "done" in code[0].lower() or "fail" in code[0].lower():
|
||||
if platform.system() == "Darwin":
|
||||
os.system(
|
||||
f'osascript -e \'display dialog "Task Completed" with title "OpenACI Agent" buttons "OK" default button "OK"\''
|
||||
)
|
||||
elif platform.system() == "Linux":
|
||||
os.system(
|
||||
f'zenity --info --title="OpenACI Agent" --text="Task Completed" --width=200 --height=100'
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
if "next" in code[0].lower():
|
||||
continue
|
||||
|
||||
if "wait" in code[0].lower():
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
else:
|
||||
time.sleep(1.0)
|
||||
print("EXECUTING CODE:", code[0])
|
||||
|
||||
# Ask for permission before executing
|
||||
exec(code[0])
|
||||
time.sleep(1.0)
|
||||
|
||||
# Update task and subtask trajectories
|
||||
if "reflection" in info and "executor_plan" in info:
|
||||
traj += (
|
||||
"\n\nReflection:\n"
|
||||
+ str(info["reflection"])
|
||||
+ "\n\n----------------------\n\nPlan:\n"
|
||||
+ info["executor_plan"]
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run AgentS2_5 with specified model.")
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="openai",
|
||||
help="Specify the provider to use (e.g., openai, anthropic, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="o3-2025-04-16",
|
||||
help="Specify the model to use (e.g., o3-2025-04-16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the main generation model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the main generation model.",
|
||||
)
|
||||
|
||||
# Grounding model config: Self-hosted endpoint based (required)
|
||||
parser.add_argument(
|
||||
"--ground_provider",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The provider for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_url",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The URL of the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the grounding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The model name for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Width of screenshot image after processor rescaling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_height",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Height of screenshot image after processor rescaling",
|
||||
)
|
||||
|
||||
# AgentS2_5 specific arguments
|
||||
parser.add_argument(
|
||||
"--max_trajectory_length",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Maximum number of image turns to keep in trajectory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_reflection",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable reflection agent to assist the worker agent",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Re-scales screenshot size to ensure it fits in UI-TARS context limit
|
||||
screen_width, screen_height = pyautogui.size()
|
||||
scaled_width, scaled_height = scale_screen_dimensions(
|
||||
screen_width, screen_height, max_dim_size=2400
|
||||
)
|
||||
|
||||
# Load the general engine params
|
||||
engine_params = {
|
||||
"engine_type": args.provider,
|
||||
"model": args.model,
|
||||
"base_url": args.model_url,
|
||||
"api_key": args.model_api_key,
|
||||
}
|
||||
|
||||
# Load the grounding engine from a custom endpoint
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": args.ground_provider,
|
||||
"model": args.ground_model,
|
||||
"base_url": args.ground_url,
|
||||
"api_key": args.ground_api_key,
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
}
|
||||
|
||||
grounding_agent = OSWorldACI(
|
||||
platform=current_platform,
|
||||
engine_params_for_generation=engine_params,
|
||||
engine_params_for_grounding=engine_params_for_grounding,
|
||||
width=screen_width,
|
||||
height=screen_height,
|
||||
)
|
||||
|
||||
agent = AgentS2_5(
|
||||
engine_params,
|
||||
grounding_agent,
|
||||
platform=current_platform,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
enable_reflection=args.enable_reflection,
|
||||
)
|
||||
|
||||
while True:
|
||||
query = input("Query: ")
|
||||
|
||||
agent.reset()
|
||||
|
||||
# Run the agent on your own device
|
||||
run_agent(agent, query, scaled_width, scaled_height)
|
||||
|
||||
response = input("Would you like to provide another query? (y/n): ")
|
||||
if response.lower() != "y":
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,401 @@
|
||||
import os
|
||||
|
||||
import backoff
|
||||
from anthropic import Anthropic
|
||||
from openai import (
|
||||
AzureOpenAI,
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
|
||||
class LMMEngine:
|
||||
pass
|
||||
|
||||
|
||||
class LMMEngineOpenAI(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, organization=None, **kwargs
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.organization = organization
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature # Can force temperature to be the same (in the case of o3 requiring temperature to be 1)
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
|
||||
)
|
||||
organization = self.organization or os.getenv("OPENAI_ORG_ID")
|
||||
if not self.llm_client:
|
||||
if not self.base_url:
|
||||
self.llm_client = OpenAI(api_key=api_key, organization=organization)
|
||||
else:
|
||||
self.llm_client = OpenAI(base_url=self.base_url, api_key=api_key, organization=organization)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_completion_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature if self.temperature is None else self.temperature,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineAnthropic(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, thinking=False, temperature=None, **kwargs
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.thinking = thinking
|
||||
self.api_key = api_key
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = Anthropic(api_key=api_key)
|
||||
# Use the instance temperature if not specified in the call
|
||||
temp = self.temperature if temperature is None else temperature
|
||||
if self.thinking:
|
||||
full_response = self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=8192,
|
||||
thinking={"type": "enabled", "budget_tokens": 4096},
|
||||
**kwargs,
|
||||
)
|
||||
thoughts = full_response.content[0].thinking
|
||||
return full_response.content[1].text
|
||||
return (
|
||||
self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.content[0]
|
||||
.text
|
||||
)
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
# Compatible with Claude-3.7 Sonnet thinking mode
|
||||
def generate_with_thinking(
|
||||
self, messages, temperature=0.0, max_new_tokens=None, **kwargs
|
||||
):
|
||||
"""Generate the next message based on previous messages, and keeps the thinking tokens"""
|
||||
|
||||
full_response = self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=8192,
|
||||
thinking={"type": "enabled", "budget_tokens": 4096},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
thoughts = full_response.content[0].thinking
|
||||
answer = full_response.content[1].text
|
||||
full_response = (
|
||||
f"<thoughts>\n{thoughts}\n</thoughts>\n\n<answer>\n{answer}\n</answer>\n"
|
||||
)
|
||||
return full_response
|
||||
|
||||
|
||||
class LMMEngineGemini(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("GEMINI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("GEMINI_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named GEMINI_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
# Use the temperature passed to generate, otherwise use the instance's temperature, otherwise default to 0.0
|
||||
temp = self.temperature if temperature is None else temperature
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineOpenRouter(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("OPENROUTER_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENROUTER_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("OPEN_ROUTER_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named OPEN_ROUTER_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineAzureOpenAI(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
azure_endpoint=None,
|
||||
model=None,
|
||||
api_version=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.cost = 0.0
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
|
||||
)
|
||||
api_version = self.api_version or os.getenv("OPENAI_API_VERSION")
|
||||
if api_version is None:
|
||||
raise ValueError(
|
||||
"api_version must be provided either as a parameter or as an environment variable named OPENAI_API_VERSION"
|
||||
)
|
||||
azure_endpoint = self.azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
if azure_endpoint is None:
|
||||
raise ValueError(
|
||||
"An Azure API endpoint needs to be provided in either the azure_endpoint parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = AzureOpenAI(
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
completion = self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
total_tokens = completion.usage.total_tokens
|
||||
self.cost += 0.02 * ((total_tokens + 500) / 1000)
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
class LMMEnginevLLM(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
messages,
|
||||
temperature=0.0,
|
||||
top_p=0.8,
|
||||
repetition_penalty=1.05,
|
||||
max_new_tokens=512,
|
||||
**kwargs
|
||||
):
|
||||
api_key = self.api_key or os.getenv("vLLM_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A vLLM API key needs to be provided in either the api_key parameter or as an environment variable named vLLM_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("vLLM_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named vLLM_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
completion = self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
top_p=top_p,
|
||||
extra_body={"repetition_penalty": repetition_penalty},
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
class LMMEngineHuggingFace(LMMEngine):
|
||||
def __init__(self, base_url=None, api_key=None, rate_limit=-1, **kwargs):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("HF_TOKEN")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A HuggingFace token needs to be provided in either the api_key parameter or as an environment variable named HF_TOKEN"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("HF_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"HuggingFace endpoint must be provided as base_url parameter or as an environment variable named HF_ENDPOINT_URL."
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineParasail(LMMEngine):
|
||||
def __init__(self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs):
|
||||
assert model is not None, "Parasail model id must be provided"
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("PARASAIL_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A Parasail API key needs to be provided in either the api_key parameter or as an environment variable named PARASAIL_API_KEY"
|
||||
)
|
||||
base_url = self.base_url
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"Parasail endpoint must be provided as base_url parameter or as an environment variable named PARASAIL_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url if base_url else "https://api.parasail.io/v1", api_key=api_key)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature,
|
||||
**kwargs
|
||||
)
|
||||
.choices[0].
|
||||
message.content
|
||||
)
|
||||
@@ -0,0 +1,306 @@
|
||||
import base64
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gui_agents.s2_5.core.engine import (
|
||||
LMMEngineAnthropic,
|
||||
LMMEngineAzureOpenAI,
|
||||
LMMEngineHuggingFace,
|
||||
LMMEngineOpenAI,
|
||||
LMMEngineOpenRouter,
|
||||
LMMEngineParasail,
|
||||
LMMEnginevLLM,
|
||||
LMMEngineGemini,
|
||||
)
|
||||
|
||||
|
||||
class LMMAgent:
|
||||
def __init__(self, engine_params=None, system_prompt=None, engine=None):
|
||||
if engine is None:
|
||||
if engine_params is not None:
|
||||
engine_type = engine_params.get("engine_type")
|
||||
if engine_type == "openai":
|
||||
self.engine = LMMEngineOpenAI(**engine_params)
|
||||
elif engine_type == "anthropic":
|
||||
self.engine = LMMEngineAnthropic(**engine_params)
|
||||
elif engine_type == "azure":
|
||||
self.engine = LMMEngineAzureOpenAI(**engine_params)
|
||||
elif engine_type == "vllm":
|
||||
self.engine = LMMEnginevLLM(**engine_params)
|
||||
elif engine_type == "huggingface":
|
||||
self.engine = LMMEngineHuggingFace(**engine_params)
|
||||
elif engine_type == "gemini":
|
||||
self.engine = LMMEngineGemini(**engine_params)
|
||||
elif engine_type == "open_router":
|
||||
self.engine = LMMEngineOpenRouter(**engine_params)
|
||||
elif engine_type == "parasail":
|
||||
self.engine = LMMEngineParasail(**engine_params)
|
||||
else:
|
||||
raise ValueError("engine_type is not supported")
|
||||
else:
|
||||
raise ValueError("engine_params must be provided")
|
||||
else:
|
||||
self.engine = engine
|
||||
|
||||
self.messages = [] # Empty messages
|
||||
|
||||
if system_prompt:
|
||||
self.add_system_prompt(system_prompt)
|
||||
else:
|
||||
self.add_system_prompt("You are a helpful assistant.")
|
||||
|
||||
def encode_image(self, image_content):
|
||||
# if image_content is a path to an image file, check type of the image_content to verify
|
||||
if isinstance(image_content, str):
|
||||
with open(image_content, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
else:
|
||||
return base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
def reset(
|
||||
self,
|
||||
):
|
||||
|
||||
self.messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
]
|
||||
|
||||
def add_system_prompt(self, system_prompt):
|
||||
self.system_prompt = system_prompt
|
||||
if len(self.messages) > 0:
|
||||
self.messages[0] = {
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
else:
|
||||
self.messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
)
|
||||
|
||||
def remove_message_at(self, index):
|
||||
"""Remove a message at a given index"""
|
||||
if index < len(self.messages):
|
||||
self.messages.pop(index)
|
||||
|
||||
def replace_message_at(
|
||||
self, index, text_content, image_content=None, image_detail="high"
|
||||
):
|
||||
"""Replace a message at a given index"""
|
||||
if index < len(self.messages):
|
||||
self.messages[index] = {
|
||||
"role": self.messages[index]["role"],
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
if image_content:
|
||||
base64_image = self.encode_image(image_content)
|
||||
self.messages[index]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
text_content,
|
||||
image_content=None,
|
||||
role=None,
|
||||
image_detail="high",
|
||||
put_text_last=False,
|
||||
):
|
||||
"""Add a new message to the list of messages"""
|
||||
|
||||
# API-style inference from OpenAI and AzureOpenAI
|
||||
if isinstance(
|
||||
self.engine,
|
||||
(
|
||||
LMMEngineOpenAI,
|
||||
LMMEngineAzureOpenAI,
|
||||
LMMEngineHuggingFace,
|
||||
LMMEngineGemini,
|
||||
LMMEngineOpenRouter,
|
||||
LMMEngineParasail
|
||||
),
|
||||
):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if isinstance(image_content, np.ndarray) or image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Rotate text to be the last message if desired
|
||||
if put_text_last:
|
||||
text_content = message["content"].pop(0)
|
||||
message["content"].append(text_content)
|
||||
|
||||
self.messages.append(message)
|
||||
|
||||
# For API-style inference from Anthropic
|
||||
elif isinstance(self.engine, LMMEngineAnthropic):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.messages.append(message)
|
||||
|
||||
# Locally hosted vLLM model inference
|
||||
elif isinstance(self.engine, LMMEnginevLLM):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{base64_image}"
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image;base64,{base64_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
self.messages.append(message)
|
||||
else:
|
||||
raise ValueError("engine_type is not supported")
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
user_message=None,
|
||||
messages=None,
|
||||
temperature=0.0,
|
||||
max_new_tokens=None,
|
||||
use_thinking=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate the next response based on previous messages"""
|
||||
if messages is None:
|
||||
messages = self.messages
|
||||
if user_message:
|
||||
messages.append(
|
||||
{"role": "user", "content": [{"type": "text", "text": user_message}]}
|
||||
)
|
||||
|
||||
# Thinking enabled for Claude Sonnet 3.7 and Gemini 2.5 Pro
|
||||
if use_thinking:
|
||||
return self.engine.generate_with_thinking(
|
||||
messages,
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Regular generation
|
||||
return self.engine.generate(
|
||||
messages,
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Dict, Optional
|
||||
from gui_agents.s2_5.core.mllm import LMMAgent
|
||||
|
||||
|
||||
class BaseModule:
|
||||
def __init__(self, engine_params: Dict, platform: str):
|
||||
self.engine_params = engine_params
|
||||
self.platform = platform
|
||||
|
||||
def _create_agent(
|
||||
self, system_prompt: str = None, engine_params: Optional[Dict] = None
|
||||
) -> LMMAgent:
|
||||
"""Create a new LMMAgent instance"""
|
||||
agent = LMMAgent(engine_params or self.engine_params)
|
||||
if system_prompt:
|
||||
agent.add_system_prompt(system_prompt)
|
||||
return agent
|
||||
@@ -0,0 +1,98 @@
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
|
||||
class PROCEDURAL_MEMORY:
|
||||
@staticmethod
|
||||
def construct_simple_worker_procedural_memory(agent_class, skipped_actions):
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""\
|
||||
You are an expert in graphical user interfaces and Python code. You are responsible for executing the task: `TASK_DESCRIPTION`.
|
||||
You are working in CURRENT_OS.
|
||||
You are provided with:
|
||||
1. A screenshot of the current time step.
|
||||
2. The history of your previous interactions with the UI.
|
||||
3. Access to the following class and methods to interact with the UI:
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
for attr_name in dir(agent_class):
|
||||
if attr_name in skipped_actions:
|
||||
continue
|
||||
|
||||
attr = getattr(agent_class, attr_name)
|
||||
if callable(attr) and hasattr(attr, "is_agent_action"):
|
||||
# Use inspect to get the full function signature
|
||||
signature = inspect.signature(attr)
|
||||
procedural_memory += f"""
|
||||
def {attr_name}{signature}:
|
||||
'''{attr.__doc__}'''
|
||||
"""
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
Your response should be formatted like this:
|
||||
(Previous action verification)
|
||||
Carefully analyze based on the screenshot if the previous action was successful. If the previous action was not successful, provide a reason for the failure.
|
||||
|
||||
(Screenshot Analysis)
|
||||
Closely examine and describe the current state of the desktop along with the currently open applications.
|
||||
|
||||
(Next Action)
|
||||
Based on the current screenshot and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task.
|
||||
|
||||
(Grounded Action)
|
||||
Translate the next action into code using the provided API methods. Format the code like this:
|
||||
```python
|
||||
agent.click("The menu button at the top right of the window", 1, "left")
|
||||
```
|
||||
Note for the code:
|
||||
1. Only perform one action at a time.
|
||||
2. Do not put anything other than python code in the block. You can only use one function call at a time. Do not put more than one function call in the block.
|
||||
3. You must use only the available methods provided above to interact with the UI, do not invent new methods.
|
||||
4. Only return one code block every time. There must be a single line of code in the code block.
|
||||
5. Do not do anything other than the exact specified task. Return with `agent.done()` immediately after the subtask is completed or `agent.fail()` if it cannot be completed.
|
||||
6. Whenever possible, your grounded action should use hot-keys with the agent.hotkey() action instead of clicking or dragging.
|
||||
7. My computer's password is 'osworld-public-evaluation', feel free to use it when you need sudo rights.
|
||||
8. Generate agent.fail() as your grounded action if you get exhaustively stuck on the task and believe it is impossible.
|
||||
9. Generate agent.done() as your grounded action when your believe the task is fully complete.
|
||||
10. Do not use the "command" + "tab" hotkey on MacOS.
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory.strip()
|
||||
|
||||
# For reflection agent, post-action verification mainly for cycle detection
|
||||
REFLECTION_ON_TRAJECTORY = textwrap.dedent(
|
||||
"""
|
||||
You are an expert computer use agent designed to reflect on the trajectory of a task and provide feedback on what has happened so far.
|
||||
You have access to the Task Description and the Current Trajectory of another computer agent. The Current Trajectory is a sequence of a desktop image, chain-of-thought reasoning, and a desktop action for each time step. The last image is the screen's display after the last action.
|
||||
Your task is to generate a reflection. Your generated reflection must fall under one of the cases listed below:
|
||||
|
||||
Case 1. The trajectory is not going according to plan. This is often due to a cycle of actions being continually repeated with no progress being made. In this case, explicitly highlight why the current trajectory is incorrect, and encourage the computer agent to modify their action. However, DO NOT encourage a specific action in particular.
|
||||
Case 2. The trajectory is going according to plan. In this case, simply tell the agent to continue proceeding as planned. DO NOT encourage a specific action in particular.
|
||||
Case 3. You believe the current task has been completed. In this case, tell the agent that the task has been successfully completed.
|
||||
|
||||
To be successful, you must follow the rules below:
|
||||
- **Your output MUST be based on one of the case options above**.
|
||||
- DO NOT suggest any specific future plans or actions. Your only goal is to provide a reflection, not an actual plan or action.
|
||||
- Any response that falls under Case 1 should explain why the trajectory is not going according to plan. You should especially lookout for cycles of actions that are continually repeated with no progress.
|
||||
- Any response that falls under Case 2 should be concise, since you just need to affirm the agent to continue with the current trajectory.
|
||||
"""
|
||||
)
|
||||
|
||||
PHRASE_TO_WORD_COORDS_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in graphical user interfaces. Your task is to process a phrase of text, and identify the most relevant word on the computer screen.
|
||||
You are provided with a phrase, a table with all the text on the screen, and a screenshot of the computer screen. You will identify the single word id that is best associated with the provided phrase.
|
||||
This single word must be displayed on the computer screenshot, and its location on the screen should align with the provided phrase.
|
||||
Each row in the text table provides 2 pieces of data in the following order. 1st is the unique word id. 2nd is the corresponding word.
|
||||
|
||||
To be successful, it is very important to follow all these rules:
|
||||
1. First, think step by step and generate your reasoning about which word id to click on.
|
||||
2. Then, output the unique word id. Remember, the word id is the 1st number in each row of the text table.
|
||||
3. If there are multiple occurrences of the same word, use the surrounding context in the phrase to choose the correct one. Pay very close attention to punctuation and capitalization.
|
||||
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
import re
|
||||
import time
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
def call_llm_safe(
|
||||
agent, temperature: float = 0.0, use_thinking: bool = False
|
||||
) -> str:
|
||||
# Retry if fails
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
response = agent.get_response(
|
||||
temperature=temperature, use_thinking=use_thinking
|
||||
)
|
||||
assert response is not None, "Response from agent should not be None"
|
||||
print("Response success!")
|
||||
break # If successful, break out of the loop
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
print(f"Attempt {attempt} failed: {e}")
|
||||
if attempt == max_retries:
|
||||
print("Max retries reached. Handling failure.")
|
||||
time.sleep(1.0)
|
||||
return response if response is not None else ""
|
||||
|
||||
|
||||
def split_thinking_response(full_response: str) -> Tuple[str, str]:
|
||||
try:
|
||||
# Extract thoughts section
|
||||
thoughts_match = re.search(
|
||||
r"<thoughts>(.*?)</thoughts>", full_response, re.DOTALL
|
||||
)
|
||||
thoughts = thoughts_match.group(1).strip()
|
||||
# Extract answer section
|
||||
answer_match = re.search(r"<answer>(.*?)</answer>", full_response, re.DOTALL)
|
||||
answer = answer_match.group(1).strip()
|
||||
return answer, thoughts
|
||||
except Exception as e:
|
||||
return full_response, ""
|
||||
|
||||
|
||||
def parse_single_code_from_string(input_string):
|
||||
input_string = input_string.strip()
|
||||
if input_string.strip() in ["WAIT", "DONE", "FAIL"]:
|
||||
return input_string.strip()
|
||||
|
||||
# This regular expression will match both ```code``` and ```python code```
|
||||
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||
# Find all non-overlapping matches in the string
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
|
||||
# The regex above captures the content inside the triple backticks.
|
||||
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
||||
# so the code inside backticks can span multiple lines.
|
||||
|
||||
# matches now contains all the captured code snippets
|
||||
|
||||
codes = []
|
||||
|
||||
for match in matches:
|
||||
match = match.strip()
|
||||
commands = [
|
||||
"WAIT",
|
||||
"DONE",
|
||||
"FAIL",
|
||||
] # fixme: updates this part when we have more commands
|
||||
|
||||
if match in commands:
|
||||
codes.append(match.strip())
|
||||
elif match.split("\n")[-1] in commands:
|
||||
if len(match.split("\n")) > 1:
|
||||
codes.append("\n".join(match.split("\n")[:-1]))
|
||||
codes.append(match.split("\n")[-1])
|
||||
else:
|
||||
codes.append(match)
|
||||
|
||||
if len(codes) <= 0:
|
||||
return "fail"
|
||||
return codes[0]
|
||||
|
||||
|
||||
def sanitize_code(code):
|
||||
# This pattern captures the outermost double-quoted text
|
||||
if "\n" in code:
|
||||
pattern = r'(".*?")'
|
||||
# Find all matches in the text
|
||||
matches = re.findall(pattern, code, flags=re.DOTALL)
|
||||
if matches:
|
||||
# Replace the first occurrence only
|
||||
first_match = matches[0]
|
||||
code = code.replace(first_match, f'"""{first_match[1:-1]}"""', 1)
|
||||
return code
|
||||
|
||||
|
||||
def extract_first_agent_function(code_string):
|
||||
# Regular expression pattern to match 'agent' functions with any arguments, including nested parentheses
|
||||
pattern = r'agent\.[a-zA-Z_]+\((?:[^()\'"]|\'[^\']*\'|"[^"]*")*\)'
|
||||
|
||||
# Find all matches in the string
|
||||
matches = re.findall(pattern, code_string)
|
||||
|
||||
# Return the first match if found, otherwise return None
|
||||
return matches[0] if matches else None
|
||||
+8
-11
@@ -42,34 +42,31 @@ export OPEN_ROUTER_ENDPOINT_URL="https://openrouter.ai/api/v1"
|
||||
```
|
||||
|
||||
```python
|
||||
from gui_agents.s2.agents.agent_s import AgentS2
|
||||
from gui_agents.s2_5.agents.agent_s import AgentS2_5
|
||||
|
||||
engine_params = {
|
||||
"engine_type": 'anthropic', # Allowed Values: 'openai', 'anthropic', 'gemini', 'azure_openai', 'vllm', 'open_router'
|
||||
"model": 'claude-3-5-sonnet-20240620', # Allowed Values: Any Vision and Language Model from the supported APIs
|
||||
"engine_type": 'openai', # Allowed Values: 'openai', 'anthropic', 'gemini', 'azure_openai', 'vllm', 'open_router'
|
||||
"model": 'o3-2025-04-16', # Allowed Values: Any Vision and Language Model from the supported APIs
|
||||
}
|
||||
agent = AgentS2(
|
||||
agent = AgentS2_5(
|
||||
engine_params,
|
||||
grounding_agent,
|
||||
platform=current_platform,
|
||||
action_space="pyautogui",
|
||||
observation_type="mixed",
|
||||
search_engine="LLM"
|
||||
)
|
||||
```
|
||||
|
||||
To use the underlying Multimodal Agent (LMMAgent) which wraps LLMs with message handling functionality, you can use the following code snippet:
|
||||
|
||||
```python
|
||||
from gui_agents.core.mllm import LMMAgent
|
||||
from gui_agents.s2_5.core.mllm import LMMAgent
|
||||
|
||||
engine_params = {
|
||||
"engine_type": 'anthropic', # Allowed Values: 'openai', 'anthropic', 'gemini', 'azure_openai', 'vllm', 'open_router'
|
||||
"model": 'claude-3-5-sonnet-20240620', # Allowed Values: Any Vision and Language Model from the supported APIs
|
||||
"engine_type": 'openai', # Allowed Values: 'openai', 'anthropic', 'gemini', 'azure_openai', 'vllm', 'open_router'
|
||||
"model": 'o3-2025-04-16', # Allowed Values: Any Vision and Language Model from the supported APIs
|
||||
}
|
||||
agent = LMMAgent(
|
||||
engine_params=engine_params,
|
||||
)
|
||||
```
|
||||
|
||||
The `AgentS2` also utilizes this `LMMAgent` internally.
|
||||
The `AgentS2_5` also utilizes this `LMMAgent` internally.
|
||||
@@ -0,0 +1,10 @@
|
||||
# Deplying Agent S2.5 in OSWorld
|
||||
|
||||
# Step 1: Set up Agent S2.5
|
||||
|
||||
Follow the [README.md](https://github.com/simular-ai/Agent-S/blob/main/README.md) to set up Agent S2.5.
|
||||
|
||||
# Step 2: Copying Over Run Files
|
||||
|
||||
If you haven't already, please follow the [OSWorld environment setup](https://github.com/xlang-ai/OSWorld/blob/main/README.md). We've provided the relevant OSWorld run files for evaluation in this `osworld_setup` folder. Please copy this over to your OSWorld folder. `run_local.py` and `lib_run_single_local.py` are for if you want to run locally on VMWare and `run.py` and `lib_run_single.py` are for if you want to run on AWS.
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import *
|
||||
from wrapt_timeout_decorator import *
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def run_single_example(
|
||||
agent, env, example, max_steps, instruction, args, example_result_dir, scores
|
||||
):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception as e:
|
||||
agent.reset()
|
||||
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
|
||||
with open(os.path.join(example_result_dir, f"step_0.png"), "wb") as _f:
|
||||
_f.write(obs["screenshot"])
|
||||
|
||||
with open(os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(instruction)
|
||||
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
for action in actions:
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(
|
||||
os.path.join(
|
||||
example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
),
|
||||
"wb",
|
||||
) as _f:
|
||||
_f.write(obs["screenshot"])
|
||||
|
||||
response.update(
|
||||
{
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png",
|
||||
}
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
response
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(
|
||||
os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
def setup_logger(example, example_result_dir):
|
||||
runtime_logger = logging.getLogger(f"desktopenv.example.{example['id']}")
|
||||
runtime_logger.setLevel(logging.DEBUG)
|
||||
runtime_logger.addHandler(
|
||||
logging.FileHandler(os.path.join(example_result_dir, "runtime.log"))
|
||||
)
|
||||
return runtime_logger
|
||||
@@ -0,0 +1,94 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import *
|
||||
from wrapt_timeout_decorator import *
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def run_single_example(
|
||||
agent, env, example, max_steps, instruction, args, example_result_dir, scores
|
||||
):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception as e:
|
||||
agent.reset()
|
||||
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
|
||||
with open(os.path.join(example_result_dir, f"step_0.png"), "wb") as _f:
|
||||
_f.write(obs["screenshot"])
|
||||
|
||||
with open(os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(instruction)
|
||||
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
time.sleep(0.5)
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
for action in actions:
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(
|
||||
os.path.join(
|
||||
example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
),
|
||||
"wb",
|
||||
) as _f:
|
||||
_f.write(obs["screenshot"])
|
||||
|
||||
response.update(
|
||||
{
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png",
|
||||
}
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
response
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(
|
||||
os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
def setup_logger(example, example_result_dir):
|
||||
runtime_logger = logging.getLogger(f"desktopenv.example.{example['id']}")
|
||||
runtime_logger.setLevel(logging.DEBUG)
|
||||
runtime_logger.addHandler(
|
||||
logging.FileHandler(os.path.join(example_result_dir, "runtime.log"))
|
||||
)
|
||||
return runtime_logger
|
||||
@@ -0,0 +1,512 @@
|
||||
"""OSWorld's run.py with AgentS2_5."""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from multiprocessing import Process, Manager, current_process, Queue
|
||||
|
||||
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(stdout_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> list:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list, engine_params, engine_params_for_grounding):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
||||
snapshot_name = None
|
||||
region = getattr(args, 'region', None)
|
||||
if args.provider_name == 'aws' and region is not None:
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
snapshot_name = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get snapshot_name from IMAGE_ID_MAP: {e}")
|
||||
snapshot_name = None
|
||||
from gui_agents.s2_5.agents.agent_s import AgentS2_5
|
||||
from gui_agents.s2_5.agents.grounding import OSWorldACI
|
||||
grounding_agent = OSWorldACI(
|
||||
platform="linux",
|
||||
engine_params_for_generation=engine_params,
|
||||
engine_params_for_grounding=engine_params_for_grounding,
|
||||
width=args.screen_width,
|
||||
height=args.screen_height,
|
||||
)
|
||||
agent = AgentS2_5(
|
||||
engine_params,
|
||||
grounding_agent,
|
||||
platform="linux",
|
||||
)
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type = "Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, 'client_password', ''),
|
||||
)
|
||||
active_environments.append(env)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
instruction = example["instruction"]
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {instruction}")
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
instruction,
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
global is_terminating, active_environments, processes
|
||||
if is_terminating:
|
||||
return
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
time.sleep(1)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=1.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=8)
|
||||
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model_provider", type=str, default="openai")
|
||||
parser.add_argument("--model", type=str, default="gpt-4o")
|
||||
parser.add_argument(
|
||||
"--model_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the main generation model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the main generation model.",
|
||||
)
|
||||
parser.add_argument("--model_temperature", type=float, default=None, help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||
|
||||
# grounding model config
|
||||
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||
parser.add_argument(
|
||||
"--ground_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the grounding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Width of screenshot image after processor rescaling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_height",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Height of screenshot image after processor rescaling",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
|
||||
engine_params = {
|
||||
"engine_type": args.model_provider,
|
||||
"model": args.model,
|
||||
"base_url": getattr(args, 'model_url', ''),
|
||||
"api_key": getattr(args, 'model_api_key', ''),
|
||||
"temperature": getattr(args, 'model_temperature', None),
|
||||
}
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": args.ground_provider,
|
||||
"model": args.ground_model,
|
||||
"base_url": getattr(args, 'ground_url', ''),
|
||||
"api_key": getattr(args, 'ground_api_key', ''),
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
}
|
||||
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores, engine_params, engine_params_for_grounding),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores, engine_params, engine_params_for_grounding),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
@@ -0,0 +1,402 @@
|
||||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import lib_run_single_local
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from gui_agents.s2_5.agents.agent_s import AgentS2_5
|
||||
from gui_agents.s2_5.agents.grounding import OSWorldACI
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Almost deprecated since it's not multi-env, use run_multienv_*.py instead
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="gpt-4o")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
|
||||
# AgentS2 specific config
|
||||
parser.add_argument("--model_provider", type=str, default="openai")
|
||||
parser.add_argument(
|
||||
"--model_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the main generation model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the main generation model.",
|
||||
)
|
||||
parser.add_argument("--model_temperature", type=float, default=None, help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||
|
||||
# grounding model config
|
||||
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||
parser.add_argument(
|
||||
"--ground_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the grounding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Width of screenshot image after processor rescaling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_height",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Height of screenshot image after processor rescaling",
|
||||
)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
scores = []
|
||||
max_steps = args.max_steps
|
||||
|
||||
# log args
|
||||
logger.info("Args: %s", args)
|
||||
# set wandb project
|
||||
cfg_args = {
|
||||
"path_to_vm": args.path_to_vm,
|
||||
"provider_name": args.provider_name,
|
||||
"headless": args.headless,
|
||||
"action_space": args.action_space,
|
||||
"observation_type": args.observation_type,
|
||||
"screen_width": args.screen_width,
|
||||
"screen_height": args.screen_height,
|
||||
"sleep_after_execution": args.sleep_after_execution,
|
||||
"max_steps": args.max_steps,
|
||||
"max_trajectory_length": args.max_trajectory_length,
|
||||
"model": args.model,
|
||||
"temperature": args.temperature,
|
||||
"result_dir": args.result_dir,
|
||||
}
|
||||
|
||||
# AgentS2 configuration
|
||||
engine_params = {
|
||||
"engine_type": args.model_provider,
|
||||
"model": args.model,
|
||||
"base_url": getattr(args, 'model_url', ''),
|
||||
"api_key": getattr(args, 'model_api_key', ''),
|
||||
"temperature": getattr(args, 'model_temperature', None),
|
||||
}
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": args.ground_provider,
|
||||
"model": args.ground_model,
|
||||
"base_url": getattr(args, 'ground_url', ''),
|
||||
"api_key": getattr(args, 'ground_api_key', ''),
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
}
|
||||
|
||||
# Create grounding agent
|
||||
grounding_agent = OSWorldACI(
|
||||
platform="linux",
|
||||
engine_params_for_generation=engine_params,
|
||||
engine_params_for_grounding=engine_params_for_grounding,
|
||||
width=args.screen_width,
|
||||
height=args.screen_height,
|
||||
)
|
||||
|
||||
# Create AgentS2 worker
|
||||
agent = AgentS2_5(
|
||||
engine_params,
|
||||
grounding_agent,
|
||||
platform="linux",
|
||||
)
|
||||
|
||||
env = DesktopEnv(
|
||||
provider_name=args.provider_name,
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type = "Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
snapshot_name="signed_in_state_1"
|
||||
)
|
||||
|
||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Domain]: {domain}")
|
||||
logger.info(f"[Example ID]: {example_id}")
|
||||
|
||||
instruction = example["instruction"]
|
||||
|
||||
logger.info(f"[Instruction]: {instruction}")
|
||||
# wandb each example config settings
|
||||
cfg_args["instruction"] = instruction
|
||||
cfg_args["start_time"] = datetime.datetime.now().strftime(
|
||||
"%Y:%m:%d-%H:%M:%S"
|
||||
)
|
||||
# run.config.update(cfg_args)
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
# example start running
|
||||
try:
|
||||
lib_run_single_local.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
max_steps,
|
||||
instruction,
|
||||
args,
|
||||
example_result_dir,
|
||||
scores,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||
# Only attempt to end recording if controller exists (not Docker provider)
|
||||
if hasattr(env, 'controller') and env.controller is not None:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
-243
@@ -1,243 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from gui_agents.s1.core.AgentS import GraphSearchAgent
|
||||
import io
|
||||
import pyautogui
|
||||
import time
|
||||
from threading import Event, Lock
|
||||
|
||||
# Determine the operating system and select appropriate ACI
|
||||
current_platform = platform.system().lower()
|
||||
if current_platform == "linux":
|
||||
from gui_agents.s1.aci.LinuxOSACI import LinuxACI, UIElement
|
||||
|
||||
grounding_agent = LinuxACI()
|
||||
elif current_platform == "darwin":
|
||||
from gui_agents.s1.aci.MacOSACI import MacOSACI, UIElement
|
||||
|
||||
grounding_agent = MacOSACI()
|
||||
elif current_platform == "windows":
|
||||
from gui_agents.s1.aci.WindowsOSACI import WindowsACI, UIElement
|
||||
|
||||
grounding_agent = WindowsACI()
|
||||
else:
|
||||
raise ValueError(f"Unsupported operating system: {current_platform}")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add global lock and status tracking
|
||||
agent_lock = Lock()
|
||||
agent_status = {"is_running": False, "current_instruction": None, "start_time": None}
|
||||
|
||||
# Add a stop event
|
||||
stop_event = Event()
|
||||
|
||||
|
||||
class InstructionData(BaseModel):
|
||||
screenshot: str
|
||||
accessibility_tree: str
|
||||
|
||||
|
||||
class CommandRequest(BaseModel):
|
||||
obs: InstructionData
|
||||
instruction: str
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
model: str
|
||||
instruction: str
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
async def stream_code(code: str):
|
||||
for line in code.splitlines(keepends=True):
|
||||
yield line
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
def run_agent(agent: GraphSearchAgent, instruction: str):
|
||||
global stop_event
|
||||
stop_event.clear() # Reset the stop event
|
||||
obs = {}
|
||||
traj = "Task:\n" + instruction
|
||||
subtask_traj = ""
|
||||
for _ in range(15):
|
||||
# Check if stop was requested
|
||||
if stop_event.is_set():
|
||||
print("Agent execution stopped by user")
|
||||
return
|
||||
|
||||
print("iteration", _)
|
||||
|
||||
obs["accessibility_tree"] = UIElement.systemWideElement()
|
||||
|
||||
# Get screen shot using pyautogui.
|
||||
# Take a screenshot
|
||||
screenshot = pyautogui.screenshot()
|
||||
|
||||
# Save the screenshot to a BytesIO object
|
||||
buffered = io.BytesIO()
|
||||
screenshot.save(buffered, format="PNG")
|
||||
|
||||
# Get the byte value of the screenshot
|
||||
screenshot_bytes = buffered.getvalue()
|
||||
# Convert to base64 string.
|
||||
obs["screenshot"] = screenshot_bytes
|
||||
|
||||
# Get next action code from the agent
|
||||
info, code = agent.predict(instruction=instruction, observation=obs)
|
||||
|
||||
if "done" in code[0].lower() or "fail" in code[0].lower():
|
||||
if platform.system() == "Darwin":
|
||||
os.system(
|
||||
f'osascript -e \'display dialog "Task Completed" with title "OpenACI Agent" buttons "OK" default button "OK"\''
|
||||
)
|
||||
elif platform.system() == "Linux":
|
||||
os.system(
|
||||
f'zenity --info --title="OpenACI Agent" --text="Task Completed" --width=200 --height=100'
|
||||
)
|
||||
|
||||
agent.update_narrative_memory(traj)
|
||||
break
|
||||
|
||||
if "next" in code[0].lower():
|
||||
continue
|
||||
|
||||
if "wait" in code[0].lower():
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
else:
|
||||
time.sleep(1.0)
|
||||
print("EXECUTING CODE:", code[0])
|
||||
|
||||
# Ask for permission before executing
|
||||
exec(code[0])
|
||||
time.sleep(1.0)
|
||||
|
||||
# Update task and subtask trajectories and optionally the episodic memory
|
||||
traj += (
|
||||
"\n\nReflection:\n"
|
||||
+ str(info["reflection"])
|
||||
+ "\n\n----------------------\n\nPlan:\n"
|
||||
+ info["executor_plan"]
|
||||
)
|
||||
subtask_traj = agent.update_episodic_memory(info, subtask_traj)
|
||||
|
||||
|
||||
@app.post("/run")
|
||||
async def run(request: RunRequest):
|
||||
global agent_status
|
||||
|
||||
# Check if agent is already running
|
||||
if not agent_lock.acquire(blocking=False):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="An agent is already running. Use /status to check current run or /stop to stop it.",
|
||||
)
|
||||
|
||||
try:
|
||||
agent_status = {
|
||||
"is_running": True,
|
||||
"current_instruction": request.instruction,
|
||||
"start_time": time.time(),
|
||||
"model": request.model,
|
||||
}
|
||||
|
||||
if "gpt" in request.model:
|
||||
engine_type = "openai"
|
||||
elif "claude" in request.model:
|
||||
engine_type = "anthropic"
|
||||
|
||||
engine_params = {
|
||||
"engine_type": engine_type,
|
||||
"model": request.model,
|
||||
"api_key": request.api_key,
|
||||
}
|
||||
|
||||
print("engine_params", engine_params)
|
||||
|
||||
agent = GraphSearchAgent(
|
||||
engine_params,
|
||||
grounding_agent,
|
||||
platform=current_platform,
|
||||
action_space="pyautogui",
|
||||
observation_type="mixed",
|
||||
)
|
||||
|
||||
agent.reset()
|
||||
print("start the agent")
|
||||
run_agent(agent, request.instruction)
|
||||
|
||||
return {"status": "completed"}
|
||||
|
||||
finally:
|
||||
agent_status = {
|
||||
"is_running": False,
|
||||
"current_instruction": None,
|
||||
"start_time": None,
|
||||
}
|
||||
agent_lock.release()
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
async def get_status():
|
||||
if agent_status["is_running"]:
|
||||
duration = time.time() - agent_status["start_time"]
|
||||
return {
|
||||
"status": "running",
|
||||
"instruction": agent_status["current_instruction"],
|
||||
"model": agent_status["model"],
|
||||
"running_for_seconds": round(duration, 2),
|
||||
}
|
||||
return {"status": "idle"}
|
||||
|
||||
|
||||
@app.post("/execute")
|
||||
async def execute_command_stream(cmd: CommandRequest):
|
||||
engine_params = {
|
||||
"engine_type": "openai",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
agent = GraphSearchAgent(
|
||||
engine_params,
|
||||
grounding_agent,
|
||||
platform=current_platform,
|
||||
action_space="pyautogui",
|
||||
observation_type="mixed",
|
||||
)
|
||||
|
||||
obs = {
|
||||
"screenshot": cmd.obs.screenshot,
|
||||
"accessibility_tree": cmd.obs.accessibility_tree,
|
||||
}
|
||||
instruction = cmd.instruction
|
||||
info, code = agent.predict(instruction=instruction, observation=obs)
|
||||
|
||||
return StreamingResponse(stream_code(code), media_type="text/plain")
|
||||
|
||||
|
||||
@app.post("/stop")
|
||||
async def stop_agent():
|
||||
if not agent_status["is_running"]:
|
||||
raise HTTPException(status_code=404, detail="No agent is currently running")
|
||||
|
||||
global stop_event
|
||||
stop_event.set()
|
||||
return {"status": "stop signal sent"}
|
||||
|
||||
|
||||
import uvicorn
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"server:app",
|
||||
host="0.0.0.0", # Allows external access
|
||||
port=8000, # Default port for FastAPI
|
||||
reload=True, # Auto-reload on code changes
|
||||
)
|
||||
+1
-2
@@ -35,8 +35,7 @@ setup(
|
||||
extras_require={"dev": ["black"]}, # Code formatter for linting
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"agent_s1=gui_agents.s1.cli_app:main",
|
||||
"agent_s2=gui_agents.s2.cli_app:main",
|
||||
"agent_s=gui_agents.s2_5.cli_app:main",
|
||||
],
|
||||
},
|
||||
classifiers=[
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gui_agents.s1.aci.ACI import ACI, _normalize_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aci():
|
||||
return ACI(top_app_only=True, ocr=False)
|
||||
|
||||
|
||||
def test_normalize_key():
|
||||
"""Test key normalization"""
|
||||
assert _normalize_key("cmd") == "command"
|
||||
assert _normalize_key("ctrl") == "ctrl"
|
||||
assert _normalize_key("shift") == "shift"
|
||||
|
||||
|
||||
def test_hotkey_cmd_normalization(aci):
|
||||
"""Test cmd normalization in hotkey command"""
|
||||
command = aci.hotkey(["cmd", "c"])
|
||||
assert "command" in command
|
||||
assert "cmd" not in command
|
||||
|
||||
|
||||
def test_click_with_cmd_key(aci):
|
||||
"""Test cmd normalization in click command"""
|
||||
aci.nodes = [{"position": (100, 200), "size": (50, 50)}]
|
||||
command = aci.click(0, hold_keys=["cmd"])
|
||||
assert "command" in command
|
||||
assert "cmd" not in command
|
||||
|
||||
|
||||
def test_type_with_overwrite(aci):
|
||||
"""Test type command with overwrite"""
|
||||
aci.nodes = [{"position": (100, 200), "size": (50, 50)}]
|
||||
command = aci.type(0, "test", overwrite=True)
|
||||
assert "command" in command or "ctrl" in command
|
||||
assert "backspace" in command
|
||||
@@ -1,25 +0,0 @@
|
||||
import time
|
||||
|
||||
import pyautogui
|
||||
from AppKit import NSWorkspace
|
||||
|
||||
from gui_agents.s1.aci.MacOSACI import MacOSACI
|
||||
|
||||
agent = MacOSACI()
|
||||
|
||||
|
||||
def test_app_switching():
|
||||
app_or_file_name = "Safari"
|
||||
|
||||
exec(agent.switch_applications(app_or_file_name))
|
||||
|
||||
# Checking the frontmost application
|
||||
frontmost_app = NSWorkspace.sharedWorkspace().frontmostApplication().localizedName()
|
||||
print(frontmost_app)
|
||||
|
||||
# Assert to confirm Safari is the frontmost application
|
||||
assert frontmost_app == "Safari", f"Expected Safari, but got {frontmost_app}"
|
||||
|
||||
|
||||
# Run the test
|
||||
test_app_switching()
|
||||
@@ -1,9 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from gui_agents.aci.UIElementBase import UIElementBase
|
||||
|
||||
|
||||
def test_uielement_base_is_abstract():
|
||||
"""Test that UIElementBase cannot be instantiated directly"""
|
||||
with pytest.raises(TypeError):
|
||||
UIElementBase()
|
||||
@@ -1,38 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pyatspi
|
||||
import pytest
|
||||
|
||||
from gui_agents.aci.UIElementLinux import UIElement
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_accessible():
|
||||
mock = Mock()
|
||||
mock.name = "Test Window"
|
||||
mock.getRole.return_value = pyatspi.ROLE_WINDOW
|
||||
mock.getState.return_value.contains.return_value = True
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ui_element(mock_accessible):
|
||||
return UIElement(mock_accessible)
|
||||
|
||||
|
||||
def test_role(ui_element, mock_accessible):
|
||||
"""Test role retrieval"""
|
||||
mock_accessible.getRoleName.return_value = "window"
|
||||
assert ui_element.role() == "window"
|
||||
|
||||
|
||||
def test_position(ui_element, mock_accessible):
|
||||
"""Test position retrieval"""
|
||||
mock_accessible.getPosition.return_value = (100, 200)
|
||||
assert ui_element.position() == (100, 200)
|
||||
|
||||
|
||||
def test_size(ui_element, mock_accessible):
|
||||
"""Test size retrieval"""
|
||||
mock_accessible.getSize.return_value = (300, 400)
|
||||
assert ui_element.size() == (300, 400)
|
||||
@@ -1,54 +0,0 @@
|
||||
from typing import Dict
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gui_agents.s1.aci.MacOSACI import UIElement
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ax_element():
|
||||
mock_element = Mock()
|
||||
mock_element.__repr__ = lambda x: "x:100 y:200"
|
||||
return mock_element
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_size_element():
|
||||
mock_element = Mock()
|
||||
mock_element.__repr__ = lambda x: "w:300 h:400"
|
||||
return mock_element
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ui_element(mock_ax_element):
|
||||
element = UIElement(mock_ax_element)
|
||||
return element
|
||||
|
||||
|
||||
def test_position_parsing(ui_element, mock_ax_element):
|
||||
"""Test position parsing from AX element"""
|
||||
with patch.object(ui_element, "attribute", return_value=mock_ax_element):
|
||||
pos = ui_element.position()
|
||||
assert pos == (100.0, 200.0)
|
||||
|
||||
|
||||
def test_size_parsing(ui_element, mock_size_element):
|
||||
"""Test size parsing from AX element"""
|
||||
with patch.object(ui_element, "attribute", return_value=mock_size_element):
|
||||
size = ui_element.size()
|
||||
assert size == (300.0, 400.0)
|
||||
|
||||
|
||||
def test_get_current_applications(obs: Dict):
|
||||
"""Test getting list of current applications"""
|
||||
with patch("AppKit.NSWorkspace") as mock_workspace:
|
||||
mock_app = Mock()
|
||||
mock_app.activationPolicy.return_value = 0
|
||||
mock_app.localizedName.return_value = "TestApp"
|
||||
mock_workspace.sharedWorkspace.return_value.runningApplications.return_value = [
|
||||
mock_app
|
||||
]
|
||||
|
||||
apps = UIElement.get_current_applications(obs)
|
||||
assert apps == ["TestApp"]
|
||||
@@ -1,46 +0,0 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import pytest
|
||||
|
||||
from gui_agents.aci.UIElementOSWorld import UIElement
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_xml():
|
||||
return """
|
||||
<root>
|
||||
<application name="TestApp">
|
||||
<window uri:deskat:state.at-spi.gnome.org:active="true">
|
||||
<button uri:deskat:component.at-spi.gnome.org:screencoord="(100,200)"
|
||||
uri:deskat:component.at-spi.gnome.org:size="(300,400)">
|
||||
Click me
|
||||
</button>
|
||||
</window>
|
||||
</application>
|
||||
</root>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ui_element(sample_xml):
|
||||
tree = ET.ElementTree(ET.fromstring(sample_xml))
|
||||
return UIElement(tree.getroot())
|
||||
|
||||
|
||||
def test_nodeFromTree(sample_xml):
|
||||
"""Test creating UIElement from XML string"""
|
||||
element = UIElement.nodeFromTree(sample_xml)
|
||||
assert element is not None
|
||||
assert isinstance(element, UIElement)
|
||||
|
||||
|
||||
def test_position(ui_element):
|
||||
"""Test position extraction from XML"""
|
||||
button = ui_element.children()[0].children()[0]
|
||||
assert button.position() == (100, 200)
|
||||
|
||||
|
||||
def test_size(ui_element):
|
||||
"""Test size extraction from XML"""
|
||||
button = ui_element.children()[0].children()[0]
|
||||
assert button.size() == (300, 400)
|
||||
Referência em uma Nova Issue
Bloquear um usuário