s3 🧠🤓🤯
Esse commit está contido em:
+25
-39
@@ -3,6 +3,12 @@
|
|||||||
<small>Use Computer Like a Human</small>
|
<small>Use Computer Like a Human</small>
|
||||||
</h1>
|
</h1>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
🌐 <a href="https://www.simular.ai/articles/agent-s3">[S3 blog]</a>
|
||||||
|
📄 <a href="https://arxiv.org/abs/2510.02250">[S3 Paper]</a>
|
||||||
|
🎥 <a href="https://www.youtube.com/watch?v=VHr0a3UBsh4">[S3 Video]</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
🌐 <a href="https://www.simular.ai/articles/agent-s2-technical-review">[S2 blog]</a>
|
🌐 <a href="https://www.simular.ai/articles/agent-s2-technical-review">[S2 blog]</a>
|
||||||
📄 <a href="https://arxiv.org/abs/2504.00906">[S2 Paper (COLM 2025)]</a>
|
📄 <a href="https://arxiv.org/abs/2504.00906">[S2 Paper (COLM 2025)]</a>
|
||||||
@@ -50,6 +56,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
## 🥳 Updates
|
## 🥳 Updates
|
||||||
|
- [x] **2025/10/02**: Released the [Agent S3 paper](https://arxiv.org/abs/2510.02250), setting a new SOTA of **69.9%** on OSWorld, with strong performance on WindowsAgentArena, and AndroidWorld!
|
||||||
- [x] **2025/08/01**: Agent S2.5 is released (gui-agents v0.2.5): simpler, better, and faster! New SOTA on [OSWorld-Verified](https://os-world.github.io)!
|
- [x] **2025/08/01**: Agent S2.5 is released (gui-agents v0.2.5): simpler, better, and faster! New SOTA on [OSWorld-Verified](https://os-world.github.io)!
|
||||||
- [x] **2025/07/07**: The [Agent S2 paper](https://arxiv.org/abs/2504.00906) is accepted to COLM 2025! See you in Montreal!
|
- [x] **2025/07/07**: The [Agent S2 paper](https://arxiv.org/abs/2504.00906) is accepted to COLM 2025! See you in Montreal!
|
||||||
- [x] **2025/04/27**: The Agent S paper won the Best Paper Award 🏆 at ICLR 2025 Agentic AI for Science Workshop!
|
- [x] **2025/04/27**: The Agent S paper won the Best Paper Award 🏆 at ICLR 2025 Agentic AI for Science Workshop!
|
||||||
@@ -77,36 +84,13 @@ Whether you're interested in AI, automation, or contributing to cutting-edge age
|
|||||||
|
|
||||||
## 🎯 Current Results
|
## 🎯 Current Results
|
||||||
|
|
||||||
<div align="center">
|
<p align="center">
|
||||||
<table border="0" cellspacing="0" cellpadding="5">
|
<img src="images/s3_results.png" alt="Agent S3 Results" width="700"/>
|
||||||
<tr>
|
</p>
|
||||||
<th>Benchmark</th>
|
|
||||||
<th>Agent S2.5</th>
|
|
||||||
<th>Previous SOTA</th>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>OSWorld Verified (100 step)</td>
|
|
||||||
<td><b>56.0%</b></td>
|
|
||||||
<td>53.1%</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>OSWorld Verified (50 step)</td>
|
|
||||||
<td><b>54.2%</b></td>
|
|
||||||
<td>50.6%</td>
|
|
||||||
</tr>
|
|
||||||
<!-- <tr>
|
|
||||||
<td>WindowsAgentArena</td>
|
|
||||||
<td>29.8%</td>
|
|
||||||
<td>19.5% (NAVI)</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>AndroidWorld</td>
|
|
||||||
<td>54.3%</td>
|
|
||||||
<td>46.8% (UI-TARS)</td>
|
|
||||||
</tr> -->
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
On OSWorld, Agent S3 alone reaches 62.6% in the 100-step setting, already exceeding the previous state of the art of 61.4% (Claude Sonnet 4.5). With the addition of Behavior Best-of-N, performance climbs even higher to 69.9%, bringing computer-use agents to within just a few points of human-level accuracy (72%).
|
||||||
|
|
||||||
|
Agent S3 also demonstrates strong zero-shot generalization. On WindowsAgentArena, accuracy rises from 50.2% using only Agent S3 to 56.6% by selecting from 3 rollouts. Similarly on AndroidWorld, performance improves from 68.1% to 71.6%
|
||||||
|
|
||||||
## 🛠️ Installation & Setup
|
## 🛠️ Installation & Setup
|
||||||
|
|
||||||
@@ -117,11 +101,11 @@ Whether you're interested in AI, automation, or contributing to cutting-edge age
|
|||||||
|
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
To install Agent S2.5 without cloning the repository, run
|
To install Agent S3 without cloning the repository, run
|
||||||
```bash
|
```bash
|
||||||
pip install gui-agents
|
pip install gui-agents
|
||||||
```
|
```
|
||||||
If you would like to test Agent S2.5 while making changes, clone the repository and install using
|
If you would like to test Agent S3 while making changes, clone the repository and install using
|
||||||
```
|
```
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
@@ -157,7 +141,9 @@ For optimal performance, we recommend [UI-TARS-1.5-7B](https://huggingface.co/By
|
|||||||
|
|
||||||
### CLI
|
### CLI
|
||||||
|
|
||||||
Run Agent S2.5 with the required parameters:
|
Note, this is running Agent S3, our improved agent, without bBoN.
|
||||||
|
|
||||||
|
Run Agent S3 with the required parameters:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
agent_s \
|
agent_s \
|
||||||
@@ -196,12 +182,12 @@ The grounding width and height should match the output coordinate resolution of
|
|||||||
|
|
||||||
### `gui_agents` SDK
|
### `gui_agents` SDK
|
||||||
|
|
||||||
First, we import the necessary modules. `AgentS2_5` is the main agent class for Agent S2.5. `OSWorldACI` is our grounding agent that translates agent actions into executable python code.
|
First, we import the necessary modules. `AgentS3` is the main agent class for Agent S3. `OSWorldACI` is our grounding agent that translates agent actions into executable python code.
|
||||||
```python
|
```python
|
||||||
import pyautogui
|
import pyautogui
|
||||||
import io
|
import io
|
||||||
from gui_agents.s2_5.agents.agent_s import AgentS2_5
|
from gui_agents.s3.agents.agent_s import AgentS3
|
||||||
from gui_agents.s2_5.agents.grounding import OSWorldACI
|
from gui_agents.s3.agents.grounding import OSWorldACI
|
||||||
|
|
||||||
# Load in your API keys.
|
# Load in your API keys.
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -243,7 +229,7 @@ engine_params_for_grounding = {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, we define our grounding agent and Agent S2.5.
|
Then, we define our grounding agent and Agent S3.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
grounding_agent = OSWorldACI(
|
grounding_agent = OSWorldACI(
|
||||||
@@ -254,7 +240,7 @@ grounding_agent = OSWorldACI(
|
|||||||
height=1080 # Optional: screen height
|
height=1080 # Optional: screen height
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = AgentS2_5(
|
agent = AgentS3(
|
||||||
engine_params,
|
engine_params,
|
||||||
grounding_agent,
|
grounding_agent,
|
||||||
platform=current_platform,
|
platform=current_platform,
|
||||||
@@ -282,11 +268,11 @@ info, action = agent.predict(instruction=instruction, observation=obs)
|
|||||||
exec(action[0])
|
exec(action[0])
|
||||||
```
|
```
|
||||||
|
|
||||||
Refer to `gui_agents/s2_5/cli_app.py` for more details on how the inference loop works.
|
Refer to `gui_agents/s3/cli_app.py` for more details on how the inference loop works.
|
||||||
|
|
||||||
### OSWorld
|
### OSWorld
|
||||||
|
|
||||||
To deploy Agent S2.5 in OSWorld, follow the [OSWorld Deployment instructions](osworld_setup/s2_5/OSWorld.md).
|
To deploy Agent S3 in OSWorld, follow the [OSWorld Deployment instructions](osworld_setup/s3/OSWorld.md).
|
||||||
|
|
||||||
## 💬 Citations
|
## 💬 Citations
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,102 @@
|
|||||||
|
import logging
|
||||||
|
import platform
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from gui_agents.s3.agents.grounding import ACI
|
||||||
|
from gui_agents.s3.agents.worker import Worker
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
|
||||||
|
class UIAgent:
|
||||||
|
"""Base class for UI automation agents"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
worker_engine_params: Dict,
|
||||||
|
grounding_agent: ACI,
|
||||||
|
platform: str = platform.system().lower(),
|
||||||
|
):
|
||||||
|
"""Initialize UIAgent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_engine_params: Configuration parameters for the worker LLM agent
|
||||||
|
grounding_agent: Instance of ACI class for UI interaction
|
||||||
|
platform: Operating system platform (macos, linux, windows)
|
||||||
|
"""
|
||||||
|
self.worker_engine_params = worker_engine_params
|
||||||
|
self.grounding_agent = grounding_agent
|
||||||
|
self.platform = platform
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset agent state"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def predict(self, instruction: str, observation: Dict) -> Tuple[Dict, List[str]]:
|
||||||
|
"""Generate next action prediction
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: Natural language instruction
|
||||||
|
observation: Current UI state observation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing agent info dictionary and list of actions
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentS3(UIAgent):
|
||||||
|
"""Agent that uses no hierarchy for less inference time"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
worker_engine_params: Dict,
|
||||||
|
grounding_agent: ACI,
|
||||||
|
platform: str = platform.system().lower(),
|
||||||
|
max_trajectory_length: int = 8,
|
||||||
|
enable_reflection: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize a minimalist AgentS2 without hierarchy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_engine_params: Configuration parameters for the worker agent.
|
||||||
|
grounding_agent: Instance of ACI class for UI interaction
|
||||||
|
platform: Operating system platform (darwin, linux, windows)
|
||||||
|
max_trajectory_length: Maximum number of image turns to keep
|
||||||
|
enable_reflection: Creates a reflection agent to assist the worker agent
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
worker_engine_params, grounding_agent, platform
|
||||||
|
)
|
||||||
|
self.max_trajectory_length = max_trajectory_length
|
||||||
|
self.enable_reflection = enable_reflection
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset agent state and initialize components"""
|
||||||
|
self.executor = Worker(
|
||||||
|
worker_engine_params=self.worker_engine_params,
|
||||||
|
grounding_agent=self.grounding_agent,
|
||||||
|
platform=self.platform,
|
||||||
|
max_trajectory_length=self.max_trajectory_length,
|
||||||
|
enable_reflection=self.enable_reflection,
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, instruction: str, observation: Dict) -> Tuple[Dict, List[str]]:
|
||||||
|
# Initialize the three info dictionaries
|
||||||
|
executor_info, actions = self.executor.generate_next_action(
|
||||||
|
instruction=instruction, obs=observation
|
||||||
|
)
|
||||||
|
|
||||||
|
# concatenate the three info dictionaries
|
||||||
|
info = {
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for d in [executor_info or {}]
|
||||||
|
for k, v in d.items()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, actions
|
||||||
@@ -0,0 +1,278 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
|
from gui_agents.s3.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||||
|
from gui_agents.s3.utils.common_utils import call_llm_safe, split_thinking_response
|
||||||
|
from gui_agents.s3.core.mllm import LMMAgent
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
def extract_code_block(action: str) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Extract code and determine type from action string."""
|
||||||
|
if "```python" in action:
|
||||||
|
code_type = "python"
|
||||||
|
code = action.split("```python")[1].split("```")[0].strip()
|
||||||
|
elif "```bash" in action:
|
||||||
|
code_type = "bash"
|
||||||
|
code = action.split("```bash")[1].split("```")[0].strip()
|
||||||
|
elif "```" in action:
|
||||||
|
code_type = None
|
||||||
|
code = action.split("```")[1].split("```")[0].strip()
|
||||||
|
else:
|
||||||
|
code_type = None
|
||||||
|
code = None
|
||||||
|
|
||||||
|
logger.debug(f"Extracted code block: type={code_type}, length={len(code) if code else 0}")
|
||||||
|
return code_type, code
|
||||||
|
|
||||||
|
|
||||||
|
def execute_code(code_type: str, code: str, env_controller) -> Dict:
|
||||||
|
"""Execute code based on its type."""
|
||||||
|
# Log the full code being executed (untruncated)
|
||||||
|
logger.info(f"CODING_AGENT_CODE_EXECUTION - Type: {code_type}\nCode:\n{code}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if code_type == "bash":
|
||||||
|
result = env_controller.run_bash_script(code, timeout=30)
|
||||||
|
elif code_type == "python":
|
||||||
|
result = env_controller.run_python_script(code)
|
||||||
|
else:
|
||||||
|
result = {"status": "error", "error": f"Unknown code type: {code_type}"}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing {code_type} code: {e}")
|
||||||
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def format_result(result: Dict, step_count: int) -> str:
|
||||||
|
"""Format execution result into context string."""
|
||||||
|
if not result:
|
||||||
|
logger.warning(f"Step {step_count + 1}: No result returned from execution")
|
||||||
|
return f"""
|
||||||
|
Step {step_count + 1} Error:
|
||||||
|
Error: No result returned from execution
|
||||||
|
"""
|
||||||
|
|
||||||
|
status = result.get('status', 'unknown')
|
||||||
|
return_code = result.get('returncode', result.get('return_code', -1))
|
||||||
|
|
||||||
|
# Handle different response structures for bash vs python
|
||||||
|
if 'returncode' in result:
|
||||||
|
# Bash script response
|
||||||
|
output = result.get('output', '') # Contains both stdout and stderr merged
|
||||||
|
error = result.get('error', '') # Always empty for bash
|
||||||
|
else:
|
||||||
|
# Python script response
|
||||||
|
output = result.get('output', '') # stdout only
|
||||||
|
error = result.get('error', '') # stderr only
|
||||||
|
|
||||||
|
logger.debug(f"Step {step_count + 1}: Status={status}, Return Code={return_code}")
|
||||||
|
|
||||||
|
# Format with better structure for multi-line outputs
|
||||||
|
result_text = f"Step {step_count + 1} Result:\n"
|
||||||
|
result_text += f"Status: {status}\n"
|
||||||
|
result_text += f"Return Code: {return_code}\n"
|
||||||
|
|
||||||
|
if output:
|
||||||
|
result_text += f"Output:\n{output}\n"
|
||||||
|
|
||||||
|
if error:
|
||||||
|
result_text += f"Error:\n{error}\n"
|
||||||
|
|
||||||
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
|
class CodeAgent:
|
||||||
|
"""A dedicated agent for executing code with a budget of steps."""
|
||||||
|
|
||||||
|
def __init__(self, engine_params: Dict, budget: int = 20):
|
||||||
|
"""Initialize the CodeAgent."""
|
||||||
|
if not engine_params:
|
||||||
|
raise ValueError("engine_params cannot be None or empty")
|
||||||
|
|
||||||
|
self.engine_params = engine_params
|
||||||
|
self.budget = budget
|
||||||
|
self.agent = None
|
||||||
|
|
||||||
|
logger.info(f"CodeAgent initialized with budget={budget}")
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the code agent state."""
|
||||||
|
logger.debug("Resetting CodeAgent state")
|
||||||
|
self.agent = LMMAgent(
|
||||||
|
engine_params=self.engine_params,
|
||||||
|
system_prompt=PROCEDURAL_MEMORY.CODE_AGENT_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, task_instruction: str, screenshot: str, env_controller) -> Dict:
|
||||||
|
"""Execute code for the given task with a budget of steps."""
|
||||||
|
logger.info(f"Starting code execution for task: {task_instruction}")
|
||||||
|
logger.info(f"Budget: {self.budget} steps")
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
# Add initial task instruction and screenshot context as user message
|
||||||
|
context = f"Task: {task_instruction}\n\nCurrent screenshot is provided for context."
|
||||||
|
self.agent.add_message(context, image_content=screenshot, role="user")
|
||||||
|
|
||||||
|
step_count = 0
|
||||||
|
execution_history = []
|
||||||
|
|
||||||
|
while step_count < self.budget:
|
||||||
|
logger.info(f"Step {step_count + 1}/{self.budget}")
|
||||||
|
|
||||||
|
# Get assistant response (thoughts and code)
|
||||||
|
response = call_llm_safe(self.agent, temperature=1)
|
||||||
|
|
||||||
|
# Log the latest message from the coding agent (untruncated)
|
||||||
|
logger.info(f"CODING_AGENT_LATEST_MESSAGE - Step {step_count + 1}:\n{response}")
|
||||||
|
|
||||||
|
# Check if response is None or empty
|
||||||
|
if not response or response.strip() == "":
|
||||||
|
error_msg = f"Step {step_count + 1}: LLM returned empty response"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
# Parse the response to extract action
|
||||||
|
action, thoughts = split_thinking_response(response)
|
||||||
|
|
||||||
|
execution_history.append({
|
||||||
|
"step": step_count + 1,
|
||||||
|
"action": action,
|
||||||
|
"thoughts": thoughts
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for completion signals
|
||||||
|
action_upper = action.upper().strip()
|
||||||
|
if action_upper == "DONE":
|
||||||
|
logger.info(f"Step {step_count + 1}: Task completed successfully")
|
||||||
|
completion_reason = "DONE"
|
||||||
|
break
|
||||||
|
elif action_upper == "FAIL":
|
||||||
|
logger.info(f"Step {step_count + 1}: Task failed by agent request")
|
||||||
|
completion_reason = "FAIL"
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract and execute code
|
||||||
|
code_type, code = extract_code_block(action)
|
||||||
|
|
||||||
|
if code:
|
||||||
|
result = execute_code(code_type, code, env_controller)
|
||||||
|
# Prepare formatted output and error for logging
|
||||||
|
output = result.get("output", "")
|
||||||
|
error = result.get("error", "")
|
||||||
|
message = result.get("message", "")
|
||||||
|
status = result.get("status", "")
|
||||||
|
|
||||||
|
log_lines = [
|
||||||
|
f"CODING_AGENT_EXECUTION_RESULT - Step {step_count + 1}:",
|
||||||
|
f"Status: {status}" if status else None,
|
||||||
|
]
|
||||||
|
|
||||||
|
if output:
|
||||||
|
log_lines.append("Output:\n" + ("-" * 40) + f"\n{output}\n" + ("-" * 40))
|
||||||
|
if error:
|
||||||
|
log_lines.append("Error:\n" + ("!" * 40) + f"\n{error}\n" + ("!" * 40))
|
||||||
|
if message and not output and not error:
|
||||||
|
log_lines.append("Message:\n" + ("-" * 40) + f"\n{message}\n" + ("-" * 40))
|
||||||
|
|
||||||
|
# Remove None entries and join
|
||||||
|
formatted_log = "\n".join([line for line in log_lines if line])
|
||||||
|
logger.info(formatted_log)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Step {step_count + 1}: No code block found in action")
|
||||||
|
result = {"status": "skipped", "message": "No code block found"}
|
||||||
|
logger.info(
|
||||||
|
f"CODING_AGENT_EXECUTION_RESULT - Step {step_count + 1}:\n"
|
||||||
|
f"Status: skipped\n"
|
||||||
|
f"Message:\n{'-' * 40}\n{result['message']}\n{'-' * 40}"
|
||||||
|
)
|
||||||
|
# Add assistant's thoughts and code to message history
|
||||||
|
self.agent.add_message(response, role="assistant")
|
||||||
|
|
||||||
|
# Process result and add formatted environment results as user message
|
||||||
|
result_context = format_result(result, step_count)
|
||||||
|
self.agent.add_message(result_context, role="user")
|
||||||
|
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
# Handle budget exhaustion
|
||||||
|
if 'completion_reason' not in locals():
|
||||||
|
logger.info(f"Budget exhausted after {step_count} steps")
|
||||||
|
completion_reason = f"BUDGET_EXHAUSTED_AFTER_{step_count}_STEPS"
|
||||||
|
|
||||||
|
# Generate final summary
|
||||||
|
logger.info("Generating execution summary")
|
||||||
|
summary = self._generate_summary(execution_history, task_instruction)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"task_instruction": task_instruction,
|
||||||
|
"completion_reason": completion_reason,
|
||||||
|
"summary": summary,
|
||||||
|
"execution_history": execution_history,
|
||||||
|
"steps_executed": step_count,
|
||||||
|
"budget": self.budget
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Code execution completed: steps={step_count}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _generate_summary(self, execution_history: List[Dict], task_instruction: str) -> str:
|
||||||
|
"""Generate summary of code execution session."""
|
||||||
|
if not execution_history:
|
||||||
|
logger.info("No execution history to summarize")
|
||||||
|
return "No actions were executed."
|
||||||
|
|
||||||
|
logger.info(f"Generated summary for {len(execution_history)} steps")
|
||||||
|
|
||||||
|
# Build detailed execution context for summary agent
|
||||||
|
execution_context = f"Task: {task_instruction}\n\nExecution Steps:\n"
|
||||||
|
|
||||||
|
for step in execution_history:
|
||||||
|
step_num = step['step']
|
||||||
|
thoughts = step.get('thoughts', '')
|
||||||
|
action = step.get('action', '')
|
||||||
|
|
||||||
|
execution_context += f"\nStep {step_num}:\n"
|
||||||
|
if thoughts:
|
||||||
|
execution_context += f"Thoughts: {thoughts}\n"
|
||||||
|
execution_context += f"Code: {action}\n"
|
||||||
|
|
||||||
|
# Create summary prompt with same context as coding agent
|
||||||
|
summary_prompt = f"""
|
||||||
|
{execution_context}
|
||||||
|
|
||||||
|
Please provide a concise summary of the code execution session. Focus on:
|
||||||
|
|
||||||
|
1. The code logic implemented at each step
|
||||||
|
2. The outputs and results produced by each code execution
|
||||||
|
3. The progression of the solution approach
|
||||||
|
|
||||||
|
Do not make judgments about success or failure. Simply describe what was attempted and what resulted.
|
||||||
|
|
||||||
|
Keep the summary under 150 words and use clear, factual language.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Generate summary using LLM with dedicated summary system prompt
|
||||||
|
try:
|
||||||
|
summary_agent = LMMAgent(
|
||||||
|
engine_params=self.engine_params,
|
||||||
|
system_prompt=PROCEDURAL_MEMORY.CODE_SUMMARY_AGENT_PROMPT
|
||||||
|
)
|
||||||
|
summary_agent.add_message(summary_prompt, role="user")
|
||||||
|
summary = call_llm_safe(summary_agent, temperature=1)
|
||||||
|
|
||||||
|
if not summary or summary.strip() == "":
|
||||||
|
summary = "Summary generation failed - no response from LLM"
|
||||||
|
logger.warning("Summary generation failed - empty response from LLM")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
summary = f"Summary generation failed: {str(e)}"
|
||||||
|
logger.error(f"Error generating summary: {e}")
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,646 @@
|
|||||||
|
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytesseract
|
||||||
|
from PIL import Image
|
||||||
|
from pytesseract import Output
|
||||||
|
|
||||||
|
from gui_agents.s3.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||||
|
from gui_agents.s3.core.mllm import LMMAgent
|
||||||
|
from gui_agents.s3.utils.common_utils import call_llm_safe
|
||||||
|
from gui_agents.s3.agents.code_agent import CodeAgent
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
|
||||||
|
class ACI:
|
||||||
|
def __init__(self):
|
||||||
|
self.notes: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
# Agent action decorator
|
||||||
|
def agent_action(func):
|
||||||
|
func.is_agent_action = True
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
UBUNTU_APP_SETUP = f"""import subprocess;
|
||||||
|
import difflib;
|
||||||
|
import pyautogui;
|
||||||
|
pyautogui.press('escape');
|
||||||
|
time.sleep(0.5);
|
||||||
|
output = subprocess.check_output(['wmctrl', '-lx']);
|
||||||
|
output = output.decode('utf-8').splitlines();
|
||||||
|
window_titles = [line.split(None, 4)[2] for line in output];
|
||||||
|
closest_matches = difflib.get_close_matches('APP_NAME', window_titles, n=1, cutoff=0.1);
|
||||||
|
if closest_matches:
|
||||||
|
closest_match = closest_matches[0];
|
||||||
|
for line in output:
|
||||||
|
if closest_match in line:
|
||||||
|
window_id = line.split()[0]
|
||||||
|
break;
|
||||||
|
subprocess.run(['wmctrl', '-ia', window_id])
|
||||||
|
subprocess.run(['wmctrl', '-ir', window_id, '-b', 'add,maximized_vert,maximized_horz'])
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SET_CELL_VALUES_CMD = """import uno
|
||||||
|
import subprocess
|
||||||
|
import unicodedata, json
|
||||||
|
|
||||||
|
def identify_document_type(component):
|
||||||
|
if component.supportsService("com.sun.star.sheet.SpreadsheetDocument"):
|
||||||
|
return "Calc"
|
||||||
|
|
||||||
|
if component.supportsService("com.sun.star.text.TextDocument"):
|
||||||
|
return "Writer"
|
||||||
|
|
||||||
|
if component.supportsService("com.sun.star.sheet.PresentationDocument"):
|
||||||
|
return "Impress"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _norm_name(s: str | None) -> str | None:
|
||||||
|
if s is None:
|
||||||
|
return None
|
||||||
|
if "\\\\u" in s or "\\\\U" in s or "\\\\x" in s:
|
||||||
|
try:
|
||||||
|
# json.loads handles all the escape forms safely
|
||||||
|
s = json.loads(f"{{s}}")
|
||||||
|
except Exception:
|
||||||
|
# fallback: best-effort
|
||||||
|
try:
|
||||||
|
s = s.encode("utf-8").decode("unicode_escape")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Normalize (NFC works well across platforms)
|
||||||
|
return unicodedata.normalize("NFC", s)
|
||||||
|
|
||||||
|
def cell_ref_to_indices(cell_ref):
|
||||||
|
column_letters = ''.join(filter(str.isalpha, cell_ref))
|
||||||
|
row_number = ''.join(filter(str.isdigit, cell_ref))
|
||||||
|
|
||||||
|
col = sum((ord(char.upper()) - ord('A') + 1) * (26**idx) for idx, char in enumerate(reversed(column_letters))) - 1
|
||||||
|
row = int(row_number) - 1
|
||||||
|
return col, row
|
||||||
|
|
||||||
|
def set_cell_values(new_cell_values: dict[str, str], app_name: str = "Untitled 1", sheet_name: str = "Sheet1"):
|
||||||
|
app_name = _norm_name(app_name)
|
||||||
|
sheet_name = _norm_name(sheet_name)
|
||||||
|
|
||||||
|
new_cell_values_idx = {{}}
|
||||||
|
for k, v in new_cell_values.items():
|
||||||
|
try:
|
||||||
|
col, row = cell_ref_to_indices(k)
|
||||||
|
except:
|
||||||
|
col = row = None
|
||||||
|
|
||||||
|
if col is not None and row is not None:
|
||||||
|
new_cell_values_idx[(col, row)] = v
|
||||||
|
|
||||||
|
# Clean up previous TCP connections.
|
||||||
|
subprocess.run(
|
||||||
|
'echo \"osworld-public-evaluation\" | sudo -S ss --kill --tcp state TIME-WAIT sport = :2002',
|
||||||
|
shell=True,
|
||||||
|
check=True,
|
||||||
|
text=True,
|
||||||
|
capture_output=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dynamically allow soffice to listen on port 2002.
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"soffice",
|
||||||
|
"--accept=socket,host=localhost,port=2002;urp;StarOffice.Service"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
local_context = uno.getComponentContext()
|
||||||
|
resolver = local_context.ServiceManager.createInstanceWithContext(
|
||||||
|
"com.sun.star.bridge.UnoUrlResolver", local_context
|
||||||
|
)
|
||||||
|
context = resolver.resolve(
|
||||||
|
f"uno:socket,host=localhost,port=2002;urp;StarOffice.ComponentContext"
|
||||||
|
)
|
||||||
|
desktop = context.ServiceManager.createInstanceWithContext(
|
||||||
|
"com.sun.star.frame.Desktop", context
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect all LibreOffice-related opened windows.
|
||||||
|
documents = []
|
||||||
|
for i, component in enumerate(desktop.Components):
|
||||||
|
title = component.Title
|
||||||
|
doc_type = identify_document_type(component)
|
||||||
|
documents.append((i, component, title, doc_type))
|
||||||
|
|
||||||
|
# Find the LibreOffice Calc app and the sheet of interest.
|
||||||
|
spreadsheet = [doc for doc in documents if doc[3] == "Calc"]
|
||||||
|
selected_spreadsheet = [doc for doc in spreadsheet if doc[2] == app_name]
|
||||||
|
if spreadsheet:
|
||||||
|
try:
|
||||||
|
if selected_spreadsheet:
|
||||||
|
spreadsheet = selected_spreadsheet[0][1]
|
||||||
|
else:
|
||||||
|
spreadsheet = spreadsheet[0][1]
|
||||||
|
|
||||||
|
sheet = spreadsheet.Sheets.getByName(sheet_name)
|
||||||
|
except:
|
||||||
|
raise ValueError(f"Could not find sheet {{sheet_name}} in {{app_name}}.")
|
||||||
|
|
||||||
|
for (col, row), value in new_cell_values_idx.items():
|
||||||
|
cell = sheet.getCellByPosition(col, row)
|
||||||
|
|
||||||
|
# Set the cell value.
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
cell.Value = value
|
||||||
|
elif isinstance(value, str):
|
||||||
|
if value.startswith("="):
|
||||||
|
cell.Formula = value
|
||||||
|
else:
|
||||||
|
cell.String = value
|
||||||
|
elif isinstance(value, bool):
|
||||||
|
cell.Value = 1 if value else 0
|
||||||
|
elif value is None:
|
||||||
|
cell.clearContents(0)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported cell value type: {{type(value)}}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not find LibreOffice Calc app corresponding to {{app_name}}.")
|
||||||
|
|
||||||
|
set_cell_values(new_cell_values={cell_values}, app_name="{app_name}", sheet_name="{sheet_name}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ACI primitives are parameterized by description, and coordinate generation uses a pretrained grounding model
|
||||||
|
class OSWorldACI(ACI):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
platform: str,
|
||||||
|
engine_params_for_generation: Dict,
|
||||||
|
engine_params_for_grounding: Dict,
|
||||||
|
width: int = 1920,
|
||||||
|
height: int = 1080,
|
||||||
|
code_agent_budget: int = 20,
|
||||||
|
code_agent_engine_params: Dict = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.env = env
|
||||||
|
self.platform = (
|
||||||
|
platform # Dictates how the switch_applications agent action works.
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure scaling
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
|
||||||
|
# Maintain state for save_to_knowledge
|
||||||
|
self.notes = []
|
||||||
|
|
||||||
|
# Screenshot used during ACI execution
|
||||||
|
self.obs = None
|
||||||
|
|
||||||
|
# Configure the visual grounding model responsible for coordinate generation
|
||||||
|
self.grounding_model = LMMAgent(engine_params_for_grounding)
|
||||||
|
self.engine_params_for_grounding = engine_params_for_grounding
|
||||||
|
|
||||||
|
# Configure text grounding agent
|
||||||
|
self.text_span_agent = LMMAgent(
|
||||||
|
engine_params=engine_params_for_generation,
|
||||||
|
system_prompt=PROCEDURAL_MEMORY.PHRASE_TO_WORD_COORDS_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure code agent
|
||||||
|
code_agent_engine_params = code_agent_engine_params or engine_params_for_generation
|
||||||
|
self.code_agent = CodeAgent(code_agent_engine_params, code_agent_budget)
|
||||||
|
|
||||||
|
# Store task instruction for code agent
|
||||||
|
self.current_task_instruction = None
|
||||||
|
self.last_code_agent_result = None
|
||||||
|
|
||||||
|
# Given the state and worker's referring expression, use the grounding model to generate (x,y)
|
||||||
|
def generate_coords(self, ref_expr: str, obs: Dict) -> List[int]:
|
||||||
|
|
||||||
|
# Reset the grounding model state
|
||||||
|
self.grounding_model.reset()
|
||||||
|
|
||||||
|
# Configure the context, UI-TARS demo does not use system prompt
|
||||||
|
prompt = f"Query:{ref_expr}\nOutput only the coordinate of one point in your response.\n"
|
||||||
|
self.grounding_model.add_message(
|
||||||
|
text_content=prompt, image_content=obs["screenshot"], put_text_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate and parse coordinates
|
||||||
|
response = call_llm_safe(self.grounding_model)
|
||||||
|
print("RAW GROUNDING MODEL RESPONSE:", response)
|
||||||
|
numericals = re.findall(r"\d+", response)
|
||||||
|
assert len(numericals) >= 2
|
||||||
|
return [int(numericals[0]), int(numericals[1])]
|
||||||
|
|
||||||
|
# Calls pytesseract to generate word level bounding boxes for text grounding
|
||||||
|
def get_ocr_elements(self, b64_image_data: str) -> Tuple[str, List]:
|
||||||
|
image = Image.open(BytesIO(b64_image_data))
|
||||||
|
image_data = pytesseract.image_to_data(image, output_type=Output.DICT)
|
||||||
|
|
||||||
|
# Clean text by removing leading and trailing spaces and non-alphabetical characters, but keeping punctuation
|
||||||
|
for i, word in enumerate(image_data["text"]):
|
||||||
|
image_data["text"][i] = re.sub(
|
||||||
|
r"^[^a-zA-Z\s.,!?;:\-\+]+|[^a-zA-Z\s.,!?;:\-\+]+$", "", word
|
||||||
|
)
|
||||||
|
|
||||||
|
ocr_elements = []
|
||||||
|
ocr_table = "Text Table:\nWord id\tText\n"
|
||||||
|
# Obtain the <id, text, group number, word number> for each valid element
|
||||||
|
grouping_map = defaultdict(list)
|
||||||
|
ocr_id = 0
|
||||||
|
for i in range(len(image_data["text"])):
|
||||||
|
block_num = image_data["block_num"][i]
|
||||||
|
if image_data["text"][i]:
|
||||||
|
grouping_map[block_num].append(image_data["text"][i])
|
||||||
|
ocr_table += f"{ocr_id}\t{image_data['text'][i]}\n"
|
||||||
|
ocr_elements.append(
|
||||||
|
{
|
||||||
|
"id": ocr_id,
|
||||||
|
"text": image_data["text"][i],
|
||||||
|
"group_num": block_num,
|
||||||
|
"word_num": len(grouping_map[block_num]),
|
||||||
|
"left": image_data["left"][i],
|
||||||
|
"top": image_data["top"][i],
|
||||||
|
"width": image_data["width"][i],
|
||||||
|
"height": image_data["height"][i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ocr_id += 1
|
||||||
|
|
||||||
|
return ocr_table, ocr_elements
|
||||||
|
|
||||||
|
# Given the state and worker's text phrase, generate the coords of the first/last word in the phrase
|
||||||
|
def generate_text_coords(
|
||||||
|
self, phrase: str, obs: Dict, alignment: str = ""
|
||||||
|
) -> List[int]:
|
||||||
|
|
||||||
|
ocr_table, ocr_elements = self.get_ocr_elements(obs["screenshot"])
|
||||||
|
|
||||||
|
alignment_prompt = ""
|
||||||
|
if alignment == "start":
|
||||||
|
alignment_prompt = "**Important**: Output the word id of the FIRST word in the provided phrase.\n"
|
||||||
|
elif alignment == "end":
|
||||||
|
alignment_prompt = "**Important**: Output the word id of the LAST word in the provided phrase.\n"
|
||||||
|
|
||||||
|
# Load LLM prompt
|
||||||
|
self.text_span_agent.reset()
|
||||||
|
self.text_span_agent.add_message(
|
||||||
|
alignment_prompt + "Phrase: " + phrase + "\n" + ocr_table, role="user"
|
||||||
|
)
|
||||||
|
self.text_span_agent.add_message(
|
||||||
|
"Screenshot:\n", image_content=obs["screenshot"], role="user"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Obtain the target element
|
||||||
|
response = call_llm_safe(self.text_span_agent)
|
||||||
|
print("TEXT SPAN AGENT RESPONSE:", response)
|
||||||
|
numericals = re.findall(r"\d+", response)
|
||||||
|
if len(numericals) > 0:
|
||||||
|
text_id = int(numericals[-1])
|
||||||
|
else:
|
||||||
|
text_id = 0
|
||||||
|
elem = ocr_elements[text_id]
|
||||||
|
|
||||||
|
# Compute the element coordinates
|
||||||
|
if alignment == "start":
|
||||||
|
coords = [elem["left"], elem["top"] + (elem["height"] // 2)]
|
||||||
|
elif alignment == "end":
|
||||||
|
coords = [elem["left"] + elem["width"], elem["top"] + (elem["height"] // 2)]
|
||||||
|
else:
|
||||||
|
coords = [
|
||||||
|
elem["left"] + (elem["width"] // 2),
|
||||||
|
elem["top"] + (elem["height"] // 2),
|
||||||
|
]
|
||||||
|
return coords
|
||||||
|
|
||||||
|
def assign_screenshot(self, obs: Dict):
|
||||||
|
self.obs = obs
|
||||||
|
|
||||||
|
def set_task_instruction(self, task_instruction: str):
|
||||||
|
"""Set the current task instruction for the code agent."""
|
||||||
|
self.current_task_instruction = task_instruction
|
||||||
|
|
||||||
|
# Resize from grounding model dim into OSWorld dim (1920 * 1080)
|
||||||
|
def resize_coordinates(self, coordinates: List[int]) -> List[int]:
|
||||||
|
grounding_width = self.engine_params_for_grounding["grounding_width"]
|
||||||
|
grounding_height = self.engine_params_for_grounding["grounding_height"]
|
||||||
|
|
||||||
|
return [
|
||||||
|
round(coordinates[0] * self.width / grounding_width),
|
||||||
|
round(coordinates[1] * self.height / grounding_height),
|
||||||
|
]
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def click(
|
||||||
|
self,
|
||||||
|
element_description: str,
|
||||||
|
num_clicks: int = 1,
|
||||||
|
button_type: str = "left",
|
||||||
|
hold_keys: List = [],
|
||||||
|
):
|
||||||
|
"""Click on the element
|
||||||
|
Args:
|
||||||
|
element_description:str, a detailed descriptions of which element to click on. This description should be at least a full sentence.
|
||||||
|
num_clicks:int, number of times to click the element
|
||||||
|
button_type:str, which mouse button to press can be "left", "middle", or "right"
|
||||||
|
hold_keys:List, list of keys to hold while clicking
|
||||||
|
"""
|
||||||
|
coords1 = self.generate_coords(element_description, self.obs)
|
||||||
|
x, y = self.resize_coordinates(coords1)
|
||||||
|
command = "import pyautogui; "
|
||||||
|
|
||||||
|
# TODO: specified duration?
|
||||||
|
for k in hold_keys:
|
||||||
|
command += f"pyautogui.keyDown({repr(k)}); "
|
||||||
|
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
|
||||||
|
for k in hold_keys:
|
||||||
|
command += f"pyautogui.keyUp({repr(k)}); "
|
||||||
|
# Return pyautoguicode to click on the element
|
||||||
|
return command
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def switch_applications(self, app_code):
|
||||||
|
"""Switch to a different application that is already open
|
||||||
|
Args:
|
||||||
|
app_code:str the code name of the application to switch to from the provided list of open applications
|
||||||
|
"""
|
||||||
|
if self.platform == "darwin":
|
||||||
|
return f"import pyautogui; import time; pyautogui.hotkey('command', 'space', interval=0.5); pyautogui.typewrite({repr(app_code)}); pyautogui.press('enter'); time.sleep(1.0)"
|
||||||
|
elif self.platform == "linux":
|
||||||
|
return UBUNTU_APP_SETUP.replace("APP_NAME", app_code)
|
||||||
|
elif self.platform == "windows":
|
||||||
|
return f"import pyautogui; import time; pyautogui.hotkey('win', 'd', interval=0.5); pyautogui.typewrite({repr(app_code)}); pyautogui.press('enter'); time.sleep(1.0)"
|
||||||
|
else:
|
||||||
|
assert False, f"Unsupported platform: {self.platform}. Supported platforms are: darwin, linux, windows."
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def open(self, app_or_filename: str):
|
||||||
|
"""Open any application or file with name app_or_filename. Use this action to open applications or files on the desktop, do not open manually.
|
||||||
|
Args:
|
||||||
|
app_or_filename:str, the name of the application or filename to open
|
||||||
|
"""
|
||||||
|
if self.platform == "linux":
|
||||||
|
return f"import pyautogui; pyautogui.hotkey('win'); time.sleep(0.5); pyautogui.write({repr(app_or_filename)}); time.sleep(1.0); pyautogui.hotkey('enter'); time.sleep(0.5)"
|
||||||
|
elif self.platform == "darwin":
|
||||||
|
return f"import pyautogui; import time; pyautogui.hotkey('command', 'space', interval=0.5); pyautogui.typewrite({repr(app_or_filename)}); pyautogui.press('enter'); time.sleep(1.0)"
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def type(
|
||||||
|
self,
|
||||||
|
element_description: Optional[str] = None,
|
||||||
|
text: str = "",
|
||||||
|
overwrite: bool = False,
|
||||||
|
enter: bool = False,
|
||||||
|
):
|
||||||
|
"""Type text/unicode into a specific element
|
||||||
|
Args:
|
||||||
|
element_description:str, a detailed description of which element to enter text in. This description should be at least a full sentence.
|
||||||
|
text:str, the text to type
|
||||||
|
overwrite:bool, Assign it to True if the text should overwrite the existing text, otherwise assign it to False. Using this argument clears all text in an element.
|
||||||
|
enter:bool, Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
|
||||||
|
"""
|
||||||
|
command = "import pyautogui; "
|
||||||
|
command += (
|
||||||
|
"\ntry:\n"
|
||||||
|
" import pyperclip\n"
|
||||||
|
"except ImportError:\n"
|
||||||
|
" import subprocess\n"
|
||||||
|
" subprocess.run('echo \"osworld-public-evaluation\" | sudo -S apt-get install -y xclip xsel', shell=True, check=True)\n"
|
||||||
|
" subprocess.check_call([subprocess.sys.executable, '-m', 'pip', 'install', 'pyperclip'])\n"
|
||||||
|
" import pyperclip\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if element_description is not None:
|
||||||
|
coords1 = self.generate_coords(element_description, self.obs)
|
||||||
|
x, y = self.resize_coordinates(coords1)
|
||||||
|
command += f"pyautogui.click({x}, {y}); "
|
||||||
|
|
||||||
|
if overwrite:
|
||||||
|
command += (
|
||||||
|
f"pyautogui.hotkey({repr('command' if self.platform == 'darwin' else 'ctrl')}, 'a'); "
|
||||||
|
"pyautogui.press('backspace'); "
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if text contains Unicode characters that pyautogui.write() can't handle
|
||||||
|
has_unicode = any(ord(char) > 127 for char in text)
|
||||||
|
|
||||||
|
if has_unicode:
|
||||||
|
# Use clipboard method for Unicode characters
|
||||||
|
command += f"pyperclip.copy({repr(text)}); "
|
||||||
|
command += f"pyautogui.hotkey({repr('command' if self.platform == 'darwin' else 'ctrl')}, 'v'); "
|
||||||
|
else:
|
||||||
|
# Use regular pyautogui.write() for ASCII text
|
||||||
|
command += f"pyautogui.write({repr(text)}); "
|
||||||
|
|
||||||
|
if enter:
|
||||||
|
command += "pyautogui.press('enter'); "
|
||||||
|
return command
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def save_to_knowledge(self, text: List[str]):
|
||||||
|
"""Save facts, elements, texts, etc. to a long-term knowledge bank for reuse during this task. Can be used for copy-pasting text, saving elements, etc.
|
||||||
|
Args:
|
||||||
|
text:List[str] the text to save to the knowledge
|
||||||
|
"""
|
||||||
|
self.notes.extend(text)
|
||||||
|
return """WAIT"""
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def drag_and_drop(
|
||||||
|
self, starting_description: str, ending_description: str, hold_keys: List = []
|
||||||
|
):
|
||||||
|
"""Drag from the starting description to the ending description
|
||||||
|
Args:
|
||||||
|
starting_description:str, a very detailed description of where to start the drag action. This description should be at least a full sentence.
|
||||||
|
ending_description:str, a very detailed description of where to end the drag action. This description should be at least a full sentence.
|
||||||
|
hold_keys:List list of keys to hold while dragging
|
||||||
|
"""
|
||||||
|
coords1 = self.generate_coords(starting_description, self.obs)
|
||||||
|
coords2 = self.generate_coords(ending_description, self.obs)
|
||||||
|
x1, y1 = self.resize_coordinates(coords1)
|
||||||
|
x2, y2 = self.resize_coordinates(coords2)
|
||||||
|
|
||||||
|
command = "import pyautogui; "
|
||||||
|
|
||||||
|
command += f"pyautogui.moveTo({x1}, {y1}); "
|
||||||
|
# TODO: specified duration?
|
||||||
|
for k in hold_keys:
|
||||||
|
command += f"pyautogui.keyDown({repr(k)}); "
|
||||||
|
command += f"pyautogui.dragTo({x2}, {y2}, duration=1., button='left'); pyautogui.mouseUp(); "
|
||||||
|
for k in hold_keys:
|
||||||
|
command += f"pyautogui.keyUp({repr(k)}); "
|
||||||
|
|
||||||
|
# Return pyautoguicode to drag and drop the elements
|
||||||
|
|
||||||
|
return command
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def highlight_text_span(self, starting_phrase: str, ending_phrase: str, button: str = "left"):
|
||||||
|
"""Highlight a text span between a provided starting phrase and ending phrase. Use this to highlight words, lines, and paragraphs.
|
||||||
|
Args:
|
||||||
|
starting_phrase:str, the phrase that denotes the start of the text span you want to highlight. If you only want to highlight one word, just pass in that single word.
|
||||||
|
ending_phrase:str, the phrase that denotes the end of the text span you want to highlight. If you only want to highlight one word, just pass in that single word.
|
||||||
|
button:str, the button to use to highlight the text span. Defaults to "left". Can be "left", "right", or "middle".
|
||||||
|
"""
|
||||||
|
coords1 = self.generate_text_coords(starting_phrase, self.obs, alignment="start")
|
||||||
|
coords2 = self.generate_text_coords(ending_phrase, self.obs, alignment="end")
|
||||||
|
x1, y1 = coords1
|
||||||
|
x2, y2 = coords2
|
||||||
|
|
||||||
|
command = "import pyautogui; "
|
||||||
|
command += f"pyautogui.moveTo({x1}, {y1}); "
|
||||||
|
command += f"pyautogui.dragTo({x2}, {y2}, duration=1., button='{button}'); pyautogui.mouseUp(); "
|
||||||
|
|
||||||
|
# Return pyautoguicode to drag and drop the elements
|
||||||
|
return command
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def set_cell_values(
|
||||||
|
self, cell_values: Dict[str, Any], app_name: str, sheet_name: str
|
||||||
|
):
|
||||||
|
"""Use this to set individual cell values in a spreadsheet. For example, setting A2 to "hello" would be done by passing {"A2": "hello"} as cell_values. The sheet must be opened before this command can be used.
|
||||||
|
Args:
|
||||||
|
cell_values: Dict[str, Any], A dictionary of cell values to set in the spreadsheet. The keys are the cell coordinates in the format "A1", "B2", etc.
|
||||||
|
Supported value types include: float, int, string, bool, formulas.
|
||||||
|
app_name: str, The name of the spreadsheet application. For example, "Some_sheet.xlsx".
|
||||||
|
sheet_name: str, The name of the sheet in the spreadsheet. For example, "Sheet1".
|
||||||
|
"""
|
||||||
|
return SET_CELL_VALUES_CMD.format(
|
||||||
|
cell_values=cell_values, app_name=app_name, sheet_name=sheet_name
|
||||||
|
)
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def call_code_agent(self, task: str = None):
|
||||||
|
"""Call the code agent to execute code for tasks or subtasks that can be completed solely with coding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: str, the task or subtask to execute. If None, uses the current full task instruction.
|
||||||
|
|
||||||
|
**🚨 CRITICAL GUIDELINES:**
|
||||||
|
- **ONLY pass a task parameter for SPECIFIC subtasks** (e.g., "Calculate sum of column B", "Filter data by date")
|
||||||
|
- **NEVER pass a task parameter for full tasks** - let it default to the original task instruction
|
||||||
|
- **NEVER rephrase or modify the original task** - this prevents hallucination corruption
|
||||||
|
- **If unsure, omit the task parameter entirely** to use the original task instruction
|
||||||
|
|
||||||
|
Use this for tasks that can be fully accomplished through code execution, particularly for:
|
||||||
|
- Spreadsheet applications (LibreOffice Calc, Excel): data processing, filtering, sorting, calculations, formulas, data analysis
|
||||||
|
- Document editors (LibreOffice Writer, Word): text processing, content editing, formatting, document manipulation
|
||||||
|
- Code editors (VS Code, text editors): code editing, file processing, text manipulation, configuration
|
||||||
|
- Data analysis tools: statistical analysis, data transformation, reporting
|
||||||
|
- File management: bulk operations, file processing, content extraction
|
||||||
|
- System utilities: configuration, setup, automation
|
||||||
|
"""
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("GROUNDING AGENT: Calling Code Agent")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
# **CRITICAL**: Only use provided task for specific subtasks, otherwise use original task instruction
|
||||||
|
if task is not None:
|
||||||
|
# This is a subtask - use the provided task
|
||||||
|
task_to_execute = task
|
||||||
|
logger.info(f"Executing SUBTASK: {task_to_execute}")
|
||||||
|
else:
|
||||||
|
# This is a full task - use the original task instruction to prevent hallucination
|
||||||
|
task_to_execute = self.current_task_instruction
|
||||||
|
logger.info(f"Executing FULL TASK: {task_to_execute}")
|
||||||
|
|
||||||
|
if task_to_execute:
|
||||||
|
print("obs keys: ", self.obs.keys())
|
||||||
|
screenshot = self.obs.get('screenshot', '') if self.obs else ''
|
||||||
|
logger.info(f"Screenshot available: {'Yes' if screenshot else 'No'}")
|
||||||
|
|
||||||
|
logger.info("Executing code agent...")
|
||||||
|
result = self.code_agent.execute(task_to_execute, screenshot, self.env.controller)
|
||||||
|
|
||||||
|
# Store the result for the worker to access
|
||||||
|
self.last_code_agent_result = result
|
||||||
|
|
||||||
|
logger.info("Code agent execution completed")
|
||||||
|
logger.info(f"Result - Completion reason: {result['completion_reason']}")
|
||||||
|
logger.info(f"Steps executed: {result['steps_executed']}")
|
||||||
|
logger.info(f"Summary: {result['summary']}")
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("GROUNDING AGENT: Code Agent Call Finished")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
# Return code to be executed in the environment
|
||||||
|
return "import time; time.sleep(2.222)"
|
||||||
|
else:
|
||||||
|
logger.warning("No task instruction available for code agent call")
|
||||||
|
return "import time; time.sleep(1.111)"
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def scroll(self, element_description: str, clicks: int, shift: bool = False):
|
||||||
|
"""Scroll the element in the specified direction
|
||||||
|
Args:
|
||||||
|
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
|
||||||
|
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
|
||||||
|
shift:bool, whether to use shift+scroll for horizontal scrolling
|
||||||
|
"""
|
||||||
|
coords1 = self.generate_coords(element_description, self.obs)
|
||||||
|
x, y = self.resize_coordinates(coords1)
|
||||||
|
|
||||||
|
if shift:
|
||||||
|
return f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})"
|
||||||
|
else:
|
||||||
|
return f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})"
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def hotkey(self, keys: List):
|
||||||
|
"""Press a hotkey combination
|
||||||
|
Args:
|
||||||
|
keys:List the keys to press in combination in a list format (e.g. ['ctrl', 'c'])
|
||||||
|
"""
|
||||||
|
# add quotes around the keys
|
||||||
|
keys = [f"'{key}'" for key in keys]
|
||||||
|
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)})"
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def hold_and_press(self, hold_keys: List, press_keys: List):
|
||||||
|
"""Hold a list of keys and press a list of keys
|
||||||
|
Args:
|
||||||
|
hold_keys:List, list of keys to hold
|
||||||
|
press_keys:List, list of keys to press in a sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
press_keys_str = "[" + ", ".join([f"'{key}'" for key in press_keys]) + "]"
|
||||||
|
command = "import pyautogui; "
|
||||||
|
for k in hold_keys:
|
||||||
|
command += f"pyautogui.keyDown({repr(k)}); "
|
||||||
|
command += f"pyautogui.press({press_keys_str}); "
|
||||||
|
for k in hold_keys:
|
||||||
|
command += f"pyautogui.keyUp({repr(k)}); "
|
||||||
|
|
||||||
|
return command
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def wait(self, time: float):
|
||||||
|
"""Wait for a specified amount of time
|
||||||
|
Args:
|
||||||
|
time:float the amount of time to wait in seconds
|
||||||
|
"""
|
||||||
|
return f"""import time; time.sleep({time})"""
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def done(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""End the current task with a success. Use this when you believe the entire task has been fully completed."""
|
||||||
|
return """DONE"""
|
||||||
|
|
||||||
|
@agent_action
|
||||||
|
def fail(self):
|
||||||
|
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete."""
|
||||||
|
return """FAIL"""
|
||||||
@@ -0,0 +1,335 @@
|
|||||||
|
from functools import partial
|
||||||
|
import logging
|
||||||
|
import textwrap
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from gui_agents.s3.agents.grounding import ACI
|
||||||
|
from gui_agents.s3.core.module import BaseModule
|
||||||
|
from gui_agents.s3.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||||
|
from gui_agents.s3.utils.common_utils import (
|
||||||
|
call_llm_safe,
|
||||||
|
call_llm_formatted,
|
||||||
|
parse_code_from_string,
|
||||||
|
split_thinking_response,
|
||||||
|
create_pyautogui_code
|
||||||
|
)
|
||||||
|
from gui_agents.s3.utils.formatters import (
|
||||||
|
SINGLE_ACTION_FORMATTER,
|
||||||
|
CODE_VALID_FORMATTER,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
|
||||||
|
class Worker(BaseModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
worker_engine_params: Dict,
|
||||||
|
grounding_agent: ACI,
|
||||||
|
platform: str = "ubuntu",
|
||||||
|
max_trajectory_length: int = 8,
|
||||||
|
enable_reflection: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Worker receives the main task and generates actions, without the need of hierarchical planning
|
||||||
|
Args:
|
||||||
|
worker_engine_params: Dict
|
||||||
|
Parameters for the worker agent
|
||||||
|
grounding_agent: Agent
|
||||||
|
The grounding agent to use
|
||||||
|
platform: str
|
||||||
|
OS platform the agent runs on (darwin, linux, windows)
|
||||||
|
max_trajectory_length: int
|
||||||
|
The amount of images turns to keep
|
||||||
|
enable_reflection: bool
|
||||||
|
Whether to enable reflection
|
||||||
|
"""
|
||||||
|
super().__init__(worker_engine_params, platform)
|
||||||
|
|
||||||
|
self.temperature = worker_engine_params.get("temperature", 0.0)
|
||||||
|
self.use_thinking = worker_engine_params.get("model", "") in [
|
||||||
|
"claude-opus-4-20250514",
|
||||||
|
"claude-sonnet-4-20250514",
|
||||||
|
"claude-3-7-sonnet-20250219",
|
||||||
|
"claude-sonnet-4-5-20250929",
|
||||||
|
]
|
||||||
|
self.grounding_agent = grounding_agent
|
||||||
|
self.max_trajectory_length = max_trajectory_length
|
||||||
|
self.enable_reflection = enable_reflection
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if self.platform != "linux":
|
||||||
|
skipped_actions = ["set_cell_values"]
|
||||||
|
else:
|
||||||
|
skipped_actions = []
|
||||||
|
|
||||||
|
sys_prompt = PROCEDURAL_MEMORY.construct_simple_worker_procedural_memory(
|
||||||
|
type(self.grounding_agent), skipped_actions=skipped_actions
|
||||||
|
).replace("CURRENT_OS", self.platform)
|
||||||
|
|
||||||
|
self.generator_agent = self._create_agent(sys_prompt)
|
||||||
|
self.reflection_agent = self._create_agent(PROCEDURAL_MEMORY.REFLECTION_ON_TRAJECTORY)
|
||||||
|
|
||||||
|
self.turn_count = 0
|
||||||
|
self.worker_history = []
|
||||||
|
self.reflections = []
|
||||||
|
self.cost_this_turn = 0
|
||||||
|
self.screenshot_inputs = []
|
||||||
|
|
||||||
|
def flush_messages(self):
|
||||||
|
"""Flush messages based on the model's context limits.
|
||||||
|
|
||||||
|
This method ensures that the agent's message history does not exceed the maximum trajectory length.
|
||||||
|
|
||||||
|
Side Effects:
|
||||||
|
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
|
||||||
|
"""
|
||||||
|
engine_type = self.engine_params.get("engine_type", "")
|
||||||
|
|
||||||
|
# Flush strategy for long-context models: keep all text, only keep latest images
|
||||||
|
if engine_type in ["anthropic", "openai", "gemini"]:
|
||||||
|
max_images = self.max_trajectory_length
|
||||||
|
for agent in [self.generator_agent, self.reflection_agent]:
|
||||||
|
if agent is None: continue
|
||||||
|
# keep latest k images
|
||||||
|
img_count = 0
|
||||||
|
for i in range(len(agent.messages) - 1, -1, -1):
|
||||||
|
for j in range(len(agent.messages[i]["content"])):
|
||||||
|
if "image" in agent.messages[i]["content"][j].get("type", ""):
|
||||||
|
img_count += 1
|
||||||
|
if img_count > max_images:
|
||||||
|
del agent.messages[i]["content"][j]
|
||||||
|
|
||||||
|
# Flush strategy for non-long-context models: drop full turns
|
||||||
|
else:
|
||||||
|
# generator msgs are alternating [user, assistant], so 2 per round
|
||||||
|
if len(self.generator_agent.messages) > 2 * self.max_trajectory_length + 1:
|
||||||
|
self.generator_agent.messages.pop(1)
|
||||||
|
self.generator_agent.messages.pop(1)
|
||||||
|
# reflector msgs are all [(user text, user image)], so 1 per round
|
||||||
|
if len(self.reflection_agent.messages) > self.max_trajectory_length + 1:
|
||||||
|
self.reflection_agent.messages.pop(1)
|
||||||
|
|
||||||
|
def _generate_reflection(self, instruction: str, obs: Dict) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Generate a reflection based on the current observation and instruction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction (str): The task instruction.
|
||||||
|
obs (Dict): The current observation containing the screenshot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str, str]: The generated reflection text and thoughts, if any (turn_count > 0).
|
||||||
|
|
||||||
|
Side Effects:
|
||||||
|
- Updates reflection agent's history
|
||||||
|
- Generates reflection response with API call
|
||||||
|
"""
|
||||||
|
reflection = None
|
||||||
|
reflection_thoughts = None
|
||||||
|
if self.enable_reflection:
|
||||||
|
# Load the initial message
|
||||||
|
if self.turn_count == 0:
|
||||||
|
text_content = textwrap.dedent(
|
||||||
|
f"""
|
||||||
|
Task Description: {instruction}
|
||||||
|
Current Trajectory below:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
updated_sys_prompt = (
|
||||||
|
self.reflection_agent.system_prompt + "\n" + text_content
|
||||||
|
)
|
||||||
|
self.reflection_agent.add_system_prompt(updated_sys_prompt)
|
||||||
|
self.reflection_agent.add_message(
|
||||||
|
text_content="The initial screen is provided. No action has been taken yet.",
|
||||||
|
image_content=obs["screenshot"],
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
# Load the latest action
|
||||||
|
else:
|
||||||
|
self.reflection_agent.add_message(
|
||||||
|
text_content=self.worker_history[-1],
|
||||||
|
image_content=obs["screenshot"],
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
full_reflection = call_llm_safe(
|
||||||
|
self.reflection_agent,
|
||||||
|
temperature=self.temperature,
|
||||||
|
use_thinking=self.use_thinking,
|
||||||
|
)
|
||||||
|
reflection, reflection_thoughts = split_thinking_response(
|
||||||
|
full_reflection
|
||||||
|
)
|
||||||
|
self.reflections.append(reflection)
|
||||||
|
logger.info("REFLECTION THOUGHTS: %s", reflection_thoughts)
|
||||||
|
logger.info("REFLECTION: %s", reflection)
|
||||||
|
return reflection, reflection_thoughts
|
||||||
|
|
||||||
|
def generate_next_action(self, instruction: str, obs: Dict) -> Tuple[Dict, List]:
|
||||||
|
"""
|
||||||
|
Predict the next action(s) based on the current observation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.grounding_agent.assign_screenshot(obs)
|
||||||
|
self.grounding_agent.set_task_instruction(instruction)
|
||||||
|
|
||||||
|
generator_message = "" if self.turn_count > 0 else "The initial screen is provided. No action has been taken yet."
|
||||||
|
|
||||||
|
# Load the task into the system prompt
|
||||||
|
if self.turn_count == 0:
|
||||||
|
prompt_with_instructions = self.generator_agent.system_prompt.replace("TASK_DESCRIPTION", instruction)
|
||||||
|
self.generator_agent.add_system_prompt(prompt_with_instructions)
|
||||||
|
|
||||||
|
# Get the per-step reflection
|
||||||
|
reflection, reflection_thoughts = self._generate_reflection(instruction, obs)
|
||||||
|
if reflection:
|
||||||
|
generator_message += f"REFLECTION: You may use this reflection on the previous action and overall trajectory:\n{reflection}\n"
|
||||||
|
|
||||||
|
# Get the grounding agent's knowledge base buffer
|
||||||
|
generator_message += f"\nCurrent Text Buffer = [{','.join(self.grounding_agent.notes)}]\n"
|
||||||
|
|
||||||
|
# Add code agent result from previous step if available (from full task or subtask execution)
|
||||||
|
if hasattr(self.grounding_agent, 'last_code_agent_result') and self.grounding_agent.last_code_agent_result is not None:
|
||||||
|
code_result = self.grounding_agent.last_code_agent_result
|
||||||
|
generator_message += f"\nCODE AGENT RESULT:\n"
|
||||||
|
generator_message += f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||||
|
generator_message += f"Steps Completed: {code_result['steps_executed']}\n"
|
||||||
|
generator_message += f"Max Steps: {code_result['budget']}\n"
|
||||||
|
generator_message += f"Completion Reason: {code_result['completion_reason']}\n"
|
||||||
|
generator_message += f"Summary: {code_result['summary']}\n"
|
||||||
|
if code_result['execution_history']:
|
||||||
|
generator_message += f"Execution History:\n"
|
||||||
|
for i, step in enumerate(code_result['execution_history']):
|
||||||
|
action = step['action']
|
||||||
|
# Format code snippets with proper backticks
|
||||||
|
if '```python' in action:
|
||||||
|
# Extract Python code and format it
|
||||||
|
code_start = action.find('```python') + 9
|
||||||
|
code_end = action.find('```', code_start)
|
||||||
|
if code_end != -1:
|
||||||
|
python_code = action[code_start:code_end].strip()
|
||||||
|
generator_message += f"Step {i+1}: \n```python\n{python_code}\n```\n"
|
||||||
|
else:
|
||||||
|
generator_message += f"Step {i+1}: \n{action}\n"
|
||||||
|
elif '```bash' in action:
|
||||||
|
# Extract Bash code and format it
|
||||||
|
code_start = action.find('```bash') + 7
|
||||||
|
code_end = action.find('```', code_start)
|
||||||
|
if code_end != -1:
|
||||||
|
bash_code = action[code_start:code_end].strip()
|
||||||
|
generator_message += f"Step {i+1}: \n```bash\n{bash_code}\n```\n"
|
||||||
|
else:
|
||||||
|
generator_message += f"Step {i+1}: \n{action}\n"
|
||||||
|
else:
|
||||||
|
generator_message += f"Step {i+1}: \n{action}\n"
|
||||||
|
generator_message += "\n"
|
||||||
|
|
||||||
|
# Save code agent result to text file
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
logs_dir = "logs"
|
||||||
|
if not os.path.exists(logs_dir):
|
||||||
|
os.makedirs(logs_dir)
|
||||||
|
|
||||||
|
# Generate filename with timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"logs/code_agent_result_step_{self.turn_count + 1}_{timestamp}.txt"
|
||||||
|
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
f.write(f"CODE AGENT RESULT - Step {self.turn_count + 1}\n")
|
||||||
|
f.write(f"Timestamp: {datetime.now().isoformat()}\n")
|
||||||
|
f.write(f"Task/Subtask Instruction: {code_result['task_instruction']}\n")
|
||||||
|
f.write(f"Steps Completed: {code_result['steps_executed']}\n")
|
||||||
|
f.write(f"Max Steps: {code_result['budget']}\n")
|
||||||
|
f.write(f"Completion Reason: {code_result['completion_reason']}\n")
|
||||||
|
f.write(f"Summary: {code_result['summary']}\n")
|
||||||
|
if code_result['execution_history']:
|
||||||
|
f.write(f"\nExecution History:\n")
|
||||||
|
for i, step in enumerate(code_result['execution_history']):
|
||||||
|
f.write(f"\nStep {i+1}:\n")
|
||||||
|
f.write(f"Action: {step['action']}\n")
|
||||||
|
if 'thoughts' in step:
|
||||||
|
f.write(f"Thoughts: {step['thoughts']}\n")
|
||||||
|
|
||||||
|
logger.info(f"Code agent result saved to: {filename}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save code agent result to file: {e}")
|
||||||
|
|
||||||
|
# Log the code agent result section for debugging (truncated execution history)
|
||||||
|
log_message = f"\nCODE AGENT RESULT:\n"
|
||||||
|
log_message += f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||||
|
log_message += f"Steps Completed: {code_result['steps_executed']}\n"
|
||||||
|
log_message += f"Max Steps: {code_result['budget']}\n"
|
||||||
|
log_message += f"Completion Reason: {code_result['completion_reason']}\n"
|
||||||
|
log_message += f"Summary: {code_result['summary']}\n"
|
||||||
|
if code_result['execution_history']:
|
||||||
|
log_message += f"Execution History (truncated):\n"
|
||||||
|
# Only log first 3 steps and last 2 steps to keep logs manageable
|
||||||
|
total_steps = len(code_result['execution_history'])
|
||||||
|
for i, step in enumerate(code_result['execution_history']):
|
||||||
|
if i < 3 or i >= total_steps - 2: # First 3 and last 2 steps
|
||||||
|
action = step['action']
|
||||||
|
if '```python' in action:
|
||||||
|
code_start = action.find('```python') + 9
|
||||||
|
code_end = action.find('```', code_start)
|
||||||
|
if code_end != -1:
|
||||||
|
python_code = action[code_start:code_end].strip()
|
||||||
|
log_message += f"Step {i+1}: ```python\n{python_code}\n```\n"
|
||||||
|
else:
|
||||||
|
log_message += f"Step {i+1}: {action}\n"
|
||||||
|
elif '```bash' in action:
|
||||||
|
code_start = action.find('```bash') + 7
|
||||||
|
code_end = action.find('```', code_start)
|
||||||
|
if code_end != -1:
|
||||||
|
bash_code = action[code_start:code_end].strip()
|
||||||
|
log_message += f"Step {i+1}: ```bash\n{bash_code}\n```\n"
|
||||||
|
else:
|
||||||
|
log_message += f"Step {i+1}: {action}\n"
|
||||||
|
else:
|
||||||
|
log_message += f"Step {i+1}: {action}\n"
|
||||||
|
elif i == 3 and total_steps > 5:
|
||||||
|
log_message += f"... (truncated {total_steps - 5} steps) ...\n"
|
||||||
|
|
||||||
|
logger.info(f"WORKER_CODE_AGENT_RESULT_SECTION - Step {self.turn_count + 1}: Code agent result added to generator message:\n{log_message}")
|
||||||
|
|
||||||
|
# Reset the code agent result after adding it to context
|
||||||
|
self.grounding_agent.last_code_agent_result = None
|
||||||
|
|
||||||
|
# Finalize the generator message
|
||||||
|
self.generator_agent.add_message(
|
||||||
|
generator_message, image_content=obs["screenshot"], role="user"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate the plan and next action
|
||||||
|
format_checkers = [SINGLE_ACTION_FORMATTER, partial(CODE_VALID_FORMATTER, self.grounding_agent, obs)]
|
||||||
|
plan = call_llm_formatted(self.generator_agent, format_checkers, temperature=self.temperature, use_thinking=self.use_thinking)
|
||||||
|
self.worker_history.append(plan)
|
||||||
|
self.generator_agent.add_message(plan, role="assistant")
|
||||||
|
logger.info("PLAN:\n %s", plan)
|
||||||
|
|
||||||
|
# Extract the next action from the plan
|
||||||
|
plan_code = parse_code_from_string(plan)
|
||||||
|
try:
|
||||||
|
assert plan_code, "Plan code should not be empty"
|
||||||
|
exec_code = create_pyautogui_code(self.grounding_agent, plan_code, obs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}")
|
||||||
|
exec_code = self.grounding_agent.wait(1.333) # Skip a turn if the code cannot be evaluated
|
||||||
|
|
||||||
|
executor_info = {
|
||||||
|
"plan": plan,
|
||||||
|
"plan_code": plan_code,
|
||||||
|
"exec_code": exec_code,
|
||||||
|
"reflection": reflection,
|
||||||
|
"reflection_thoughts": reflection_thoughts,
|
||||||
|
"code_agent_output": self.grounding_agent.last_code_agent_result if hasattr(self.grounding_agent, 'last_code_agent_result') and self.grounding_agent.last_code_agent_result is not None else None,
|
||||||
|
}
|
||||||
|
self.turn_count += 1
|
||||||
|
self.screenshot_inputs.append(obs["screenshot"])
|
||||||
|
self.flush_messages()
|
||||||
|
return executor_info, [exec_code]
|
||||||
@@ -0,0 +1,181 @@
|
|||||||
|
from gui_agents.s3.core.mllm import LMMAgent
|
||||||
|
from gui_agents.s3.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||||
|
from gui_agents.s3.utils.common_utils import call_llm_formatted, split_thinking_response, compress_image
|
||||||
|
from gui_agents.s3.utils.formatters import (
|
||||||
|
THOUGHTS_ANSWER_TAG_FORMATTER,
|
||||||
|
)
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Dict
|
||||||
|
import base64
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class BehaviorNarrator:
|
||||||
|
def __init__(self, engine_params):
|
||||||
|
self.judge_agent = LMMAgent(engine_params=engine_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_mouse_action(action: str) -> list[str]:
|
||||||
|
mouse_actions = []
|
||||||
|
for sub_action in action.split(';'):
|
||||||
|
sub_action = sub_action.strip()
|
||||||
|
if sub_action.startswith('pyautogui.click') or sub_action.startswith('pyautogui.moveTo') or sub_action.startswith('pyautogui.dragTo'):
|
||||||
|
mouse_actions.append(sub_action)
|
||||||
|
return mouse_actions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mark_action(mouse_actions:list[str], img: Image):
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
font = ImageFont.load_default(25)
|
||||||
|
|
||||||
|
drag_start_width, drag_start_height = None, None
|
||||||
|
|
||||||
|
for mouse_action in mouse_actions:
|
||||||
|
width, height = mouse_action.split('(')[1].strip(')').split(', ')[:2]
|
||||||
|
width, height = int(width), int(height)
|
||||||
|
|
||||||
|
# Clamp coordinates within bounds
|
||||||
|
width = max(0, min(img.width - 1, width ))
|
||||||
|
height = max(0, min(img.height - 1, height))
|
||||||
|
|
||||||
|
def place_text(label, color):
|
||||||
|
bbox = draw.textbbox((0, 0), label, font=font)
|
||||||
|
text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1] # Measure text size
|
||||||
|
offset_x, offset_y = -5, 5 # Default offset
|
||||||
|
if width + offset_x + text_w > img.width: # Out of bounds on right
|
||||||
|
offset_x = -text_w - 5
|
||||||
|
if height + offset_y + text_h > img.height: # Out of bounds on bottom
|
||||||
|
offset_y = -text_h - 5
|
||||||
|
if width + offset_x < 0: # Out of bounds on left
|
||||||
|
offset_x = 5
|
||||||
|
if height + offset_y < 0: # Out of bounds on top
|
||||||
|
offset_y = 5
|
||||||
|
draw.text((width + offset_x, height + offset_y), label, fill=color, font=font)
|
||||||
|
|
||||||
|
if mouse_action.startswith('pyautogui.click'):
|
||||||
|
draw.circle((width, height), radius=3, fill=(255, 0, 0))
|
||||||
|
place_text("Click", (255, 0, 0))
|
||||||
|
if mouse_action.startswith('pyautogui.moveTo'):
|
||||||
|
draw.circle((width, height), radius=3, fill=(0, 0, 255))
|
||||||
|
place_text("MoveTo", (0, 0, 255))
|
||||||
|
drag_start_height, drag_start_width = height, width
|
||||||
|
if mouse_action.startswith('pyautogui.dragTo'):
|
||||||
|
draw.line([(drag_start_width, drag_start_height), (width, height)], fill=(0, 255, 0), width=2)
|
||||||
|
draw.circle((width, height), radius=3, fill=(0, 255, 0))
|
||||||
|
place_text("DragTo", (0, 255, 0))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_mouse_action_representation(mouse_actions:list[str]) -> str:
|
||||||
|
"""
|
||||||
|
Returns a string representation of the mouse action for the given action.
|
||||||
|
"""
|
||||||
|
assert len(mouse_actions) <= 2, f"Multiple mouse action types found: {mouse_actions}"
|
||||||
|
if len(mouse_actions) == 1:
|
||||||
|
action = mouse_actions[0]
|
||||||
|
if action.startswith('pyautogui.click'):
|
||||||
|
return "The red circle labeled 'Click' marks the position where the mouse was clicked."
|
||||||
|
elif action.startswith('pyautogui.moveTo'):
|
||||||
|
return "The blue circle labeled 'MoveTo' marks the position where the mouse was moved to."
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown single action type: {action}")
|
||||||
|
else:
|
||||||
|
assert mouse_actions[0].startswith('pyautogui.moveTo') and mouse_actions[1].startswith('pyautogui.dragTo')
|
||||||
|
return "The blue circle labeled 'MoveTo' marks the starting position of the mouse.\nThe green circle labeled 'DragTo' marks the ending position.\nThe green line illustrates the mouse's drag path."
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_zoomed_image(image_bytes: bytes, x: int, y: int, width: int = 300, height: int = 300, upscaling: bool = False, scale: int = 4, add_bounding_box: bool = False) -> bytes:
|
||||||
|
"""Returns a zoomed image centered around (x, y) coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes (bytes): The original image in bytes.
|
||||||
|
x (int): The x-coordinate of the center point.
|
||||||
|
y (int): The y-coordinate of the center point.
|
||||||
|
width (int): The width of the zoomed area.
|
||||||
|
height (int): The height of the zoomed area.
|
||||||
|
padding (int): Extra padding around the zoomed area.
|
||||||
|
upscaling (bool): Whether to upscale and enhance the zoomed image.
|
||||||
|
scale (int): The upscaling factor if upscaling is True.
|
||||||
|
add_bounding_box (bool): Whether to add a bounding box around the zoomed area in the original image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The zoomed image in bytes.
|
||||||
|
bytes: The original image with bounding box in bytes (if add_bounding_box is True). Otherwise, returns original bytes.
|
||||||
|
"""
|
||||||
|
# Find zoom dimensions
|
||||||
|
img = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||||
|
cx, cy = x - width // 2, y - height // 2 # Center coordinates
|
||||||
|
W, H = img.size
|
||||||
|
left = min(max(cx, 0), W - width)
|
||||||
|
top = min(max(cy, 0), H - height)
|
||||||
|
right = left + width
|
||||||
|
bottom = top + height
|
||||||
|
zoomed_img = img.crop((left, top, right, bottom))
|
||||||
|
# Add noticeable bounding box to original image
|
||||||
|
if add_bounding_box:
|
||||||
|
draw_img = img.copy()
|
||||||
|
draw = ImageDraw.Draw(draw_img)
|
||||||
|
draw.rectangle([left, top, right, bottom], outline="red", width=3)
|
||||||
|
original_with_box_bytes = compress_image(image=draw_img) # Compress to reduce size
|
||||||
|
else:
|
||||||
|
original_with_box_bytes = image_bytes
|
||||||
|
if upscaling:
|
||||||
|
# Upscale and enhance zoomed image
|
||||||
|
zoomed_img = cv2.cvtColor(np.array(zoomed_img), cv2.COLOR_RGB2BGR) # PIL -> OpenCV
|
||||||
|
zoomed_img = cv2.resize(zoomed_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
zoomed_img = cv2.fastNlMeansDenoisingColored(zoomed_img, None, 5, 5, 7, 21) # light denoise (helps with JPEG speckle)
|
||||||
|
zoomed_img = Image.fromarray(cv2.cvtColor(zoomed_img, cv2.COLOR_BGR2RGB)) # OpenCV -> PIL
|
||||||
|
zoomed_img_bytes = compress_image(image=zoomed_img) # Compress to reduce size
|
||||||
|
return zoomed_img_bytes, original_with_box_bytes
|
||||||
|
|
||||||
|
def judge(self, screenshot_num: int, before_img_bytes: bytes, after_img_bytes: bytes, pyautogui_action: str) -> Dict[str, str]:
|
||||||
|
if pyautogui_action == "DONE":
|
||||||
|
return {
|
||||||
|
"fact_thoughts": "The agent has indicated that it is done with the task.",
|
||||||
|
"fact_answer": "The agent has indicated that it is done with the task."
|
||||||
|
}
|
||||||
|
elif pyautogui_action == "FAIL":
|
||||||
|
return {
|
||||||
|
"fact_thoughts": "The agent has indicated that it is impossible to proceed further with the task.",
|
||||||
|
"fact_answer": "The agent has indicated that it is impossible to proceed further with the task."
|
||||||
|
}
|
||||||
|
# Prepare ANNOTATED BEFORE image
|
||||||
|
mouse_actions = BehaviorNarrator.extract_mouse_action(pyautogui_action)
|
||||||
|
before_img = Image.open(BytesIO(before_img_bytes))
|
||||||
|
BehaviorNarrator.mark_action(mouse_actions, before_img)
|
||||||
|
out_buffer = BytesIO()
|
||||||
|
before_img.save(out_buffer, format='PNG')
|
||||||
|
marked_before_img_bytes = out_buffer.getvalue()
|
||||||
|
marked_before_img_message = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(marked_before_img_bytes).decode('utf-8')}", "detail": "high"}}
|
||||||
|
if mouse_actions:
|
||||||
|
coords = mouse_actions[-1].split('(')[1].strip(')').split(', ')
|
||||||
|
x, y = int(coords[0]), int(coords[1])
|
||||||
|
zoomed_after_img_bytes, marked_after_img_bytes = BehaviorNarrator.get_zoomed_image(image_bytes=after_img_bytes, x=x, y=y, width=300, height=300, scale=4, upscaling=True, add_bounding_box=True)
|
||||||
|
after_img_message = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(marked_after_img_bytes).decode('utf-8')}", "detail": "high"}}
|
||||||
|
zoomed_after_img_message = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(zoomed_after_img_bytes).decode('utf-8')}", "detail": "high"}}
|
||||||
|
else:
|
||||||
|
after_img_message = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(after_img_bytes).decode('utf-8')}", "detail": "high"}}
|
||||||
|
zoomed_after_img_message = None
|
||||||
|
|
||||||
|
fact_message = [{"role": "system", "content": PROCEDURAL_MEMORY.BEHAVIOR_NARRATOR_SYSTEM_PROMPT}]
|
||||||
|
fact_message_content = [
|
||||||
|
{"type": "text", "text": "BEFORE:"},
|
||||||
|
marked_before_img_message,
|
||||||
|
{"type": "text", "text": f"Agent Action: {pyautogui_action}"},
|
||||||
|
{"type": "text", "text": "AFTER:"},
|
||||||
|
after_img_message
|
||||||
|
]
|
||||||
|
if zoomed_after_img_message:
|
||||||
|
fact_message_content += [
|
||||||
|
{"type": "text", "text": "ZOOMED AFTER:"},
|
||||||
|
zoomed_after_img_message
|
||||||
|
]
|
||||||
|
fact_message += [{"role": "user", "content": fact_message_content}]
|
||||||
|
fact_response = call_llm_formatted(self.judge_agent, [THOUGHTS_ANSWER_TAG_FORMATTER], messages=fact_message, temperature=0.0)
|
||||||
|
fact_answer, fact_thoughts = split_thinking_response(fact_response)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"fact_thoughts": fact_thoughts,
|
||||||
|
"fact_answer": f"Fact Caption from Screenshot {screenshot_num}: {fact_answer}"
|
||||||
|
}
|
||||||
|
return result
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
import os
|
||||||
|
import base64
|
||||||
|
from typing import List, Tuple, Optional, List
|
||||||
|
|
||||||
|
from gui_agents.s3.core.mllm import LMMAgent
|
||||||
|
from gui_agents.s3.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||||
|
from gui_agents.s3.utils.common_utils import call_llm_formatted, split_thinking_response
|
||||||
|
|
||||||
|
|
||||||
|
def get_final_screenshot_file(task_dir: str) -> str:
|
||||||
|
"""Get the final screenshot file name from a task directory."""
|
||||||
|
screenshot_files = []
|
||||||
|
for filename in os.listdir(task_dir):
|
||||||
|
if filename.startswith("step_") and filename.endswith(".png"):
|
||||||
|
screenshot_files.append(filename)
|
||||||
|
|
||||||
|
if not screenshot_files:
|
||||||
|
return "step_0.png" # fallback
|
||||||
|
|
||||||
|
# Sort by step number and get the last one
|
||||||
|
def extract_step_num(filename):
|
||||||
|
try:
|
||||||
|
return int(filename.split("_")[1].split(".")[0])
|
||||||
|
except:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
screenshot_files.sort(key=extract_step_num)
|
||||||
|
return screenshot_files[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_openai_message_format(image_path: str, caption: str = "") -> Optional[dict]:
|
||||||
|
"""Convert an image file to OpenAI message format."""
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
image_data = base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if caption:
|
||||||
|
content.append({"type": "text", "text": caption})
|
||||||
|
|
||||||
|
content.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{image_data}",
|
||||||
|
"detail": "high"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"role": "user", "content": content}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading image {image_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ComparativeJudge:
|
||||||
|
def __init__(self, engine_params):
|
||||||
|
self.judge_agent = LMMAgent(engine_params=engine_params)
|
||||||
|
|
||||||
|
def judge(self, task_description: str, task: str, result_dirs: List[str], all_fact_captions: List[List[str]]) -> Tuple[str, str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Fact captions + initial/final screenshots judging.
|
||||||
|
Pipeline: use provided fact captions → include initial/final screenshots → judge.
|
||||||
|
"""
|
||||||
|
num_trajectories = len(result_dirs)
|
||||||
|
system_prompt = PROCEDURAL_MEMORY.VLM_EVALUATOR_PROMPT_COMPARATIVE_BASELINE
|
||||||
|
system_prompt = system_prompt.replace("<TASK_DESCRIPTION_INPUT>", task_description)
|
||||||
|
system_prompt = system_prompt.replace("<NUMBER OF TRAJECTORIES>", str(num_trajectories))
|
||||||
|
|
||||||
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
|
for i, (result_dir, fact_captions) in enumerate(zip(result_dirs, all_fact_captions)):
|
||||||
|
task_dir = os.path.join(result_dir, task.split("/")[0], task.split("/")[1])
|
||||||
|
result_initial_screenshot = os.path.join(task_dir, "step_0.png")
|
||||||
|
result_final_screenshot = os.path.join(task_dir, get_final_screenshot_file(task_dir))
|
||||||
|
initial_screenshot_message = image_to_openai_message_format(result_initial_screenshot, caption=f"Initial screenshot of result{i+1}")
|
||||||
|
final_screenshot_message = image_to_openai_message_format(result_final_screenshot, caption=f"Final screenshot of result{i+1}")
|
||||||
|
if initial_screenshot_message is not None and final_screenshot_message is not None:
|
||||||
|
messages.append(initial_screenshot_message)
|
||||||
|
messages.append(final_screenshot_message)
|
||||||
|
if fact_captions:
|
||||||
|
messages.append({"role": "user", "content": [{"type": "text", "text": f"Fact captions for Trajectory {i+1}:"}] + [{"type": "text", "text": caption} for caption in fact_captions]})
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": [{"type": "text", "text": f"Please evaluate the {num_trajectories} trajectories based on the criteria provided in the system prompt."}]})
|
||||||
|
|
||||||
|
response = call_llm_formatted(self.judge_agent, [], messages=messages)
|
||||||
|
answer, thoughts = split_thinking_response(response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
judge_choice = int(answer)
|
||||||
|
if 1 <= judge_choice <= num_trajectories:
|
||||||
|
selected_trajectory = result_dirs[judge_choice - 1]
|
||||||
|
else:
|
||||||
|
selected_trajectory = None
|
||||||
|
except ValueError:
|
||||||
|
selected_trajectory = None
|
||||||
|
|
||||||
|
return answer, thoughts, selected_trajectory
|
||||||
@@ -0,0 +1,364 @@
|
|||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import pyautogui
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from gui_agents.s3.agents.grounding import OSWorldACI
|
||||||
|
from gui_agents.s3.agents.agent_s import AgentS3
|
||||||
|
|
||||||
|
current_platform = platform.system().lower()
|
||||||
|
|
||||||
|
# Global flag to track pause state for debugging
|
||||||
|
paused = False
|
||||||
|
|
||||||
|
def get_char():
|
||||||
|
"""Get a single character from stdin without pressing Enter"""
|
||||||
|
try:
|
||||||
|
# Import termios and tty on Unix-like systems
|
||||||
|
if platform.system() in ["Darwin", "Linux"]:
|
||||||
|
import termios
|
||||||
|
import tty
|
||||||
|
fd = sys.stdin.fileno()
|
||||||
|
old_settings = termios.tcgetattr(fd)
|
||||||
|
try:
|
||||||
|
tty.setraw(sys.stdin.fileno())
|
||||||
|
ch = sys.stdin.read(1)
|
||||||
|
finally:
|
||||||
|
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||||
|
return ch
|
||||||
|
else:
|
||||||
|
# Windows fallback
|
||||||
|
import msvcrt
|
||||||
|
return msvcrt.getch().decode('utf-8', errors='ignore')
|
||||||
|
except:
|
||||||
|
return input() # Fallback for non-terminal environments
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
"""Handle Ctrl+C signal for debugging during agent execution"""
|
||||||
|
global paused
|
||||||
|
|
||||||
|
if not paused:
|
||||||
|
print("\n\n🔸 Agent-S Workflow Paused 🔸")
|
||||||
|
print("=" * 50)
|
||||||
|
print("Options:")
|
||||||
|
print(" • Press Ctrl+C again to quit")
|
||||||
|
print(" • Press Esc to resume workflow")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
paused = True
|
||||||
|
|
||||||
|
while paused:
|
||||||
|
try:
|
||||||
|
print("\n[PAUSED] Waiting for input... ", end="", flush=True)
|
||||||
|
char = get_char()
|
||||||
|
|
||||||
|
if ord(char) == 3: # Ctrl+C
|
||||||
|
print("\n\n🛑 Exiting Agent-S...")
|
||||||
|
sys.exit(0)
|
||||||
|
elif ord(char) == 27: # Esc
|
||||||
|
print("\n\n▶️ Resuming Agent-S workflow...")
|
||||||
|
paused = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"\n Unknown command: '{char}' (ord: {ord(char)})")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n🛑 Exiting Agent-S...")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
# Already paused, second Ctrl+C means quit
|
||||||
|
print("\n\n🛑 Exiting Agent-S...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Set up signal handler for Ctrl+C
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
|
log_dir = "logs"
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
debug_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
sdebug_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
debug_handler.setLevel(logging.DEBUG)
|
||||||
|
stdout_handler.setLevel(logging.INFO)
|
||||||
|
sdebug_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
debug_handler.setFormatter(formatter)
|
||||||
|
stdout_handler.setFormatter(formatter)
|
||||||
|
sdebug_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(debug_handler)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
logger.addHandler(sdebug_handler)
|
||||||
|
|
||||||
|
platform_os = platform.system()
|
||||||
|
|
||||||
|
|
||||||
|
def show_permission_dialog(code: str, action_description: str):
|
||||||
|
"""Show a platform-specific permission dialog and return True if approved."""
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
result = os.system(
|
||||||
|
f'osascript -e \'display dialog "Do you want to execute this action?\n\n{code} which will try to {action_description}" with title "Action Permission" buttons {{"Cancel", "OK"}} default button "OK" cancel button "Cancel"\''
|
||||||
|
)
|
||||||
|
return result == 0
|
||||||
|
elif platform.system() == "Linux":
|
||||||
|
result = os.system(
|
||||||
|
f'zenity --question --title="Action Permission" --text="Do you want to execute this action?\n\n{code}" --width=400 --height=200'
|
||||||
|
)
|
||||||
|
return result == 0
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def scale_screen_dimensions(width: int, height: int, max_dim_size: int):
|
||||||
|
scale_factor = min(max_dim_size / width, max_dim_size / height, 1)
|
||||||
|
safe_width = int(width * scale_factor)
|
||||||
|
safe_height = int(height * scale_factor)
|
||||||
|
return safe_width, safe_height
|
||||||
|
|
||||||
|
|
||||||
|
def run_agent(agent, instruction: str, scaled_width: int, scaled_height: int):
|
||||||
|
global paused
|
||||||
|
obs = {}
|
||||||
|
traj = "Task:\n" + instruction
|
||||||
|
subtask_traj = ""
|
||||||
|
for step in range(15):
|
||||||
|
# Check if we're in paused state and wait
|
||||||
|
while paused:
|
||||||
|
time.sleep(0.1)
|
||||||
|
# Get screen shot using pyautogui
|
||||||
|
screenshot = pyautogui.screenshot()
|
||||||
|
screenshot = screenshot.resize((scaled_width, scaled_height), Image.LANCZOS)
|
||||||
|
|
||||||
|
# Save the screenshot to a BytesIO object
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
screenshot.save(buffered, format="PNG")
|
||||||
|
|
||||||
|
# Get the byte value of the screenshot
|
||||||
|
screenshot_bytes = buffered.getvalue()
|
||||||
|
# Convert to base64 string.
|
||||||
|
obs["screenshot"] = screenshot_bytes
|
||||||
|
|
||||||
|
# Check again for pause state before prediction
|
||||||
|
while paused:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
print(f"\n🔄 Step {step + 1}/15: Getting next action from agent...")
|
||||||
|
|
||||||
|
# Get next action code from the agent
|
||||||
|
info, code = agent.predict(instruction=instruction, observation=obs)
|
||||||
|
|
||||||
|
if "done" in code[0].lower() or "fail" in code[0].lower():
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
os.system(
|
||||||
|
f'osascript -e \'display dialog "Task Completed" with title "OpenACI Agent" buttons "OK" default button "OK"\''
|
||||||
|
)
|
||||||
|
elif platform.system() == "Linux":
|
||||||
|
os.system(
|
||||||
|
f'zenity --info --title="OpenACI Agent" --text="Task Completed" --width=200 --height=100'
|
||||||
|
)
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
if "next" in code[0].lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "wait" in code[0].lower():
|
||||||
|
print("⏳ Agent requested wait...")
|
||||||
|
time.sleep(5)
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
time.sleep(1.0)
|
||||||
|
print("EXECUTING CODE:", code[0])
|
||||||
|
|
||||||
|
# Check for pause state before execution
|
||||||
|
while paused:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Ask for permission before executing
|
||||||
|
exec(code[0])
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
|
# Update task and subtask trajectories
|
||||||
|
if "reflection" in info and "executor_plan" in info:
|
||||||
|
traj += (
|
||||||
|
"\n\nReflection:\n"
|
||||||
|
+ str(info["reflection"])
|
||||||
|
+ "\n\n----------------------\n\nPlan:\n"
|
||||||
|
+ info["executor_plan"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Run AgentS3 with specified model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider",
|
||||||
|
type=str,
|
||||||
|
default="openai",
|
||||||
|
help="Specify the provider to use (e.g., openai, anthropic, etc.)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="gpt-5-2025-08-07",
|
||||||
|
help="Specify the model to use (e.g., gpt-5-2025-08-07)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_url",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The URL of the main generation model API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_api_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The API key of the main generation model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_temperature",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Grounding model config: Self-hosted endpoint based (required)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_provider",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The provider for the grounding model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_url",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The URL of the grounding model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_api_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The API key of the grounding model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The model name for the grounding model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grounding_width",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Width of screenshot image after processor rescaling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grounding_height",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Height of screenshot image after processor rescaling",
|
||||||
|
)
|
||||||
|
|
||||||
|
# AgentS3 specific arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_trajectory_length",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Maximum number of image turns to keep in trajectory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_reflection",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Enable reflection agent to assist the worker agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Re-scales screenshot size to ensure it fits in UI-TARS context limit
|
||||||
|
screen_width, screen_height = pyautogui.size()
|
||||||
|
scaled_width, scaled_height = scale_screen_dimensions(
|
||||||
|
screen_width, screen_height, max_dim_size=2400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the general engine params
|
||||||
|
engine_params = {
|
||||||
|
"engine_type": args.provider,
|
||||||
|
"model": args.model,
|
||||||
|
"base_url": args.model_url,
|
||||||
|
"api_key": args.model_api_key,
|
||||||
|
"temperature": getattr(args, 'model_temperature', None),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load the grounding engine from a custom endpoint
|
||||||
|
engine_params_for_grounding = {
|
||||||
|
"engine_type": args.ground_provider,
|
||||||
|
"model": args.ground_model,
|
||||||
|
"base_url": args.ground_url,
|
||||||
|
"api_key": args.ground_api_key,
|
||||||
|
"grounding_width": args.grounding_width,
|
||||||
|
"grounding_height": args.grounding_height,
|
||||||
|
}
|
||||||
|
|
||||||
|
grounding_agent = OSWorldACI(
|
||||||
|
platform=current_platform,
|
||||||
|
engine_params_for_generation=engine_params,
|
||||||
|
engine_params_for_grounding=engine_params_for_grounding,
|
||||||
|
width=screen_width,
|
||||||
|
height=screen_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = AgentS3(
|
||||||
|
engine_params,
|
||||||
|
grounding_agent,
|
||||||
|
platform=current_platform,
|
||||||
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
enable_reflection=args.enable_reflection,
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input("Query: ")
|
||||||
|
|
||||||
|
agent.reset()
|
||||||
|
|
||||||
|
# Run the agent on your own device
|
||||||
|
run_agent(agent, query, scaled_width, scaled_height)
|
||||||
|
|
||||||
|
response = input("Would you like to provide another query? (y/n): ")
|
||||||
|
if response.lower() != "y":
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,405 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import backoff
|
||||||
|
from anthropic import Anthropic
|
||||||
|
from openai import (
|
||||||
|
AzureOpenAI,
|
||||||
|
APIConnectionError,
|
||||||
|
APIError,
|
||||||
|
AzureOpenAI,
|
||||||
|
OpenAI,
|
||||||
|
RateLimitError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEngine:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEngineOpenAI(LMMEngine):
|
||||||
|
def __init__(
|
||||||
|
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, organization=None, **kwargs
|
||||||
|
):
|
||||||
|
assert model is not None, "model must be provided"
|
||||||
|
self.model = model
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.organization = organization
|
||||||
|
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||||
|
self.llm_client = None
|
||||||
|
self.temperature = temperature # Can force temperature to be the same (in the case of o3 requiring temperature to be 1)
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||||
|
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
|
||||||
|
)
|
||||||
|
organization = self.organization or os.getenv("OPENAI_ORG_ID")
|
||||||
|
if not self.llm_client:
|
||||||
|
if not self.base_url:
|
||||||
|
self.llm_client = OpenAI(api_key=api_key, organization=organization)
|
||||||
|
else:
|
||||||
|
self.llm_client = OpenAI(base_url=self.base_url, api_key=api_key, organization=organization)
|
||||||
|
return (
|
||||||
|
self.llm_client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
# max_completion_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||||
|
temperature=temperature if self.temperature is None else self.temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
.choices[0]
|
||||||
|
.message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEngineAnthropic(LMMEngine):
|
||||||
|
def __init__(
|
||||||
|
self, base_url=None, api_key=None, model=None, thinking=False, temperature=None, **kwargs
|
||||||
|
):
|
||||||
|
assert model is not None, "model must be provided"
|
||||||
|
self.model = model
|
||||||
|
self.thinking = thinking
|
||||||
|
self.api_key = api_key
|
||||||
|
self.llm_client = None
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||||
|
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
|
||||||
|
)
|
||||||
|
self.llm_client = Anthropic(api_key=api_key)
|
||||||
|
# Use the instance temperature if not specified in the call
|
||||||
|
temp = self.temperature if temperature is None else temperature
|
||||||
|
if self.thinking:
|
||||||
|
full_response = self.llm_client.messages.create(
|
||||||
|
system=messages[0]["content"][0]["text"],
|
||||||
|
model=self.model,
|
||||||
|
messages=messages[1:],
|
||||||
|
max_tokens=8192,
|
||||||
|
thinking={"type": "enabled", "budget_tokens": 4096},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
thoughts = full_response.content[0].thinking
|
||||||
|
return full_response.content[1].text
|
||||||
|
return (
|
||||||
|
self.llm_client.messages.create(
|
||||||
|
system=messages[0]["content"][0]["text"],
|
||||||
|
model=self.model,
|
||||||
|
messages=messages[1:],
|
||||||
|
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||||
|
temperature=temp,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
.content[0]
|
||||||
|
.text
|
||||||
|
)
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
# Compatible with Claude-3.7 Sonnet thinking mode
|
||||||
|
def generate_with_thinking(
|
||||||
|
self, messages, temperature=0.0, max_new_tokens=None, **kwargs
|
||||||
|
):
|
||||||
|
"""Generate the next message based on previous messages, and keeps the thinking tokens"""
|
||||||
|
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
|
||||||
|
)
|
||||||
|
self.llm_client = Anthropic(api_key=api_key)
|
||||||
|
full_response = self.llm_client.messages.create(
|
||||||
|
system=messages[0]["content"][0]["text"],
|
||||||
|
model=self.model,
|
||||||
|
messages=messages[1:],
|
||||||
|
max_tokens=8192,
|
||||||
|
thinking={"type": "enabled", "budget_tokens": 4096},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
thoughts = full_response.content[0].thinking
|
||||||
|
answer = full_response.content[1].text
|
||||||
|
full_response = (
|
||||||
|
f"<thoughts>\n{thoughts}\n</thoughts>\n\n<answer>\n{answer}\n</answer>\n"
|
||||||
|
)
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEngineGemini(LMMEngine):
|
||||||
|
def __init__(
|
||||||
|
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||||
|
):
|
||||||
|
assert model is not None, "model must be provided"
|
||||||
|
self.model = model
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||||
|
self.llm_client = None
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||||
|
api_key = self.api_key or os.getenv("GEMINI_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
|
||||||
|
)
|
||||||
|
base_url = self.base_url or os.getenv("GEMINI_ENDPOINT_URL")
|
||||||
|
if base_url is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named GEMINI_ENDPOINT_URL"
|
||||||
|
)
|
||||||
|
if not self.llm_client:
|
||||||
|
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||||
|
# Use the temperature passed to generate, otherwise use the instance's temperature, otherwise default to 0.0
|
||||||
|
temp = self.temperature if temperature is None else temperature
|
||||||
|
return (
|
||||||
|
self.llm_client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||||
|
temperature=temp,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
.choices[0]
|
||||||
|
.message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEngineOpenRouter(LMMEngine):
|
||||||
|
def __init__(
|
||||||
|
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||||
|
):
|
||||||
|
assert model is not None, "model must be provided"
|
||||||
|
self.model = model
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||||
|
self.llm_client = None
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||||
|
api_key = self.api_key or os.getenv("OPENROUTER_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENROUTER_API_KEY"
|
||||||
|
)
|
||||||
|
base_url = self.base_url or os.getenv("OPEN_ROUTER_ENDPOINT_URL")
|
||||||
|
if base_url is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named OPEN_ROUTER_ENDPOINT_URL"
|
||||||
|
)
|
||||||
|
if not self.llm_client:
|
||||||
|
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||||
|
# Use self.temperature if set, otherwise use the temperature argument
|
||||||
|
temp = self.temperature if self.temperature is not None else temperature
|
||||||
|
return (
|
||||||
|
self.llm_client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||||
|
temperature=temp,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
.choices[0]
|
||||||
|
.message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEngineAzureOpenAI(LMMEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url=None,
|
||||||
|
api_key=None,
|
||||||
|
azure_endpoint=None,
|
||||||
|
model=None,
|
||||||
|
api_version=None,
|
||||||
|
rate_limit=-1,
|
||||||
|
temperature=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert model is not None, "model must be provided"
|
||||||
|
self.model = model
|
||||||
|
self.api_version = api_version
|
||||||
|
self.api_key = api_key
|
||||||
|
self.azure_endpoint = azure_endpoint
|
||||||
|
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||||
|
self.llm_client = None
|
||||||
|
self.cost = 0.0
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||||
|
api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
|
||||||
|
)
|
||||||
|
api_version = self.api_version or os.getenv("OPENAI_API_VERSION")
|
||||||
|
if api_version is None:
|
||||||
|
raise ValueError(
|
||||||
|
"api_version must be provided either as a parameter or as an environment variable named OPENAI_API_VERSION"
|
||||||
|
)
|
||||||
|
azure_endpoint = self.azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||||
|
if azure_endpoint is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An Azure API endpoint needs to be provided in either the azure_endpoint parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
|
||||||
|
)
|
||||||
|
if not self.llm_client:
|
||||||
|
self.llm_client = AzureOpenAI(
|
||||||
|
azure_endpoint=azure_endpoint,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
)
|
||||||
|
# Use self.temperature if set, otherwise use the temperature argument
|
||||||
|
temp = self.temperature if self.temperature is not None else temperature
|
||||||
|
completion = self.llm_client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||||
|
temperature=temp,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
total_tokens = completion.usage.total_tokens
|
||||||
|
self.cost += 0.02 * ((total_tokens + 500) / 1000)
|
||||||
|
return completion.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEnginevLLM(LMMEngine):
|
||||||
|
def __init__(
|
||||||
|
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||||
|
):
|
||||||
|
assert model is not None, "model must be provided"
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||||
|
self.llm_client = None
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages,
|
||||||
|
temperature=0.0,
|
||||||
|
top_p=0.8,
|
||||||
|
repetition_penalty=1.05,
|
||||||
|
max_new_tokens=512,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
api_key = self.api_key or os.getenv("vLLM_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"A vLLM API key needs to be provided in either the api_key parameter or as an environment variable named vLLM_API_KEY"
|
||||||
|
)
|
||||||
|
base_url = self.base_url or os.getenv("vLLM_ENDPOINT_URL")
|
||||||
|
if base_url is None:
|
||||||
|
raise ValueError(
|
||||||
|
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named vLLM_ENDPOINT_URL"
|
||||||
|
)
|
||||||
|
if not self.llm_client:
|
||||||
|
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||||
|
# Use self.temperature if set, otherwise use the temperature argument
|
||||||
|
temp = self.temperature if self.temperature is not None else temperature
|
||||||
|
completion = self.llm_client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||||
|
temperature=temp,
|
||||||
|
top_p=top_p,
|
||||||
|
extra_body={"repetition_penalty": repetition_penalty},
|
||||||
|
)
|
||||||
|
return completion.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEngineHuggingFace(LMMEngine):
|
||||||
|
def __init__(self, base_url=None, api_key=None, rate_limit=-1, **kwargs):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||||
|
self.llm_client = None
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||||
|
api_key = self.api_key or os.getenv("HF_TOKEN")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"A HuggingFace token needs to be provided in either the api_key parameter or as an environment variable named HF_TOKEN"
|
||||||
|
)
|
||||||
|
base_url = self.base_url or os.getenv("HF_ENDPOINT_URL")
|
||||||
|
if base_url is None:
|
||||||
|
raise ValueError(
|
||||||
|
"HuggingFace endpoint must be provided as base_url parameter or as an environment variable named HF_ENDPOINT_URL."
|
||||||
|
)
|
||||||
|
if not self.llm_client:
|
||||||
|
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||||
|
return (
|
||||||
|
self.llm_client.chat.completions.create(
|
||||||
|
model="tgi",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
.choices[0]
|
||||||
|
.message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMMEngineParasail(LMMEngine):
|
||||||
|
def __init__(self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs):
|
||||||
|
assert model is not None, "Parasail model id must be provided"
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||||
|
self.llm_client = None
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||||
|
)
|
||||||
|
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||||
|
api_key = self.api_key or os.getenv("PARASAIL_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"A Parasail API key needs to be provided in either the api_key parameter or as an environment variable named PARASAIL_API_KEY"
|
||||||
|
)
|
||||||
|
base_url = self.base_url
|
||||||
|
if base_url is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Parasail endpoint must be provided as base_url parameter or as an environment variable named PARASAIL_ENDPOINT_URL"
|
||||||
|
)
|
||||||
|
if not self.llm_client:
|
||||||
|
self.llm_client = OpenAI(base_url=base_url if base_url else "https://api.parasail.io/v1", api_key=api_key)
|
||||||
|
return (
|
||||||
|
self.llm_client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
.choices[0].
|
||||||
|
message.content
|
||||||
|
)
|
||||||
@@ -0,0 +1,305 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gui_agents.s3.core.engine import (
|
||||||
|
LMMEngineAnthropic,
|
||||||
|
LMMEngineAzureOpenAI,
|
||||||
|
LMMEngineHuggingFace,
|
||||||
|
LMMEngineOpenAI,
|
||||||
|
LMMEngineOpenRouter,
|
||||||
|
LMMEngineParasail,
|
||||||
|
LMMEnginevLLM,
|
||||||
|
LMMEngineGemini,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LMMAgent:
|
||||||
|
def __init__(self, engine_params=None, system_prompt=None, engine=None):
|
||||||
|
if engine is None:
|
||||||
|
if engine_params is not None:
|
||||||
|
engine_type = engine_params.get("engine_type")
|
||||||
|
if engine_type == "openai":
|
||||||
|
self.engine = LMMEngineOpenAI(**engine_params)
|
||||||
|
elif engine_type == "anthropic":
|
||||||
|
self.engine = LMMEngineAnthropic(**engine_params)
|
||||||
|
elif engine_type == "azure":
|
||||||
|
self.engine = LMMEngineAzureOpenAI(**engine_params)
|
||||||
|
elif engine_type == "vllm":
|
||||||
|
self.engine = LMMEnginevLLM(**engine_params)
|
||||||
|
elif engine_type == "huggingface":
|
||||||
|
self.engine = LMMEngineHuggingFace(**engine_params)
|
||||||
|
elif engine_type == "gemini":
|
||||||
|
self.engine = LMMEngineGemini(**engine_params)
|
||||||
|
elif engine_type == "open_router":
|
||||||
|
self.engine = LMMEngineOpenRouter(**engine_params)
|
||||||
|
elif engine_type == "parasail":
|
||||||
|
self.engine = LMMEngineParasail(**engine_params)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"engine_type '{engine_type}' is not supported")
|
||||||
|
else:
|
||||||
|
raise ValueError("engine_params must be provided")
|
||||||
|
else:
|
||||||
|
self.engine = engine
|
||||||
|
|
||||||
|
self.messages = [] # Empty messages
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
self.add_system_prompt(system_prompt)
|
||||||
|
else:
|
||||||
|
self.add_system_prompt("You are a helpful assistant.")
|
||||||
|
|
||||||
|
def encode_image(self, image_content):
|
||||||
|
# if image_content is a path to an image file, check type of the image_content to verify
|
||||||
|
if isinstance(image_content, str):
|
||||||
|
with open(image_content, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
else:
|
||||||
|
return base64.b64encode(image_content).decode("utf-8")
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": self.system_prompt}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def add_system_prompt(self, system_prompt):
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
if len(self.messages) > 0:
|
||||||
|
self.messages[0] = {
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": self.system_prompt}],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.messages.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": self.system_prompt}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_message_at(self, index):
|
||||||
|
"""Remove a message at a given index"""
|
||||||
|
if index < len(self.messages):
|
||||||
|
self.messages.pop(index)
|
||||||
|
|
||||||
|
def replace_message_at(
|
||||||
|
self, index, text_content, image_content=None, image_detail="high"
|
||||||
|
):
|
||||||
|
"""Replace a message at a given index"""
|
||||||
|
if index < len(self.messages):
|
||||||
|
self.messages[index] = {
|
||||||
|
"role": self.messages[index]["role"],
|
||||||
|
"content": [{"type": "text", "text": text_content}],
|
||||||
|
}
|
||||||
|
if image_content:
|
||||||
|
base64_image = self.encode_image(image_content)
|
||||||
|
self.messages[index]["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{base64_image}",
|
||||||
|
"detail": image_detail,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
text_content,
|
||||||
|
image_content=None,
|
||||||
|
role=None,
|
||||||
|
image_detail="high",
|
||||||
|
put_text_last=False,
|
||||||
|
):
|
||||||
|
"""Add a new message to the list of messages"""
|
||||||
|
|
||||||
|
# API-style inference from OpenAI and AzureOpenAI
|
||||||
|
if isinstance(
|
||||||
|
self.engine,
|
||||||
|
(
|
||||||
|
LMMEngineOpenAI,
|
||||||
|
LMMEngineAzureOpenAI,
|
||||||
|
LMMEngineHuggingFace,
|
||||||
|
LMMEngineGemini,
|
||||||
|
LMMEngineOpenRouter,
|
||||||
|
LMMEngineParasail
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# infer role from previous message
|
||||||
|
if role != "user":
|
||||||
|
if self.messages[-1]["role"] == "system":
|
||||||
|
role = "user"
|
||||||
|
elif self.messages[-1]["role"] == "user":
|
||||||
|
role = "assistant"
|
||||||
|
elif self.messages[-1]["role"] == "assistant":
|
||||||
|
role = "user"
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"role": role,
|
||||||
|
"content": [{"type": "text", "text": text_content}],
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(image_content, np.ndarray) or image_content:
|
||||||
|
# Check if image_content is a list or a single image
|
||||||
|
if isinstance(image_content, list):
|
||||||
|
# If image_content is a list of images, loop through each image
|
||||||
|
for image in image_content:
|
||||||
|
base64_image = self.encode_image(image)
|
||||||
|
message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{base64_image}",
|
||||||
|
"detail": image_detail,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If image_content is a single image, handle it directly
|
||||||
|
base64_image = self.encode_image(image_content)
|
||||||
|
message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{base64_image}",
|
||||||
|
"detail": image_detail,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rotate text to be the last message if desired
|
||||||
|
if put_text_last:
|
||||||
|
text_content = message["content"].pop(0)
|
||||||
|
message["content"].append(text_content)
|
||||||
|
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
# For API-style inference from Anthropic
|
||||||
|
elif isinstance(self.engine, LMMEngineAnthropic):
|
||||||
|
# infer role from previous message
|
||||||
|
if role != "user":
|
||||||
|
if self.messages[-1]["role"] == "system":
|
||||||
|
role = "user"
|
||||||
|
elif self.messages[-1]["role"] == "user":
|
||||||
|
role = "assistant"
|
||||||
|
elif self.messages[-1]["role"] == "assistant":
|
||||||
|
role = "user"
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"role": role,
|
||||||
|
"content": [{"type": "text", "text": text_content}],
|
||||||
|
}
|
||||||
|
|
||||||
|
if image_content:
|
||||||
|
# Check if image_content is a list or a single image
|
||||||
|
if isinstance(image_content, list):
|
||||||
|
# If image_content is a list of images, loop through each image
|
||||||
|
for image in image_content:
|
||||||
|
base64_image = self.encode_image(image)
|
||||||
|
message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": base64_image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If image_content is a single image, handle it directly
|
||||||
|
base64_image = self.encode_image(image_content)
|
||||||
|
message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": base64_image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
# Locally hosted vLLM model inference
|
||||||
|
elif isinstance(self.engine, LMMEnginevLLM):
|
||||||
|
# infer role from previous message
|
||||||
|
if role != "user":
|
||||||
|
if self.messages[-1]["role"] == "system":
|
||||||
|
role = "user"
|
||||||
|
elif self.messages[-1]["role"] == "user":
|
||||||
|
role = "assistant"
|
||||||
|
elif self.messages[-1]["role"] == "assistant":
|
||||||
|
role = "user"
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"role": role,
|
||||||
|
"content": [{"type": "text", "text": text_content}],
|
||||||
|
}
|
||||||
|
|
||||||
|
if image_content:
|
||||||
|
# Check if image_content is a list or a single image
|
||||||
|
if isinstance(image_content, list):
|
||||||
|
# If image_content is a list of images, loop through each image
|
||||||
|
for image in image_content:
|
||||||
|
base64_image = self.encode_image(image)
|
||||||
|
message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image;base64,{base64_image}"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If image_content is a single image, handle it directly
|
||||||
|
base64_image = self.encode_image(image_content)
|
||||||
|
message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image;base64,{base64_image}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.messages.append(message)
|
||||||
|
else:
|
||||||
|
raise ValueError("engine_type is not supported")
|
||||||
|
|
||||||
|
def get_response(
|
||||||
|
self,
|
||||||
|
user_message=None,
|
||||||
|
messages=None,
|
||||||
|
temperature=0.0,
|
||||||
|
max_new_tokens=None,
|
||||||
|
use_thinking=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Generate the next response based on previous messages"""
|
||||||
|
if messages is None:
|
||||||
|
messages = self.messages
|
||||||
|
if user_message:
|
||||||
|
messages.append(
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": user_message}]}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Regular generation
|
||||||
|
if use_thinking:
|
||||||
|
return self.engine.generate_with_thinking(
|
||||||
|
messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.engine.generate(
|
||||||
|
messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
from gui_agents.s3.core.mllm import LMMAgent
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModule:
|
||||||
|
def __init__(self, engine_params: Dict, platform: str):
|
||||||
|
self.engine_params = engine_params
|
||||||
|
self.platform = platform
|
||||||
|
|
||||||
|
def _create_agent(
|
||||||
|
self, system_prompt: str = None, engine_params: Optional[Dict] = None
|
||||||
|
) -> LMMAgent:
|
||||||
|
"""Create a new LMMAgent instance"""
|
||||||
|
agent = LMMAgent(engine_params or self.engine_params)
|
||||||
|
if system_prompt:
|
||||||
|
agent.add_system_prompt(system_prompt)
|
||||||
|
return agent
|
||||||
@@ -0,0 +1,367 @@
|
|||||||
|
import inspect
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
|
||||||
|
class PROCEDURAL_MEMORY:
|
||||||
|
|
||||||
|
FORMATTING_FEEDBACK_PROMPT = textwrap.dedent("""
|
||||||
|
Your previous response was not formatted correctly. You must respond again to replace your previous response. Do not make reference to this message while fixing the response. Please address the following issues below to improve the previous response:
|
||||||
|
FORMATTING_FEEDBACK
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def construct_simple_worker_procedural_memory(agent_class, skipped_actions):
|
||||||
|
procedural_memory = textwrap.dedent(
|
||||||
|
f"""\
|
||||||
|
You are an expert in graphical user interfaces and Python code. You are responsible for executing the task: `TASK_DESCRIPTION`.
|
||||||
|
You are working in CURRENT_OS.
|
||||||
|
|
||||||
|
# GUIDELINES
|
||||||
|
|
||||||
|
## Agent Usage Guidelines
|
||||||
|
You have access to both GUI and code agents. Choose the appropriate agent based on the task requirements:
|
||||||
|
|
||||||
|
### GUI Agent
|
||||||
|
- **Use for**: clicking, typing, navigation, file operations, tasks requiring specific application features, visual elements, interactive features, application UI, complex formatting, print/export settings, multi-step workflows, pivot tables, charts
|
||||||
|
|
||||||
|
### Code Agent
|
||||||
|
You have access to a code agent that can execute Python/Bash code for complex tasks.
|
||||||
|
|
||||||
|
**Usage Strategy**:
|
||||||
|
- **Full Task**: Use `agent.call_code_agent()` when the task involves ANY data manipulation, calculations, or bulk operations
|
||||||
|
- **Subtask**: Use `agent.call_code_agent("specific subtask")` for focused data tasks
|
||||||
|
- **CRITICAL**: If calling the code agent for the full task, pass the original task instruction without rewording or modification
|
||||||
|
|
||||||
|
### Code Agent Result Interpretation
|
||||||
|
- The code agent runs Python/Bash code in the background (up to 20 steps), independently performing tasks like file modification, package installation, or system operations.
|
||||||
|
- After execution, you receive a report with:
|
||||||
|
* Steps completed (actual steps run)
|
||||||
|
* Max steps (step budget)
|
||||||
|
* Completion reason: DONE (success), FAIL (gave up), or BUDGET_EXHAUSTED (used all steps)
|
||||||
|
* Summary of work done
|
||||||
|
* Full execution history
|
||||||
|
- Interpretation:
|
||||||
|
* DONE: The code agent finished before using all steps, believing the task was completed through code.
|
||||||
|
* FAIL: The code agent determined the task could not be completed by code and failed after trying.
|
||||||
|
* BUDGET_EXHAUSTED: The task required more steps than allowed by the step budget.
|
||||||
|
|
||||||
|
### Code Agent Verification
|
||||||
|
- After the code agent modifies files, your job is to find and verify these files via GUI actions (e.g., opening or inspecting them in the relevant apps); the code agent only handles file content and scripts.
|
||||||
|
- ALWAYS verify code agent results with GUI actions before using agent.done(); NEVER trust code agent output alone. If verification or the code agent fails, use GUI actions to finish the task and only use agent.done() if results match expectations.
|
||||||
|
- **CRITICAL**: Files modified by code agent may not show changes in currently open applications - you MUST close and reopen the entire application. Reloading the page/file is insufficient.
|
||||||
|
|
||||||
|
Never assume a task is done based on appearances-always ensure the specific requested action has been performed and verify the modification. If you haven't executed any actions, the task is not complete.
|
||||||
|
|
||||||
|
### END OF GUIDELINES
|
||||||
|
|
||||||
|
You are provided with:
|
||||||
|
1. A screenshot of the current time step.
|
||||||
|
2. The history of your previous interactions with the UI.
|
||||||
|
3. Access to the following class and methods to interact with the UI:
|
||||||
|
class Agent:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
for attr_name in dir(agent_class):
|
||||||
|
if attr_name in skipped_actions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
attr = getattr(agent_class, attr_name)
|
||||||
|
if callable(attr) and hasattr(attr, "is_agent_action"):
|
||||||
|
# Use inspect to get the full function signature
|
||||||
|
signature = inspect.signature(attr)
|
||||||
|
procedural_memory += f"""
|
||||||
|
def {attr_name}{signature}:
|
||||||
|
'''{attr.__doc__}'''
|
||||||
|
"""
|
||||||
|
|
||||||
|
procedural_memory += textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Your response should be formatted like this:
|
||||||
|
(Previous action verification)
|
||||||
|
Carefully analyze based on the screenshot if the previous action was successful. If the previous action was not successful, provide a reason for the failure.
|
||||||
|
|
||||||
|
(Screenshot Analysis)
|
||||||
|
Closely examine and describe the current state of the desktop along with the currently open applications.
|
||||||
|
|
||||||
|
(Next Action)
|
||||||
|
Based on the current screenshot and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task.
|
||||||
|
|
||||||
|
(Grounded Action)
|
||||||
|
Translate the next action into code using the provided API methods. Format the code like this:
|
||||||
|
```python
|
||||||
|
agent.click("The menu button at the top right of the window", 1, "left")
|
||||||
|
```
|
||||||
|
Note for the grounded action:
|
||||||
|
1. Only perform one action at a time.
|
||||||
|
2. Do not put anything other than python code in the block. You can only use one function call at a time. Do not put more than one function call in the block.
|
||||||
|
3. You must use only the available methods provided above to interact with the UI, do not invent new methods.
|
||||||
|
4. Only return one code block every time. There must be a single line of code in the code block.
|
||||||
|
5. Do not do anything other than the exact specified task. Return with `agent.done()` immediately after the subtask is completed or `agent.fail()` if it cannot be completed.
|
||||||
|
6. Whenever possible, your grounded action should use hot-keys with the agent.hotkey() action instead of clicking or dragging.
|
||||||
|
7. My computer's password is 'osworld-public-evaluation', feel free to use it when you need sudo rights.
|
||||||
|
8. Generate agent.fail() as your grounded action if you get exhaustively stuck on the task and believe it is impossible.
|
||||||
|
9. Generate agent.done() as your grounded action when your believe the task is fully complete.
|
||||||
|
10. Do not use the "command" + "tab" hotkey on MacOS.
|
||||||
|
11. Prefer hotkeys and application features over clicking on text elements when possible. Highlighting text is fine.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return procedural_memory.strip()
|
||||||
|
|
||||||
|
# For reflection agent, post-action verification mainly for cycle detection
|
||||||
|
REFLECTION_ON_TRAJECTORY = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are an expert computer use agent designed to reflect on the trajectory of a task and provide feedback on what has happened so far.
|
||||||
|
You have access to the Task Description and the Current Trajectory of another computer agent. The Current Trajectory is a sequence of a desktop image, chain-of-thought reasoning, and a desktop action for each time step. The last image is the screen's display after the last action.
|
||||||
|
|
||||||
|
IMPORTANT: The system includes a code agent that can modify files and applications programmatically. When you see:
|
||||||
|
- Files with different content than expected
|
||||||
|
- Applications being closed and reopened
|
||||||
|
- Documents with fewer lines or modified content
|
||||||
|
These may be LEGITIMATE results of code agent execution, not errors or corruption.
|
||||||
|
|
||||||
|
Your task is to generate a reflection. Your generated reflection must fall under one of the cases listed below:
|
||||||
|
|
||||||
|
Case 1. The trajectory is not going according to plan. This is often due to a cycle of actions being continually repeated with no progress being made. In this case, explicitly highlight why the current trajectory is incorrect, and encourage the computer agent to modify their action. However, DO NOT encourage a specific action in particular.
|
||||||
|
Case 2. The trajectory is going according to plan. In this case, simply tell the agent to continue proceeding as planned. DO NOT encourage a specific action in particular.
|
||||||
|
Case 3. You believe the current task has been completed. In this case, tell the agent that the task has been successfully completed.
|
||||||
|
|
||||||
|
To be successful, you must follow the rules below:
|
||||||
|
- **Your output MUST be based on one of the case options above**.
|
||||||
|
- DO NOT suggest any specific future plans or actions. Your only goal is to provide a reflection, not an actual plan or action.
|
||||||
|
- Any response that falls under Case 1 should explain why the trajectory is not going according to plan. You should especially lookout for cycles of actions that are continually repeated with no progress.
|
||||||
|
- Any response that falls under Case 2 should be concise, since you just need to affirm the agent to continue with the current trajectory.
|
||||||
|
- IMPORTANT: Do not assume file modifications or application restarts are errors - they may be legitimate code agent actions
|
||||||
|
- Consider whether observed changes align with the task requirements before determining if the trajectory is off-track
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
PHRASE_TO_WORD_COORDS_PROMPT = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are an expert in graphical user interfaces. Your task is to process a phrase of text, and identify the most relevant word on the computer screen.
|
||||||
|
You are provided with a phrase, a table with alxl the text on the screen, and a screenshot of the computer screen. You will identify the single word id that is best associated with the provided phrase.
|
||||||
|
This single word must be displayed on the computer screenshot, and its location on the screen should align with the provided phrase.
|
||||||
|
Each row in the text table provides 2 pieces of data in the following order. 1st is the unique word id. 2nd is the corresponding word.
|
||||||
|
|
||||||
|
To be successful, it is very important to follow all these rules:
|
||||||
|
1. First, think step by step and generate your reasoning about which word id to click on.
|
||||||
|
2. Then, output the unique word id. Remember, the word id is the 1st number in each row of the text table.
|
||||||
|
3. If there are multiple occurrences of the same word, use the surrounding context in the phrase to choose the correct one. Pay very close attention to punctuation and capitalization.
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_AGENT_PROMPT = textwrap.dedent("""\
|
||||||
|
You are a code execution agent with a limited step budget to complete tasks.
|
||||||
|
|
||||||
|
# Core Guidelines:
|
||||||
|
- Execute Python/Bash code step-by-step to progress toward the goal
|
||||||
|
- Use sudo with: "echo osworld-public-evaluation | sudo -S [COMMANDS]"
|
||||||
|
- Username: "user"
|
||||||
|
- Print results and handle errors appropriately
|
||||||
|
- Code execution may not show immediately on screen
|
||||||
|
|
||||||
|
# CRITICAL: Incremental Step-by-Step Approach
|
||||||
|
- Break down complex tasks into small, self-contained steps
|
||||||
|
- Each step should contain a single, focused code snippet that advances toward the goal
|
||||||
|
- Code from each step does NOT persist to the next step - write complete, standalone snippets
|
||||||
|
- Example workflow:
|
||||||
|
* Step 1: Write code to locate/find the target file
|
||||||
|
* Step 2: Write code to **THOROUGHLY** inspect/read the file contents
|
||||||
|
* Step 3: Write code to modify the file based on findings
|
||||||
|
* Step 4: Write code to verify the changes
|
||||||
|
- If verification fails (the modification did not work as intended), return to Step 3 and rewrite the modification code. Repeat until verification succeeds.
|
||||||
|
- Do NOT write entire scripts in one step - focus on one small task per step
|
||||||
|
|
||||||
|
# CRITICAL: File Modification Strategy
|
||||||
|
- ALWAYS prioritize modifying existing open files IN PLACE rather than creating new files
|
||||||
|
- The screenshot context shows which file is currently open and should be modified
|
||||||
|
- For open documents (LibreOffice .docx/.xlsx, text editors, etc.), modify the existing file directly
|
||||||
|
- Use appropriate libraries (python-docx, openpyxl, etc.) to modify files in place
|
||||||
|
- CRITICAL: When modifying files, perform COMPLETE OVERWRITES, not appends
|
||||||
|
- For documents: replace all paragraphs/sheets with new content
|
||||||
|
- For text files: write the complete new content, overwriting the old
|
||||||
|
- Only create new files when explicitly required by the task
|
||||||
|
- Verify your reasoning aligns with the user's intent for the open file
|
||||||
|
|
||||||
|
# CRITICAL: Thorough File Inspection Guidelines
|
||||||
|
- **ALWAYS inspect file contents AND data types before and after modifications**
|
||||||
|
- Check cell values, formats, data types, number formats, decimal separators, and formatting properties
|
||||||
|
- For spreadsheets: inspect cell values, number formats, date formats, currency formats, and cell properties
|
||||||
|
- For documents: inspect text content, formatting, styles, and structural elements
|
||||||
|
- Verify that modifications actually changed the intended properties (not just values)
|
||||||
|
- Compare before/after states to ensure changes were applied correctly
|
||||||
|
|
||||||
|
# CRITICAL: Code-Based Task Solving
|
||||||
|
- You are responsible for writing EXECUTABLE CODE to solve the task programmatically
|
||||||
|
- Write Python/Bash scripts that process, filter, transform, or manipulate the data as required
|
||||||
|
|
||||||
|
# CRITICAL: Preserve Document Structure and Formatting
|
||||||
|
- When modifying documents/spreadsheets, PRESERVE the original structure, headers, and formatting
|
||||||
|
- NEVER modify column headers, row headers, document titles, or sheet names unless explicitly requested
|
||||||
|
- Maintain fonts, colors, borders, cell formatting, paragraph styles, etc.
|
||||||
|
- Only change the content/data, not the structure or visual presentation
|
||||||
|
- Use libraries that support formatting preservation (python-docx, openpyxl, etc.)
|
||||||
|
- The goal is to keep the document looking exactly the same, just with different content
|
||||||
|
- **For column reordering**: Preserve table position - reorder columns within the table without shifting the table itself
|
||||||
|
|
||||||
|
# CRITICAL: Final Step Requirement
|
||||||
|
- At the final step before completing the task (the step before you return DONE), you MUST print out the contents of any files you modified
|
||||||
|
- Use appropriate commands to display the final state of modified files:
|
||||||
|
* For text files: `cat filename` or `head -n 50 filename` for large files
|
||||||
|
* For Python files: `cat filename.py`
|
||||||
|
* For configuration files: `cat filename.conf`
|
||||||
|
* For any other file type: use appropriate viewing commands
|
||||||
|
- This ensures the user can see exactly what changes were made to the files
|
||||||
|
|
||||||
|
# CRITICAL: Verification Instructions
|
||||||
|
- When you complete a task that modifies files, you MUST provide clear verification instructions
|
||||||
|
- Include specific details about what the GUI agent should check:
|
||||||
|
* Which files were modified and their expected final state
|
||||||
|
* What the content should look like (number of lines, key data points, etc.)
|
||||||
|
* How to verify the changes are correct
|
||||||
|
* Whether the task is complete or if additional GUI actions are needed
|
||||||
|
- This helps the GUI agent understand what to expect and how to verify your work correctly
|
||||||
|
|
||||||
|
# Response Format:
|
||||||
|
You MUST respond using exactly this format:
|
||||||
|
|
||||||
|
<thoughts>
|
||||||
|
Your step-by-step reasoning about what needs to be done and how to approach the current step.
|
||||||
|
</thoughts>
|
||||||
|
|
||||||
|
<answer>
|
||||||
|
Return EXACTLY ONE of the following options:
|
||||||
|
|
||||||
|
For Python code:
|
||||||
|
```python
|
||||||
|
your_python_code_here
|
||||||
|
```
|
||||||
|
|
||||||
|
For Bash commands:
|
||||||
|
```bash
|
||||||
|
your_bash_commands_here
|
||||||
|
```
|
||||||
|
|
||||||
|
For task completion:
|
||||||
|
DONE
|
||||||
|
|
||||||
|
For task failure:
|
||||||
|
FAIL
|
||||||
|
</answer>
|
||||||
|
|
||||||
|
# Technical Notes:
|
||||||
|
- Wrap code in ONE block, identify language (python/bash)
|
||||||
|
- Python code runs line-by-line in interactive terminal (no __main__)
|
||||||
|
- Install missing packages as needed
|
||||||
|
- Ignore "sudo: /etc/sudoers.d is world writable" error
|
||||||
|
- After in-place modifications, close/reopen files via GUI to show changes
|
||||||
|
|
||||||
|
Focus on progress within your step budget.
|
||||||
|
""")
|
||||||
|
|
||||||
|
CODE_SUMMARY_AGENT_PROMPT = textwrap.dedent("""\
|
||||||
|
You are a code execution summarizer. Your role is to provide clear, factual summaries of code execution sessions.
|
||||||
|
|
||||||
|
Key responsibilities:
|
||||||
|
- Summarize the code logic and approach used at each step
|
||||||
|
- Describe the outputs and results produced by code execution
|
||||||
|
- Explain the progression of the solution approach
|
||||||
|
- Use neutral, objective language without making judgments about success or failure
|
||||||
|
- Focus on what was attempted and what resulted
|
||||||
|
- Keep summaries concise and well-structured
|
||||||
|
|
||||||
|
CRITICAL: Include verification instructions for the GUI agent
|
||||||
|
- If files were modified, provide specific verification guidance:
|
||||||
|
* What files were changed and their expected final state
|
||||||
|
* What the GUI agent should look for when verifying
|
||||||
|
* How to verify the changes are correct
|
||||||
|
* Whether the task appears complete or if additional GUI actions are needed
|
||||||
|
- This helps the GUI agent understand what to expect and verify your work properly
|
||||||
|
|
||||||
|
Always maintain a factual, non-judgmental tone.
|
||||||
|
""")
|
||||||
|
|
||||||
|
BEHAVIOR_NARRATOR_SYSTEM_PROMPT = textwrap.dedent("""\
|
||||||
|
You are an expert in computer usage responsible for analyzing what happened after a computer action is taken.
|
||||||
|
|
||||||
|
**Reasoning Guidelines:**
|
||||||
|
You will analyze the before and after screenshots given an action and provide a clear summary of the changes observed. Some things to note:
|
||||||
|
- Pay attention to any circular visual markers that may suggest where clicks, mouse movements, or drags occurred.
|
||||||
|
- Clicks will be marked with a red circle and labeled Click
|
||||||
|
- Moving the mouse without clicking will be marked with a blue circle and labeled MoveTo
|
||||||
|
- Drag and drops will have an initial blue circle labeled MoveTo, a green circle labeled DragTo, and a green line connecting the two circles.
|
||||||
|
- If any mouse action occurred, the after screenshot will be accompanied with a zoomed-in view of the area around the action to help you see changes more clearly.
|
||||||
|
- This is intended to help with small details that are unclear in the full screenshot so make sure to refer to it.
|
||||||
|
- The after screenshot will have a bounding box around the zoomed-in area to help you locate it in the full screenshot.
|
||||||
|
- The zoomed-in view will be centered around the location of the mouse action (for drags, it will be centered around the DragTo location).
|
||||||
|
- Focus on the changes that were induced by the action, rather than irrelevant details (e.g. the time change in the system clock).
|
||||||
|
- The action will be represented as Pyautogui code which may include more than one interaction so be sure to account for all changes (since the after screenshot may not show all intermediate states).
|
||||||
|
- Note that even if the action is expected to cause a change, it may have not. Never assume that the action was successful without clear evidence in the screenshots.
|
||||||
|
- Do not rely on the coordinates of the action to determine what changed; always refer to the visual marker as the true location of the action.
|
||||||
|
- Your response will be used to caption the differences between before and after screenshots so they must be extremely precise.
|
||||||
|
- Make sure to include the <thoughts>...</thoughts> and <answer>...</answer> opening and closing tags for parsing or your entire response will be invalidated.
|
||||||
|
|
||||||
|
Please format your response as follows below.
|
||||||
|
<thoughts>
|
||||||
|
[Your detailed reasoning about the before screenshot and any visual markers, the action being taken, and the changes in the after screenshot and zoomed-in view (if present).]
|
||||||
|
</thoughts>
|
||||||
|
<answer>
|
||||||
|
[An unordered list of the relevant changes induced by the action]
|
||||||
|
</answer>
|
||||||
|
""")
|
||||||
|
|
||||||
|
VLM_EVALUATOR_PROMPT_COMPARATIVE_BASELINE = textwrap.dedent("""\
|
||||||
|
You are a meticulous and impartial evaluator, tasked with judging <NUMBER OF TRAJECTORIES> sequences of OS desktop actions to determine which one better completes the user's request. Your evaluation must be strict, detailed, and adhere to the provided criteria.
|
||||||
|
|
||||||
|
**User Request:**
|
||||||
|
<TASK_DESCRIPTION_INPUT>
|
||||||
|
|
||||||
|
**Judge Guidelines:**
|
||||||
|
These guidelines are to help you evaluate both sequences of actions. These are strict guidelines and should not be deviated from.
|
||||||
|
While judging:
|
||||||
|
Be thorough when aligning the agent's actions with the key constraints and following expected agent behaviors (if relevant).
|
||||||
|
The agent is always expected to complete the task; key constraints take precedence over these guidelines which act as tie breakers.
|
||||||
|
Always double-check the agent's calculations for accuracy.
|
||||||
|
Explicitly state which rows and columns must be selected.
|
||||||
|
Always verify that exact values match the user's request.
|
||||||
|
Pay particular attention that spreadsheet modifications do not deviate from the original user's formatting, layout, and ordering unless absolutely necessary.
|
||||||
|
|
||||||
|
Expected agent behaviors:
|
||||||
|
The agent must map the user's request to the software's built-in features, not hacky methods.
|
||||||
|
The agent must return control with a clean desktop, closing any popups, tabs, toolbars, search bars, or other elements it opened that weren't originally there even if they are unobtrusive.
|
||||||
|
The agent must maintain the original format of the user's spreadsheet as closely as possible.
|
||||||
|
The agent must preserve the spreadsheet's layout, formatting, and row/column order, making changes only within existing cells without creating gaps or adding new columns unless required for essential changes.
|
||||||
|
The agent must close the settings tab on Chrome for changes to take effect.
|
||||||
|
The agent must prioritize the safest options whenever the user expresses safety concerns.
|
||||||
|
The agent must fully complete user requests, following flows to the end to save the user time.
|
||||||
|
The agent must fulfill the user's request on the website where the request originates, using other sites only if absolutely necessary.
|
||||||
|
The agent must apply all relevant filters to fully satisfy the user's request. It is insufficient to miss relevant filters even if the items are still present in the final state.
|
||||||
|
|
||||||
|
**Reasoning Structure:**
|
||||||
|
1. **Evaluate both sequences of actions against relevant judge guidelines.** Explicitly list EACH AND EVERY judge guidelines, whether they apply, and, if so, verify that they were met, partially met, or not met at all for both sequences.
|
||||||
|
2. **Reason about the differences between the two sequences.** Consider which sequence better meets the judge guidelines. If they both meet the guidelines equally, consider which sequence is more efficient, effective, or cleaner.
|
||||||
|
3. **Provide a brief justification for your decision, highlighting which judge guidelines were met and which were missed.**
|
||||||
|
|
||||||
|
**Reasoning Guidelines:**
|
||||||
|
- You will be provided <NUMBER OF TRAJECTORIES> results, each result is in the form of initial_screenshot, final_screenshot.
|
||||||
|
- You **must** refer to final_screenshot to understand what has changed from initial_screenshot to final_screenshot. These facts are accurate; **Do not assume what has changed or likely changed.**
|
||||||
|
- You can cite facts during reasoning, e.g., Fact 2, Facts 1-2, but **must** refer to fact captions for accurate changes.
|
||||||
|
- You **must** explicitly write out all justifications
|
||||||
|
- You **must** enclose all reasoning in <thoughts> tags and the final answer in <answer> tags
|
||||||
|
|
||||||
|
- The user prefers that the agent communicates when it is impossible to proceed rather than attempting to complete the task incorrectly.
|
||||||
|
- If at least one trajectory is deemed impossible to proceed, it should be chosen if the other trajectory doesn't satisfy the request either.
|
||||||
|
- You **must** explicitly state when either trajectory was deemed impossible to proceed.
|
||||||
|
- You **must** explicitly write out all reasoning and justifications
|
||||||
|
|
||||||
|
Which sequence of actions better completes the user request OR correctly notes the request is impossible? Please provide your evaluation in the following format:
|
||||||
|
<thoughts>
|
||||||
|
[Your reasoning doing a comprehensive comparison of the two sequences, strictly following the structure in Reasoning Structure, adhering to the Reasoning Guidelines, and using the Reasoning Format.]
|
||||||
|
</thoughts>
|
||||||
|
<answer>
|
||||||
|
[The index of the better sequence, a single integer from 1 to <NUMBER OF TRAJECTORIES>]
|
||||||
|
</answer>
|
||||||
|
""")
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
import re
|
||||||
|
import time
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from typing import Tuple, Dict
|
||||||
|
|
||||||
|
from gui_agents.s3.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
def create_pyautogui_code(agent, code: str, obs: Dict) -> str:
|
||||||
|
"""
|
||||||
|
Attempts to evaluate the code into a pyautogui code snippet with grounded actions using the observation screenshot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent (ACI): The grounding agent to use for evaluation.
|
||||||
|
code (str): The code string to evaluate.
|
||||||
|
obs (Dict): The current observation containing the screenshot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
exec_code (str): The pyautogui code to execute the grounded action.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If there is an error in evaluating the code.
|
||||||
|
"""
|
||||||
|
agent.assign_screenshot(obs) # Necessary for grounding
|
||||||
|
exec_code = eval(code)
|
||||||
|
return exec_code
|
||||||
|
|
||||||
|
def call_llm_safe(
|
||||||
|
agent, temperature: float = 0.0, use_thinking: bool = False, **kwargs
|
||||||
|
) -> str:
|
||||||
|
# Retry if fails
|
||||||
|
max_retries = 3 # Set the maximum number of retries
|
||||||
|
attempt = 0
|
||||||
|
response = ""
|
||||||
|
while attempt < max_retries:
|
||||||
|
try:
|
||||||
|
response = agent.get_response(
|
||||||
|
temperature=temperature, use_thinking=use_thinking, **kwargs
|
||||||
|
)
|
||||||
|
assert response is not None, "Response from agent should not be None"
|
||||||
|
print("Response success!")
|
||||||
|
break # If successful, break out of the loop
|
||||||
|
except Exception as e:
|
||||||
|
attempt += 1
|
||||||
|
print(f"Attempt {attempt} failed: {e}")
|
||||||
|
if attempt == max_retries:
|
||||||
|
print("Max retries reached. Handling failure.")
|
||||||
|
time.sleep(1.0)
|
||||||
|
return response if response is not None else ""
|
||||||
|
|
||||||
|
def call_llm_formatted(generator, format_checkers, **kwargs):
|
||||||
|
"""
|
||||||
|
Calls the generator agent's LLM and ensures correct formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator (ACI): The generator agent to call.
|
||||||
|
obs (Dict): The current observation containing the screenshot.
|
||||||
|
format_checkers (Callable): Functions that take the response and return a tuple of (success, feedback).
|
||||||
|
**kwargs: Additional keyword arguments for the LLM call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
response (str): The formatted response from the generator agent.
|
||||||
|
"""
|
||||||
|
max_retries = 3 # Set the maximum number of retries
|
||||||
|
attempt = 0
|
||||||
|
response = ""
|
||||||
|
messages = generator.messages.copy() # Copy messages to avoid modifying the original
|
||||||
|
while attempt < max_retries:
|
||||||
|
response = call_llm_safe(generator, messages=messages, **kwargs)
|
||||||
|
|
||||||
|
# Prepare feedback messages for incorrect formatting
|
||||||
|
feedback_msgs = []
|
||||||
|
for format_checker in format_checkers:
|
||||||
|
success, feedback = format_checker(response)
|
||||||
|
if not success:
|
||||||
|
feedback_msgs.append(feedback)
|
||||||
|
if not feedback_msgs:
|
||||||
|
# logger.info(f"Response formatted correctly on attempt {attempt} for {generator.engine.model}")
|
||||||
|
break
|
||||||
|
logger.error(f"Response formatting error on attempt {attempt} for {generator.engine.model}. Response: {response} {', '.join(feedback_msgs)}")
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": response}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(f"Bad response: {response}")
|
||||||
|
delimiter = "\n- "
|
||||||
|
formatting_feedback = f"- {delimiter.join(feedback_msgs)}"
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": PROCEDURAL_MEMORY.FORMATTING_FEEDBACK_PROMPT.replace("FORMATTING_FEEDBACK", formatting_feedback)}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info("Feedback:\n%s", formatting_feedback)
|
||||||
|
|
||||||
|
attempt += 1
|
||||||
|
if attempt == max_retries:
|
||||||
|
logger.error("Max retries reached when formatting response. Handling failure.")
|
||||||
|
time.sleep(1.0)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def split_thinking_response(full_response: str) -> Tuple[str, str]:
|
||||||
|
try:
|
||||||
|
# Extract thoughts section
|
||||||
|
thoughts = full_response.split("<thoughts>")[-1].split("</thoughts>")[0].strip()
|
||||||
|
|
||||||
|
# Extract answer section
|
||||||
|
answer = full_response.split("<answer>")[-1].split("</answer>")[0].strip()
|
||||||
|
|
||||||
|
return answer, thoughts
|
||||||
|
except Exception as e:
|
||||||
|
return full_response, ""
|
||||||
|
|
||||||
|
def parse_code_from_string(input_string):
|
||||||
|
""" Parses a string to extract each line of code enclosed in triple backticks (```)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_string (str): The input string containing code snippets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The last code snippet found in the input string, or an empty string if no code is found.
|
||||||
|
"""
|
||||||
|
input_string = input_string.strip()
|
||||||
|
|
||||||
|
# This regular expression will match both ```code``` and ```python code```
|
||||||
|
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||||
|
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||||
|
|
||||||
|
# Find all non-overlapping matches in the string
|
||||||
|
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||||
|
if len(matches) == 0:
|
||||||
|
# return []
|
||||||
|
return ""
|
||||||
|
relevant_code = matches[-1] # We only care about the last match given it is the grounded action
|
||||||
|
return relevant_code
|
||||||
|
|
||||||
|
def extract_agent_functions(code):
|
||||||
|
"""Extracts all agent function calls from the given code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code (str): The code string to search for agent function calls.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of all agent function calls found in the code.
|
||||||
|
"""
|
||||||
|
pattern = r'(agent\.\w+\(\s*.*\))' # Matches
|
||||||
|
return re.findall(pattern, code)
|
||||||
|
|
||||||
|
def compress_image(image_bytes: bytes = None, image: Image = None) -> bytes:
|
||||||
|
"""Compresses an image represented as bytes.
|
||||||
|
|
||||||
|
Compression involves resizing image into half its original size and saving to webp format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes (bytes): The image data to compress.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The compressed image data.
|
||||||
|
"""
|
||||||
|
if not image:
|
||||||
|
image = Image.open(BytesIO(image_bytes))
|
||||||
|
output = BytesIO()
|
||||||
|
image.save(output, format='WEBP')
|
||||||
|
compressed_image_bytes = output.getvalue()
|
||||||
|
return compressed_image_bytes
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""This file contains various formatting checks used to reprompt an agent for correctly formatted responses."""
|
||||||
|
from gui_agents.s3.utils.common_utils import (
|
||||||
|
extract_agent_functions,
|
||||||
|
parse_code_from_string,
|
||||||
|
create_pyautogui_code,
|
||||||
|
split_thinking_response
|
||||||
|
)
|
||||||
|
|
||||||
|
single_action_check = lambda response: len(extract_agent_functions(parse_code_from_string(response))) == 1
|
||||||
|
single_action_error_msg = "Incorrect code: There must be a single agent action in the code response."
|
||||||
|
SINGLE_ACTION_FORMATTER = lambda response: (
|
||||||
|
single_action_check(response), single_action_error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
def _attempt_code_creation(agent, code, obs):
|
||||||
|
""" Attempts to create a pyautogui code snippet from the response code """
|
||||||
|
try:
|
||||||
|
return create_pyautogui_code(agent, code, obs)
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
code_valid_check = lambda agent, obs, response: _attempt_code_creation(agent, parse_code_from_string(response), obs) is not None
|
||||||
|
code_valid_error_msg = "Incorrect code: The agent action must be a valid function and use valid parameters from the docstring list."
|
||||||
|
CODE_VALID_FORMATTER = lambda agent, obs, response: (
|
||||||
|
code_valid_check(agent, obs, response), code_valid_error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
thoughts_answer_tag_check = lambda response: split_thinking_response(response)[1] != ""
|
||||||
|
thoughts_answer_tag_error_msg = "Incorrect response: The response must contain both <thoughts>...</thoughts> and <answer>...</answer> tags."
|
||||||
|
THOUGHTS_ANSWER_TAG_FORMATTER = lambda response: (
|
||||||
|
thoughts_answer_tag_check(response), thoughts_answer_tag_error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
integer_answer_check = lambda response: split_thinking_response(response)[0].strip().isdigit()
|
||||||
|
integer_answer_error_msg = "Incorrect response: The <answer>...</answer> tag must contain a single integer."
|
||||||
|
INTEGER_ANSWER_FORMATTER = lambda response: (
|
||||||
|
integer_answer_check(response), integer_answer_error_msg
|
||||||
|
)
|
||||||
Arquivo binário não exibido.
|
Depois Largura: | Altura: | Tamanho: 221 KiB |
@@ -0,0 +1,54 @@
|
|||||||
|
# Deplying Agent S3 in OSWorld
|
||||||
|
|
||||||
|
# Step 1: Set up Agent S3
|
||||||
|
|
||||||
|
Follow the [README.md](https://github.com/simular-ai/Agent-S/blob/main/README.md) to set up Agent S3.
|
||||||
|
|
||||||
|
# Step 2: Copying Over Run Files
|
||||||
|
|
||||||
|
If you haven't already, please follow the [OSWorld environment setup](https://github.com/xlang-ai/OSWorld/blob/main/README.md). We've provided the relevant OSWorld run files for evaluation in this `osworld_setup` folder. Please copy this over to your OSWorld folder. `run_local.py` is for if you want to run locally on VMWare and `run.py` and `lib_run_single.py` are for if you want to run on AWS. All run commands in order are provided in the `run.sh`. Copy over the files in `osworld_setup/s3/bbon` as well.
|
||||||
|
|
||||||
|
# Step 3: Switch the AMI
|
||||||
|
|
||||||
|
Switch image AMI for the AWS provider in `desktop_env/providers/aws/manager.py` is set to `"ami-0b505e9d0d99ba88c"`.
|
||||||
|
|
||||||
|
# Step 4: Generating Facts
|
||||||
|
|
||||||
|
After completing your OSWorld runs and having result directories, run `generate_facts.py` to generate fact captions for screenshot pairs:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python osworld_setup/s3/bbon/generate_facts.py \
|
||||||
|
--results-dirs \
|
||||||
|
results1/pyautogui/screenshot/gpt-5-2025-08-07 \
|
||||||
|
results2/pyautogui/screenshot/gpt-5-2025-08-07 \
|
||||||
|
--model "gpt-5-2025-08-07" \
|
||||||
|
--engine-type "openai" \
|
||||||
|
--temperature 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
This will populate your result directories with `fact_captions.jsonl` files containing behavioral descriptions of screenshot differences.
|
||||||
|
|
||||||
|
# Step 5: Run the Judge
|
||||||
|
|
||||||
|
Finally, run `run_judge.py` to evaluate the trajectories using the generated fact captions:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python osworld_setup/s3/bbon/run_judge.py \
|
||||||
|
--results-dirs \
|
||||||
|
results1/pyautogui/screenshot/gpt-5-2025-08-07 \
|
||||||
|
results2/pyautogui/screenshot/gpt-5-2025-08-07 \
|
||||||
|
--output-dir "judge_results" \
|
||||||
|
--examples-path "evaluation_examples/examples" \
|
||||||
|
--model "gpt-5-2025-08-07" \
|
||||||
|
--engine-type "openai" \
|
||||||
|
--temperature 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Compare trajectories across different result directories
|
||||||
|
- Use the facts to judge which trajectory performs better
|
||||||
|
- Generate evaluation results
|
||||||
|
- Save results to the specified output directory
|
||||||
|
|
||||||
|
The judge will create files like `BoN2.json`, `BoN3.json`, etc., showing the performance comparison as you add more trajectories.
|
||||||
|
|
||||||
@@ -0,0 +1,201 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
from typing import List, Optional
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from gui_agents.s3.bbon.behavior_narrator import BehaviorNarrator
|
||||||
|
from utils import get_new_tasks_classification
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_single_fact_caption(task_dir: str, screenshot_files: List[str], i: int, judge: BehaviorNarrator, trajectory_lines: List[str]):
|
||||||
|
"""Generate a single fact caption for a screenshot pair."""
|
||||||
|
before_file = os.path.join(task_dir, screenshot_files[i])
|
||||||
|
after_file = os.path.join(task_dir, screenshot_files[i + 1])
|
||||||
|
|
||||||
|
# Load action from trajectory data if available
|
||||||
|
pyautogui_action = None
|
||||||
|
if i < len(trajectory_lines):
|
||||||
|
try:
|
||||||
|
data = json.loads(trajectory_lines[i])
|
||||||
|
pyautogui_action = data.get("exec_code")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if pyautogui_action is None:
|
||||||
|
raise ValueError(f"No pyautogui action found for step {i+1}")
|
||||||
|
|
||||||
|
# Read image bytes
|
||||||
|
try:
|
||||||
|
with open(before_file, "rb") as f:
|
||||||
|
before_bytes = f.read()
|
||||||
|
with open(after_file, "rb") as f:
|
||||||
|
after_bytes = f.read()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error reading images: {e}")
|
||||||
|
|
||||||
|
# Generate fact caption using behavior narrator
|
||||||
|
result = await asyncio.to_thread(judge.judge, before_bytes, after_bytes, pyautogui_action)
|
||||||
|
result["screenshot_num"] = i + 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_fact_captions_parallel(task_dir: str, judge: BehaviorNarrator, step_semaphore: Optional[asyncio.Semaphore] = None):
|
||||||
|
"""Generate fact captions for a task directory when they don't exist (parallelized version)."""
|
||||||
|
print(f"Generating fact captions for {task_dir}...")
|
||||||
|
|
||||||
|
# Find all screenshot files
|
||||||
|
screenshot_files = []
|
||||||
|
for filename in os.listdir(task_dir):
|
||||||
|
if filename.startswith("step_") and filename.endswith(".png"):
|
||||||
|
screenshot_files.append(filename)
|
||||||
|
|
||||||
|
# Sort by step number
|
||||||
|
def extract_step_num(filename):
|
||||||
|
try:
|
||||||
|
return int(filename.split("_")[1].split(".")[0])
|
||||||
|
except:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
screenshot_files.sort(key=extract_step_num)
|
||||||
|
|
||||||
|
if len(screenshot_files) < 2:
|
||||||
|
print(f"Not enough screenshots to generate fact captions in {task_dir}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Load trajectory data once
|
||||||
|
trajectory_lines = []
|
||||||
|
trajectory_file = os.path.join(task_dir, "traj.jsonl")
|
||||||
|
if os.path.exists(trajectory_file):
|
||||||
|
try:
|
||||||
|
with open(trajectory_file, "r") as f:
|
||||||
|
trajectory_lines = f.readlines()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Use shared semaphore to limit concurrent judge calls
|
||||||
|
if step_semaphore is None:
|
||||||
|
step_semaphore = asyncio.Semaphore(5) # Default limit
|
||||||
|
|
||||||
|
async def bounded_task(task_func, *args, **kwargs):
|
||||||
|
async with step_semaphore:
|
||||||
|
return await task_func(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create bounded tasks for parallel execution
|
||||||
|
bounded_tasks = [
|
||||||
|
bounded_task(generate_single_fact_caption, task_dir, screenshot_files, i, judge, trajectory_lines)
|
||||||
|
for i in range(len(screenshot_files) - 1)
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*bounded_tasks, return_exceptions=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in parallel execution: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Process results and save to file
|
||||||
|
fact_captions = []
|
||||||
|
successful_results = []
|
||||||
|
fact_captions_file = os.path.join(task_dir, "fact_captions.jsonl")
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
print(f"Error generating fact caption for step {i+1}: {result}")
|
||||||
|
continue
|
||||||
|
successful_results.append(result)
|
||||||
|
fact_caption = f"Fact Caption from Screenshot {result['screenshot_num']}: {result['fact_answer']}"
|
||||||
|
fact_captions.append(fact_caption)
|
||||||
|
|
||||||
|
# Save all results to file at once
|
||||||
|
if successful_results:
|
||||||
|
with open(fact_captions_file, "w") as f:
|
||||||
|
for result in successful_results:
|
||||||
|
f.write(json.dumps(result) + "\n")
|
||||||
|
|
||||||
|
print(f"Generated {len(fact_captions)} fact captions for {task_dir}")
|
||||||
|
return fact_captions
|
||||||
|
|
||||||
|
|
||||||
|
async def main(engine_params: dict, results_dirs: List[str]):
|
||||||
|
"""Main function to generate fact captions for multiple task directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine_params: Engine parameters for BehaviorNarrator
|
||||||
|
results_dirs: List of results directories to analyze for task classification
|
||||||
|
"""
|
||||||
|
# Get task IDs automatically using get_new_tasks_classification
|
||||||
|
tasks_classification = get_new_tasks_classification(results_dirs)
|
||||||
|
task_ids = tasks_classification['variance']
|
||||||
|
|
||||||
|
print(f"Found {len(task_ids)} variance tasks to process")
|
||||||
|
judge = BehaviorNarrator(engine_params=engine_params)
|
||||||
|
|
||||||
|
# Get concurrency settings from environment
|
||||||
|
per_step = int(os.getenv("DIFFCAP_PER_STEP_CONCURRENCY", "100"))
|
||||||
|
per_taskdir = int(os.getenv("DIFFCAP_PER_TASKDIR_CONCURRENCY", "4"))
|
||||||
|
|
||||||
|
# Build list of task directories to process
|
||||||
|
task_dirs = []
|
||||||
|
for task_id in task_ids:
|
||||||
|
domain, example_id = task_id.split("/")
|
||||||
|
|
||||||
|
# Check each results directory for this task
|
||||||
|
for results_dir in results_dirs:
|
||||||
|
task_dir = os.path.join(results_dir, domain, example_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "fact_captions.jsonl" in os.listdir(task_dir):
|
||||||
|
print(f"Fact captions already exist for {task_dir}")
|
||||||
|
continue
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
task_dirs.append(task_dir)
|
||||||
|
|
||||||
|
if not task_dirs:
|
||||||
|
print("No new task directories to process.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Scheduling {len(task_dirs)} task directories...")
|
||||||
|
|
||||||
|
# Set up semaphores for concurrency control
|
||||||
|
shared_step_semaphore = asyncio.Semaphore(per_step)
|
||||||
|
taskdir_semaphore = asyncio.Semaphore(per_taskdir)
|
||||||
|
|
||||||
|
async def run_one(task_dir):
|
||||||
|
async with taskdir_semaphore:
|
||||||
|
print(f"Processing {task_dir}")
|
||||||
|
return await generate_fact_captions_parallel(task_dir, judge, step_semaphore=shared_step_semaphore)
|
||||||
|
|
||||||
|
# Execute all tasks in parallel
|
||||||
|
results = await asyncio.gather(*[run_one(d) for d in task_dirs], return_exceptions=True)
|
||||||
|
|
||||||
|
# Report results
|
||||||
|
failures = sum(1 for r in results if isinstance(r, Exception))
|
||||||
|
if failures:
|
||||||
|
print(f"Completed with {failures} failures out of {len(task_dirs)} task directories.")
|
||||||
|
else:
|
||||||
|
print("Completed all task directories successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Generate fact captions for OSWorld task directories")
|
||||||
|
parser.add_argument("--results-dirs", nargs="+", required=True, help="List of results directories to analyze for task classification")
|
||||||
|
parser.add_argument("--model", default="gpt-5-2025-08-07", help="Model to use for generation")
|
||||||
|
parser.add_argument("--engine-type", default="openai", help="Engine type")
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Engine parameters
|
||||||
|
engine_params = {
|
||||||
|
"model": args.model,
|
||||||
|
"engine_type": args.engine_type,
|
||||||
|
"temperature": args.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Results directories: {args.results_dirs}")
|
||||||
|
asyncio.run(main(engine_params, args.results_dirs))
|
||||||
@@ -0,0 +1,212 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
import concurrent.futures
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from tqdm.asyncio import tqdm_asyncio
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
from utils import (
|
||||||
|
get_new_tasks_classification,
|
||||||
|
evaluate_comparative_results,
|
||||||
|
load_task_instruction,
|
||||||
|
load_facts
|
||||||
|
)
|
||||||
|
from gui_agents.s3.bbon.comparative_judge import ComparativeJudge
|
||||||
|
|
||||||
|
def run_judge(task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge) -> Tuple[str, str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Fact captions + initial/final screenshots judging.
|
||||||
|
Pipeline: load trajectories → load existing fact captions → include initial/final screenshots → judge.
|
||||||
|
"""
|
||||||
|
# 1. Use provided task instruction
|
||||||
|
# task_instruction is now a direct input parameter
|
||||||
|
|
||||||
|
# 2. Load fact captions for all trajectories
|
||||||
|
all_fact_captions = []
|
||||||
|
for result_dir in result_dirs:
|
||||||
|
task_dir = os.path.join(result_dir, task.split("/")[0], task.split("/")[1])
|
||||||
|
fact_captions = load_facts(task_dir)
|
||||||
|
all_fact_captions.append(fact_captions)
|
||||||
|
|
||||||
|
# 3. Use the new Judge class method
|
||||||
|
return judge.judge(task_instruction, task, result_dirs, all_fact_captions)
|
||||||
|
|
||||||
|
def evaluate_trajectories(task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge) -> Tuple[str, str, dict]:
|
||||||
|
"""Wrapper that runs fact-only MCQ judge and returns results."""
|
||||||
|
answer, thoughts, selected_trajectory = run_judge(task, task_instruction, result_dirs, judge)
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"selected_trajectory": selected_trajectory,
|
||||||
|
"answer": answer,
|
||||||
|
"thoughts": thoughts,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"✅ Added task {task} (MCQ fact-only)")
|
||||||
|
return answer, thoughts, record
|
||||||
|
|
||||||
|
asyncio.get_event_loop().set_default_executor(
|
||||||
|
concurrent.futures.ThreadPoolExecutor(max_workers=100)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_async(task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge):
|
||||||
|
"""Async wrapper for fact-only MCQ evaluation."""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
evaluate_trajectories, task=task, task_instruction=task_instruction, result_dirs=result_dirs, judge=judge
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def evaluate_and_save(result_dirs: List[str], output_file_path: str, examples_path: str, engine_params: dict):
|
||||||
|
"""Main evaluation function that processes tasks and saves results."""
|
||||||
|
res = get_new_tasks_classification(results_dirs=result_dirs)
|
||||||
|
for key in res:
|
||||||
|
print(f"{key}: {res[key]}")
|
||||||
|
optimal, minimum, expected_value = res["optimal"], res["minimum"], res["expected_value"]
|
||||||
|
print(f"optimal score: {optimal}, minimum score: {minimum}")
|
||||||
|
|
||||||
|
variance = res["variance"]
|
||||||
|
|
||||||
|
judge = ComparativeJudge(engine_params=engine_params)
|
||||||
|
|
||||||
|
# Load existing results
|
||||||
|
if os.path.exists(output_file_path):
|
||||||
|
with open(output_file_path, "r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
data = {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
data = {}
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
# Prepare async tasks only for tasks not yet in data
|
||||||
|
tasks = []
|
||||||
|
task_names = []
|
||||||
|
for task in variance:
|
||||||
|
if str(task) in data:
|
||||||
|
print(f"⚠️ Task {task} already exists in results — skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load task instruction from examples path
|
||||||
|
task_instruction = load_task_instruction(task, examples_path)
|
||||||
|
if task_instruction is None:
|
||||||
|
print(f"⚠️ No task instruction found for {task}, skipping...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
tasks.append(run_async(task, task_instruction, result_dirs, judge))
|
||||||
|
task_names.append(task)
|
||||||
|
|
||||||
|
# Run only new tasks
|
||||||
|
results = await tqdm_asyncio.gather(*tasks)
|
||||||
|
# Merge into existing results
|
||||||
|
for task, (ans, thoughts, record) in zip(task_names, results):
|
||||||
|
data[str(task)] = record
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
|
||||||
|
with open(output_file_path, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
res = evaluate_comparative_results(result_dirs, json_path=output_file_path)
|
||||||
|
gain, maximum_gain = res
|
||||||
|
data["score"] = {
|
||||||
|
"optimal": optimal,
|
||||||
|
"minimum": minimum,
|
||||||
|
"expected_value": expected_value,
|
||||||
|
"res": res,
|
||||||
|
"actual score": minimum + gain
|
||||||
|
}
|
||||||
|
os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
|
||||||
|
with open(output_file_path, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def run_experiment(shuffled_runs: List[str], output_dir: str, examples_path: str, engine_params: dict, start_round: int = 2, max_rounds: int = None):
|
||||||
|
"""
|
||||||
|
Run fact-only experiments progressively: start_round vs start_round+1, etc.
|
||||||
|
"""
|
||||||
|
if max_rounds is None:
|
||||||
|
max_rounds = len(shuffled_runs)
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for i in range(start_round, max_rounds + 1): # start at start_round (default 2)
|
||||||
|
test_dirs = shuffled_runs[:i]
|
||||||
|
output_file_path = os.path.join(output_dir, f"BoN{i}.json")
|
||||||
|
|
||||||
|
print(f"Running fact-only experiment with {i} dirs → {output_file_path}")
|
||||||
|
await evaluate_and_save(test_dirs, output_file_path, examples_path, engine_params)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(shuffled_runs: List[str] = None, output_dir: str = None, examples_path: str = None, engine_params: dict = None, start_round: int = 2, max_rounds: int = None):
|
||||||
|
"""Main function to run fact-only judge experiments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shuffled_runs: List of result directory paths to compare
|
||||||
|
output_dir: Directory to save results
|
||||||
|
examples_path: Path to examples directory containing task instructions
|
||||||
|
engine_params: Engine parameters for the judge
|
||||||
|
start_round: Starting round number (default: 2)
|
||||||
|
max_rounds: Maximum number of rounds to run (default: len(shuffled_runs))
|
||||||
|
"""
|
||||||
|
if shuffled_runs is None:
|
||||||
|
print("Error: shuffled_runs must be provided")
|
||||||
|
return
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
print("Error: output_dir must be provided")
|
||||||
|
return
|
||||||
|
|
||||||
|
if examples_path is None:
|
||||||
|
print("Error: examples_path must be provided")
|
||||||
|
return
|
||||||
|
|
||||||
|
if engine_params is None:
|
||||||
|
print("Error: engine_params must be provided")
|
||||||
|
return
|
||||||
|
|
||||||
|
await run_experiment(shuffled_runs, output_dir, examples_path, engine_params, start_round, max_rounds)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run fact-only judge experiments on OSWorld task directories")
|
||||||
|
parser.add_argument("--results-dirs", nargs="+", required=True, help="List of results directories to analyze")
|
||||||
|
parser.add_argument("--output-dir", required=True, help="Directory to save results")
|
||||||
|
parser.add_argument("--examples-path", required=True, help="Path to examples directory containing task instructions")
|
||||||
|
parser.add_argument("--start-round", type=int, default=2, help="Starting round number (default: 2)")
|
||||||
|
parser.add_argument("--max-rounds", type=int, default=None, help="Maximum number of rounds to run (default: len(results_dirs))")
|
||||||
|
parser.add_argument("--model", default="gpt-5-2025-08-07", help="Model to use for judging")
|
||||||
|
parser.add_argument("--engine-type", default="openai", help="Engine type")
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Engine parameters
|
||||||
|
engine_params = {
|
||||||
|
"model": args.model,
|
||||||
|
"engine_type": args.engine_type,
|
||||||
|
"temperature": args.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Results directories: {args.results_dirs}")
|
||||||
|
print(f"Output directory: {args.output_dir}")
|
||||||
|
print(f"Examples path: {args.examples_path}")
|
||||||
|
print(f"Start round: {args.start_round}")
|
||||||
|
print(f"Max rounds: {args.max_rounds}")
|
||||||
|
print(f"Engine params: {engine_params}")
|
||||||
|
|
||||||
|
# Run fact-only evaluation
|
||||||
|
asyncio.run(
|
||||||
|
main(
|
||||||
|
shuffled_runs=args.results_dirs,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
examples_path=args.examples_path,
|
||||||
|
engine_params=engine_params,
|
||||||
|
start_round=args.start_round,
|
||||||
|
max_rounds=args.max_rounds,
|
||||||
|
)
|
||||||
|
)
|
||||||
@@ -0,0 +1,283 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Optional, List
|
||||||
|
import base64
|
||||||
|
|
||||||
|
def image_to_openai_message_format(image_path: str, caption: str = None) -> Optional[dict]:
|
||||||
|
"""Convert an image file to OpenAI message format."""
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
print(f"Image file not found: {image_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(image_path, "rb") as f:
|
||||||
|
image_bytes = f.read()
|
||||||
|
|
||||||
|
if not image_bytes:
|
||||||
|
print(f"Empty image file: {image_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
if not base64_image:
|
||||||
|
print(f"Failed to encode image to base64: {image_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if caption:
|
||||||
|
content.append({"type": "text", "text": caption})
|
||||||
|
|
||||||
|
content.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/png;base64,{base64_image}"}
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"role": "user", "content": content}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing image {image_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_facts(task_dir: str) -> List[str]:
|
||||||
|
"""Load existing facts from facts.jsonl file."""
|
||||||
|
fact_captions_file = os.path.join(task_dir, "fact_captions.jsonl")
|
||||||
|
|
||||||
|
if not os.path.exists(fact_captions_file):
|
||||||
|
print(f"fact_captions.jsonl not found at {fact_captions_file}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
fact_captions = []
|
||||||
|
with open(fact_captions_file, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
if "fact_answer" in data:
|
||||||
|
fact_captions.append(data["fact_answer"])
|
||||||
|
|
||||||
|
return fact_captions
|
||||||
|
|
||||||
|
def load_task_instruction(task: str, examples_path: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Load task instruction from examples path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Task ID in format "domain/example_id"
|
||||||
|
examples_path: Path to the examples directory (e.g., "/home/ubuntu/Simular/OSWorld/evaluation_examples/examples")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task instruction string or None if not found
|
||||||
|
"""
|
||||||
|
domain, example_id = task.split("/", 1)
|
||||||
|
|
||||||
|
# Construct path to the JSON file
|
||||||
|
json_file_path = os.path.join(examples_path, domain, f"{example_id}.json")
|
||||||
|
|
||||||
|
if not os.path.exists(json_file_path):
|
||||||
|
logging.warning(f"Example file not found: {json_file_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Extract instruction from the JSON
|
||||||
|
if "instruction" in data:
|
||||||
|
instruction = data["instruction"]
|
||||||
|
if instruction and instruction.strip():
|
||||||
|
return instruction.strip()
|
||||||
|
|
||||||
|
logging.warning(f"No 'instruction' key found in {json_file_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Error reading example file {json_file_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_final_screenshot_file(result_dir: str) -> str:
|
||||||
|
"""
|
||||||
|
Finds the screenshot file with the largest valid step index in the given directory.
|
||||||
|
Works with filenames like step_0.png, step_1_20250.png, step-2.png, etc.
|
||||||
|
Only considers .png files (case-insensitive).
|
||||||
|
If the highest index file is invalid/corrupted, it tries the next lower index.
|
||||||
|
Returns None if no valid matching files are found.
|
||||||
|
"""
|
||||||
|
# First, collect all valid step files with their indices
|
||||||
|
step_files = {}
|
||||||
|
pattern = re.compile(r"step[_\-]?(\d+)", re.IGNORECASE)
|
||||||
|
|
||||||
|
for fname in os.listdir(result_dir):
|
||||||
|
if not fname.lower().endswith(".png"):
|
||||||
|
continue
|
||||||
|
match = pattern.match(fname)
|
||||||
|
if match:
|
||||||
|
idx = int(match.group(1))
|
||||||
|
step_files[idx] = fname
|
||||||
|
if not step_files:
|
||||||
|
return None
|
||||||
|
# Sort indices in descending order (highest first)
|
||||||
|
sorted_indices = sorted(step_files.keys(), reverse=True)
|
||||||
|
# Try each file from highest to lowest index
|
||||||
|
for idx in sorted_indices:
|
||||||
|
fname = step_files[idx]
|
||||||
|
file_path = os.path.join(result_dir, fname)
|
||||||
|
# Check if file exists and is valid
|
||||||
|
if os.path.exists(file_path) and is_valid_image(file_path):
|
||||||
|
return fname
|
||||||
|
else:
|
||||||
|
print(f"Invalid or corrupted image at step {idx}: {fname}, trying previous step...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_valid_image(file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an image file is valid by trying to open it with PIL.
|
||||||
|
Also checks if file is not empty.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check file size first (quick check)
|
||||||
|
if os.path.getsize(file_path) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Try to open and verify the image
|
||||||
|
with Image.open(file_path) as img:
|
||||||
|
img.verify() # This will raise an exception if image is corrupted
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Image validation failed for {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_tasks_classification(results_dirs: [str]):
|
||||||
|
# Step 1: collect domain/task_ids for each trajectory
|
||||||
|
tasks_per_dir = []
|
||||||
|
for results_dir in results_dirs:
|
||||||
|
domain_tasks = set()
|
||||||
|
for domain in os.listdir(results_dir):
|
||||||
|
domain_dir = os.path.join(results_dir, domain)
|
||||||
|
if not os.path.isdir(domain_dir):
|
||||||
|
continue
|
||||||
|
for task_id in os.listdir(domain_dir):
|
||||||
|
task_dir = os.path.join(domain_dir, task_id)
|
||||||
|
if os.path.isdir(task_dir):
|
||||||
|
domain_tasks.add(f"{domain}/{task_id}")
|
||||||
|
tasks_per_dir.append(domain_tasks)
|
||||||
|
|
||||||
|
# Step 2: find tasks common to all trajectories
|
||||||
|
common_tasks = set.intersection(*tasks_per_dir)
|
||||||
|
|
||||||
|
constant_tasks = []
|
||||||
|
variance_tasks = []
|
||||||
|
constant_tasks_scores = []
|
||||||
|
optimal_sum = 0.0
|
||||||
|
expected_value = 0.0
|
||||||
|
|
||||||
|
# Step 3: evaluate each common task
|
||||||
|
for domain_task in sorted(common_tasks):
|
||||||
|
domain, task_id = domain_task.split("/", 1)
|
||||||
|
results = []
|
||||||
|
for results_dir in results_dirs:
|
||||||
|
task_dir = os.path.join(results_dir, domain, task_id)
|
||||||
|
result_file = os.path.join(task_dir, "result.txt")
|
||||||
|
if os.path.isfile(result_file):
|
||||||
|
with open(result_file, "r") as f:
|
||||||
|
try:
|
||||||
|
val = float(f.read().strip())
|
||||||
|
results.append(val)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not results: # skip if no valid results
|
||||||
|
logging.warning(f"No valid results for {domain_task}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# classification
|
||||||
|
if all(r == results[0] for r in results):
|
||||||
|
constant_tasks.append(domain_task)
|
||||||
|
constant_tasks_scores.append(results[0])
|
||||||
|
else:
|
||||||
|
variance_tasks.append(domain_task)
|
||||||
|
|
||||||
|
# accumulate min/optimal
|
||||||
|
# minimum_sum += min(results) #We incorrectly also counted the minimum sum of variance tasks, we should not do this
|
||||||
|
optimal_sum += max(results)
|
||||||
|
expected_value += sum(results) / len(results)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"constant": constant_tasks, #We dont evaluate constant tasks
|
||||||
|
"variance": variance_tasks, #We evaluate variance tasks
|
||||||
|
"minimum": sum(constant_tasks_scores), #sum of constant tasks scores (easy + hard)
|
||||||
|
"optimal": optimal_sum, #If we get the best score, we get the optimal score
|
||||||
|
"expected_value": expected_value, #If we get the average score across all tasks for all trajectories, we get the expected value
|
||||||
|
}
|
||||||
|
|
||||||
|
def check_selected_trajectory(results_dirs: [str], selected_trajectory: str, task: str):
|
||||||
|
"""
|
||||||
|
results_dirs: list of directories in format results_dir/<domain>/<task_id>
|
||||||
|
selected_trajectory: the path of the selected trajectory
|
||||||
|
task: string in format "<domain>/<task_id>"
|
||||||
|
|
||||||
|
Returns (selected_val, optimal_val)
|
||||||
|
"""
|
||||||
|
domain, task_id = task.split("/")
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
if not any(
|
||||||
|
os.path.commonpath([os.path.abspath(selected_trajectory), os.path.abspath(rd)]) == os.path.abspath(rd)
|
||||||
|
for rd in results_dirs
|
||||||
|
):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
for rd in results_dirs:
|
||||||
|
result_file = os.path.join(rd, domain, task_id, "result.txt")
|
||||||
|
if os.path.isfile(result_file):
|
||||||
|
try:
|
||||||
|
all_results.append(float(open(result_file).read().strip()))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
selected_file = os.path.join(selected_trajectory, domain, task_id, "result.txt")
|
||||||
|
if not os.path.isfile(selected_file):
|
||||||
|
return None, max(all_results) if all_results else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
selected_val = float(open(selected_file).read().strip())
|
||||||
|
except ValueError:
|
||||||
|
return None, max(all_results) if all_results else None
|
||||||
|
|
||||||
|
optimal_val = max(all_results) if all_results else selected_val
|
||||||
|
return selected_val, optimal_val
|
||||||
|
|
||||||
|
def evaluate_comparative_results(results_dirs: [str], json_path: str = None):
|
||||||
|
"""
|
||||||
|
Opens comparative_judge_results.json (default) or a given path,
|
||||||
|
evaluates each task, and returns results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_dirs: list of result directories
|
||||||
|
json_path: optional path to comparative_judge_results.json
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict mapping task -> {"selected_val": float or None, "optimal_val": float or None}
|
||||||
|
"""
|
||||||
|
judge_score = 0
|
||||||
|
optimal_score = 0
|
||||||
|
if json_path is None:
|
||||||
|
json_path = "comparative_judge_results.json"
|
||||||
|
|
||||||
|
with open(json_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for task, info in data.items():
|
||||||
|
selected_trajectory = info.get("selected_trajectory")
|
||||||
|
if selected_trajectory:
|
||||||
|
selected_val, optimal_val = check_selected_trajectory(results_dirs, selected_trajectory, task)
|
||||||
|
if selected_val is not None and optimal_val is not None:
|
||||||
|
print(f"task: {task}, selected_val: {selected_val}, optimal_val: {optimal_val}")
|
||||||
|
judge_score += selected_val
|
||||||
|
optimal_score += optimal_val
|
||||||
|
return judge_score, optimal_score
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import *
|
||||||
|
from wrapt_timeout_decorator import *
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
|
def run_single_example(
|
||||||
|
agent, env, example, max_steps, instruction, args, example_result_dir, scores
|
||||||
|
):
|
||||||
|
runtime_logger = setup_logger(example, example_result_dir)
|
||||||
|
try:
|
||||||
|
agent.reset(runtime_logger)
|
||||||
|
except Exception as e:
|
||||||
|
agent.reset()
|
||||||
|
|
||||||
|
env.reset(task_config=example)
|
||||||
|
time.sleep(60) # Wait for the environment to be ready
|
||||||
|
obs = env._get_obs() # Get the initial observation
|
||||||
|
|
||||||
|
with open(os.path.join(example_result_dir, f"step_0.png"), "wb") as _f:
|
||||||
|
_f.write(obs["screenshot"])
|
||||||
|
|
||||||
|
with open(os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8") as f:
|
||||||
|
f.write(instruction)
|
||||||
|
|
||||||
|
done = False
|
||||||
|
step_idx = 0
|
||||||
|
# env.controller.start_recording()
|
||||||
|
while not done and step_idx < max_steps:
|
||||||
|
response, actions = agent.predict(
|
||||||
|
instruction,
|
||||||
|
obs
|
||||||
|
)
|
||||||
|
for action in actions:
|
||||||
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
logger.info("Step %d: %s", step_idx + 1, action)
|
||||||
|
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||||
|
|
||||||
|
logger.info("Reward: %.2f", reward)
|
||||||
|
logger.info("Done: %s", done)
|
||||||
|
# Save screenshot and trajectory information
|
||||||
|
with open(
|
||||||
|
os.path.join(
|
||||||
|
example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||||
|
),
|
||||||
|
"wb",
|
||||||
|
) as _f:
|
||||||
|
_f.write(obs["screenshot"])
|
||||||
|
|
||||||
|
response.update(
|
||||||
|
{
|
||||||
|
"step_num": step_idx + 1,
|
||||||
|
"action_timestamp": action_timestamp,
|
||||||
|
"action": action,
|
||||||
|
"reward": reward,
|
||||||
|
"done": done,
|
||||||
|
"info": info,
|
||||||
|
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(response, ensure_ascii=False))
|
||||||
|
f.write("\n")
|
||||||
|
if done:
|
||||||
|
logger.info("The episode is done.")
|
||||||
|
break
|
||||||
|
step_idx += 1
|
||||||
|
result = env.evaluate()
|
||||||
|
logger.info("Result: %.2f", result)
|
||||||
|
scores.append(result)
|
||||||
|
with open(
|
||||||
|
os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write(f"{result}\n")
|
||||||
|
# env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(example, example_result_dir):
|
||||||
|
runtime_logger = logging.getLogger(f"desktopenv.example.{example['id']}")
|
||||||
|
runtime_logger.setLevel(logging.DEBUG)
|
||||||
|
runtime_logger.addHandler(
|
||||||
|
logging.FileHandler(os.path.join(example_result_dir, "runtime.log"))
|
||||||
|
)
|
||||||
|
return runtime_logger
|
||||||
|
|
||||||
@@ -0,0 +1,518 @@
|
|||||||
|
"""OSWorld's run.py with AgentS2."""
|
||||||
|
|
||||||
|
"""Script to run end-to-end evaluation on the benchmark.
|
||||||
|
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
from multiprocessing import Process, Manager, current_process, Queue
|
||||||
|
|
||||||
|
|
||||||
|
import lib_run_single
|
||||||
|
from desktop_env.desktop_env import DesktopEnv
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# Logger Configs {{{ #
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
|
||||||
|
stdout_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
# }}} Logger Configs #
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
|
# Global variables for signal handling
|
||||||
|
active_environments = []
|
||||||
|
processes = []
|
||||||
|
is_terminating = False
|
||||||
|
|
||||||
|
def distribute_tasks(test_all_meta: dict) -> list:
|
||||||
|
all_tasks = []
|
||||||
|
for domain, examples in test_all_meta.items():
|
||||||
|
for example_id in examples:
|
||||||
|
all_tasks.append((domain, example_id))
|
||||||
|
return all_tasks
|
||||||
|
|
||||||
|
def process_signal_handler(signum, frame, env_idx):
|
||||||
|
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||||
|
local_vars = frame.f_locals
|
||||||
|
active_environments = local_vars.get('active_environments', [])
|
||||||
|
for env in active_environments:
|
||||||
|
if env is not None:
|
||||||
|
try:
|
||||||
|
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||||
|
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list, engine_params, engine_params_for_grounding):
|
||||||
|
active_environments = []
|
||||||
|
env = None
|
||||||
|
try:
|
||||||
|
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
||||||
|
snapshot_name = None
|
||||||
|
region = getattr(args, 'region', None)
|
||||||
|
if args.provider_name == 'aws' and region is not None:
|
||||||
|
try:
|
||||||
|
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||||
|
screen_size = (args.screen_width, args.screen_height)
|
||||||
|
snapshot_name = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get snapshot_name from IMAGE_ID_MAP: {e}")
|
||||||
|
snapshot_name = None
|
||||||
|
from gui_agents.s3.agents.agent_s import AgentS3
|
||||||
|
from gui_agents.s3.agents.grounding import OSWorldACI
|
||||||
|
env = DesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
region=region,
|
||||||
|
snapshot_name=snapshot_name,
|
||||||
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
|
headless=args.headless,
|
||||||
|
os_type = "Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
enable_proxy=True,
|
||||||
|
client_password=getattr(args, 'client_password', ''),
|
||||||
|
)
|
||||||
|
grounding_agent = OSWorldACI(
|
||||||
|
env=env,
|
||||||
|
platform="linux",
|
||||||
|
engine_params_for_generation=engine_params,
|
||||||
|
engine_params_for_grounding=engine_params_for_grounding,
|
||||||
|
width=args.screen_width,
|
||||||
|
height=args.screen_height,
|
||||||
|
)
|
||||||
|
agent = AgentS3(
|
||||||
|
engine_params,
|
||||||
|
grounding_agent,
|
||||||
|
platform="linux",
|
||||||
|
)
|
||||||
|
|
||||||
|
active_environments.append(env)
|
||||||
|
logger.info(f"Process {current_process().name} started.")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = task_queue.get(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
domain, example_id = item
|
||||||
|
try:
|
||||||
|
config_file = os.path.join(
|
||||||
|
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||||
|
)
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
example = json.load(f)
|
||||||
|
instruction = example["instruction"]
|
||||||
|
example_result_dir = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
domain,
|
||||||
|
example_id,
|
||||||
|
)
|
||||||
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||||
|
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||||
|
logger.info(f"[{current_process().name}][Instruction]: {instruction}")
|
||||||
|
try:
|
||||||
|
lib_run_single.run_single_example(
|
||||||
|
agent,
|
||||||
|
env,
|
||||||
|
example,
|
||||||
|
args.max_steps,
|
||||||
|
instruction,
|
||||||
|
args,
|
||||||
|
example_result_dir,
|
||||||
|
shared_scores,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
try:
|
||||||
|
env.controller.end_recording(
|
||||||
|
os.path.join(example_result_dir, "recording.mp4")
|
||||||
|
)
|
||||||
|
except Exception as rec_e:
|
||||||
|
logger.error(f"Failed to end recording: {rec_e}")
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{"Error": f"{domain}/{example_id} - {e}"}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
finally:
|
||||||
|
logger.info(f"{current_process().name} cleaning up environment...")
|
||||||
|
try:
|
||||||
|
if env:
|
||||||
|
env.close()
|
||||||
|
logger.info(f"{current_process().name} environment closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
global is_terminating, active_environments, processes
|
||||||
|
if is_terminating:
|
||||||
|
return
|
||||||
|
is_terminating = True
|
||||||
|
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||||
|
for env in active_environments:
|
||||||
|
try:
|
||||||
|
logger.info(f"Closing environment...")
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Environment closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing environment: {e}")
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Sending termination signal to process {p.name}...")
|
||||||
|
p.terminate()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending termination signal to process: {e}")
|
||||||
|
time.sleep(1)
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Forcefully terminating process {p.name}...")
|
||||||
|
import signal as sig
|
||||||
|
os.kill(p.pid, sig.SIGKILL)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error forcefully terminating process: {e}")
|
||||||
|
logger.info("Shutdown complete. Exiting.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def config() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run end-to-end evaluation on the benchmark"
|
||||||
|
)
|
||||||
|
|
||||||
|
# environment config
|
||||||
|
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider_name", type=str, default="vmware",
|
||||||
|
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--headless", action="store_true", help="Run in headless machine"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--observation_type",
|
||||||
|
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
default="screenshot",
|
||||||
|
help="Observation type",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||||
|
parser.add_argument("--screen_width", type=int, default=1920)
|
||||||
|
parser.add_argument("--screen_height", type=int, default=1080)
|
||||||
|
parser.add_argument("--sleep_after_execution", type=float, default=1.0)
|
||||||
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
|
||||||
|
parser.add_argument("--domain", type=str, default="all")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||||
|
)
|
||||||
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--client_password", type=str, default="", help="Client password"
|
||||||
|
)
|
||||||
|
|
||||||
|
# agent config
|
||||||
|
parser.add_argument("--max_trajectory_length", type=int, default=8)
|
||||||
|
|
||||||
|
|
||||||
|
# lm config
|
||||||
|
parser.add_argument("--model_provider", type=str, default="openai")
|
||||||
|
parser.add_argument("--model", type=str, default="gpt-4o")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_url",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The URL of the main generation model API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_api_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The API key of the main generation model.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--model_temperature", type=float, default=None, help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||||
|
|
||||||
|
# grounding model config
|
||||||
|
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||||
|
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_api_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The API key of the grounding model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grounding_width",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Width of screenshot image after processor rescaling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grounding_height",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Height of screenshot image after processor rescaling",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||||
|
global processes
|
||||||
|
logger.info("Args: %s", args)
|
||||||
|
all_tasks = distribute_tasks(test_all_meta)
|
||||||
|
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||||
|
|
||||||
|
engine_params = {
|
||||||
|
"engine_type": args.model_provider,
|
||||||
|
"model": args.model,
|
||||||
|
"base_url": getattr(args, 'model_url', ''),
|
||||||
|
"api_key": getattr(args, 'model_api_key', ''),
|
||||||
|
"temperature": getattr(args, 'model_temperature', None),
|
||||||
|
}
|
||||||
|
engine_params_for_grounding = {
|
||||||
|
"engine_type": args.ground_provider,
|
||||||
|
"model": args.ground_model,
|
||||||
|
"base_url": getattr(args, 'ground_url', ''),
|
||||||
|
"api_key": getattr(args, 'ground_api_key', ''),
|
||||||
|
"grounding_width": args.grounding_width,
|
||||||
|
"grounding_height": args.grounding_height,
|
||||||
|
}
|
||||||
|
|
||||||
|
with Manager() as manager:
|
||||||
|
shared_scores = manager.list()
|
||||||
|
task_queue = manager.Queue()
|
||||||
|
for item in all_tasks:
|
||||||
|
task_queue.put(item)
|
||||||
|
num_envs = args.num_envs
|
||||||
|
processes = []
|
||||||
|
for i in range(num_envs):
|
||||||
|
p = Process(
|
||||||
|
target=run_env_tasks,
|
||||||
|
args=(task_queue, args, shared_scores, engine_params, engine_params_for_grounding),
|
||||||
|
name=f"EnvProcess-{i+1}"
|
||||||
|
)
|
||||||
|
p.daemon = True
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
alive_count = 0
|
||||||
|
for idx, p in enumerate(processes):
|
||||||
|
if not p.is_alive():
|
||||||
|
logger.warning(f"Process {p.name} died, restarting...")
|
||||||
|
new_p = Process(
|
||||||
|
target=run_env_tasks,
|
||||||
|
args=(task_queue, args, shared_scores, engine_params, engine_params_for_grounding),
|
||||||
|
name=f"EnvProcess-Restart-{idx+1}"
|
||||||
|
)
|
||||||
|
new_p.daemon = True
|
||||||
|
new_p.start()
|
||||||
|
processes[idx] = new_p
|
||||||
|
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||||
|
else:
|
||||||
|
alive_count += 1
|
||||||
|
if task_queue.empty():
|
||||||
|
logger.info("All tasks finished.")
|
||||||
|
break
|
||||||
|
if alive_count == 0:
|
||||||
|
logger.error("All processes died, exiting.")
|
||||||
|
break
|
||||||
|
time.sleep(5)
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Terminating process {p.name} due to error...")
|
||||||
|
p.terminate()
|
||||||
|
except Exception as term_e:
|
||||||
|
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||||
|
raise
|
||||||
|
scores = list(shared_scores)
|
||||||
|
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unfinished(
|
||||||
|
action_space, use_model, observation_type, result_dir, total_file_json
|
||||||
|
):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
finished = {}
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
finished[domain] = []
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
if example_id == "onboard":
|
||||||
|
continue
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" not in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
for file in os.listdir(example_path):
|
||||||
|
os.remove(os.path.join(example_path, file))
|
||||||
|
else:
|
||||||
|
finished[domain].append(example_id)
|
||||||
|
|
||||||
|
if not finished:
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
for domain, examples in finished.items():
|
||||||
|
if domain in total_file_json:
|
||||||
|
total_file_json[domain] = [
|
||||||
|
x for x in total_file_json[domain] if x not in examples
|
||||||
|
]
|
||||||
|
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
|
||||||
|
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_result = []
|
||||||
|
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
try:
|
||||||
|
all_result.append(
|
||||||
|
float(
|
||||||
|
open(
|
||||||
|
os.path.join(example_path, "result.txt"), "r"
|
||||||
|
).read()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
all_result.append(0.0)
|
||||||
|
|
||||||
|
if not all_result:
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||||
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
####### The complete version of the list of examples #######
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
args = config()
|
||||||
|
|
||||||
|
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||||
|
path_to_args = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
"args.json",
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||||
|
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(vars(args), f, indent=4)
|
||||||
|
|
||||||
|
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||||
|
test_all_meta = json.load(f)
|
||||||
|
|
||||||
|
if args.domain != "all":
|
||||||
|
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||||
|
|
||||||
|
test_file_list = get_unfinished(
|
||||||
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
left_info = ""
|
||||||
|
for domain in test_file_list:
|
||||||
|
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||||
|
logger.info(f"Left tasks:\n{left_info}")
|
||||||
|
|
||||||
|
get_result(
|
||||||
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
test(args, test_file_list)
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# Step 1: Complete 2 or more rollouts on either AWS or locally
|
||||||
|
python run.py \
|
||||||
|
--provider_name "aws" \
|
||||||
|
--headless \
|
||||||
|
--num_envs 10 \
|
||||||
|
--max_steps 100 \
|
||||||
|
--domain "all" \
|
||||||
|
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||||
|
--result_dir "results" \
|
||||||
|
--region "us-east-1" \
|
||||||
|
--model_provider "openai" \
|
||||||
|
--model "gpt-5-2025-08-07" \
|
||||||
|
--model_temperature 1.0 \
|
||||||
|
--ground_provider "huggingface" \
|
||||||
|
--ground_url "<YOUR_HUGGINGFACE_ENDPOINT_URL>/v1" \
|
||||||
|
--grounding_width 1920 \
|
||||||
|
--grounding_height 1080 \
|
||||||
|
--sleep_after_execution 3
|
||||||
|
|
||||||
|
python run_local.py \
|
||||||
|
--path_to_vm "/Users/user/OSWorld/vmware_vm_data/Ubuntu0/Ubuntu0.vmx" \
|
||||||
|
--provider_name "vmware" \
|
||||||
|
--headless \
|
||||||
|
--max_steps 100 \
|
||||||
|
--domain "all" \
|
||||||
|
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||||
|
--result_dir "results" \
|
||||||
|
--model_provider "openai" \
|
||||||
|
--model "gpt-5-2025-08-07" \
|
||||||
|
--model_temperature 1.0 \
|
||||||
|
--ground_provider "huggingface" \
|
||||||
|
--ground_url "<YOUR_HUGGINGFACE_ENDPOINT_URL>/v1" \
|
||||||
|
--grounding_width 1920 \
|
||||||
|
--grounding_height 1080
|
||||||
|
|
||||||
|
# Step 2: Generate Facts
|
||||||
|
python generate_facts.py \
|
||||||
|
--results-dirs \
|
||||||
|
results1/pyautogui/screenshot/gpt-5-2025-08-07 \
|
||||||
|
results2/pyautogui/screenshot/gpt-5-2025-08-07 \
|
||||||
|
--model "gpt-5-2025-08-07" \
|
||||||
|
--engine-type "openai" \
|
||||||
|
--temperature 1.0
|
||||||
|
|
||||||
|
# Step 3: Run the Judge. Make sure the order of the results-dirs is the same as the order above.
|
||||||
|
python run_judge.py \
|
||||||
|
--results-dirs \
|
||||||
|
results1/pyautogui/screenshot/gpt-5-2025-08-07 \
|
||||||
|
results2/pyautogui/screenshot/gpt-5-2025-08-07 \
|
||||||
|
--output-dir "judge_results" \
|
||||||
|
--examples-path "evaluation_examples/examples" \
|
||||||
|
--model "gpt-5-2025-08-07" \
|
||||||
|
--engine-type "openai" \
|
||||||
|
--temperature 1.0
|
||||||
@@ -0,0 +1,399 @@
|
|||||||
|
"""Script to run end-to-end evaluation on the benchmark.
|
||||||
|
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import lib_run_single
|
||||||
|
from desktop_env.desktop_env import DesktopEnv
|
||||||
|
from gui_agents.s3.agents.agent_s import AgentS3
|
||||||
|
from gui_agents.s3.agents.grounding import OSWorldACI
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Almost deprecated since it's not multi-env, use run_multienv_*.py instead
|
||||||
|
|
||||||
|
# Logger Configs {{{ #
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
debug_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
sdebug_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
debug_handler.setLevel(logging.DEBUG)
|
||||||
|
stdout_handler.setLevel(logging.INFO)
|
||||||
|
sdebug_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
debug_handler.setFormatter(formatter)
|
||||||
|
stdout_handler.setFormatter(formatter)
|
||||||
|
sdebug_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(debug_handler)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
logger.addHandler(sdebug_handler)
|
||||||
|
# }}} Logger Configs #
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
|
def config() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run end-to-end evaluation on the benchmark"
|
||||||
|
)
|
||||||
|
|
||||||
|
# environment config
|
||||||
|
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider_name", type=str, default="vmware",
|
||||||
|
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--headless", action="store_true", help="Run in headless machine"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--observation_type",
|
||||||
|
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
default="screenshot",
|
||||||
|
help="Observation type",
|
||||||
|
)
|
||||||
|
parser.add_argument("--screen_width", type=int, default=1920)
|
||||||
|
parser.add_argument("--screen_height", type=int, default=1080)
|
||||||
|
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
|
||||||
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
|
||||||
|
# agent config
|
||||||
|
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||||
|
)
|
||||||
|
|
||||||
|
# lm config
|
||||||
|
parser.add_argument("--model", type=str, default="gpt-4o")
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
|
|
||||||
|
# AgentS2 specific config
|
||||||
|
parser.add_argument("--model_provider", type=str, default="openai")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_url",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The URL of the main generation model API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_api_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The API key of the main generation model.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--model_temperature", type=float, default=None, help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||||
|
|
||||||
|
# grounding model config
|
||||||
|
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||||
|
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_api_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The API key of the grounding model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grounding_width",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Width of screenshot image after processor rescaling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grounding_height",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Height of screenshot image after processor rescaling",
|
||||||
|
)
|
||||||
|
|
||||||
|
# example config
|
||||||
|
parser.add_argument("--domain", type=str, default="all")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
# logging related
|
||||||
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||||
|
scores = []
|
||||||
|
max_steps = args.max_steps
|
||||||
|
|
||||||
|
# log args
|
||||||
|
logger.info("Args: %s", args)
|
||||||
|
# set wandb project
|
||||||
|
cfg_args = {
|
||||||
|
"path_to_vm": args.path_to_vm,
|
||||||
|
"provider_name": args.provider_name,
|
||||||
|
"headless": args.headless,
|
||||||
|
"action_space": args.action_space,
|
||||||
|
"observation_type": args.observation_type,
|
||||||
|
"screen_width": args.screen_width,
|
||||||
|
"screen_height": args.screen_height,
|
||||||
|
"sleep_after_execution": args.sleep_after_execution,
|
||||||
|
"max_steps": args.max_steps,
|
||||||
|
"max_trajectory_length": args.max_trajectory_length,
|
||||||
|
"model": args.model,
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"result_dir": args.result_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
# AgentS2 configuration
|
||||||
|
engine_params = {
|
||||||
|
"engine_type": args.model_provider,
|
||||||
|
"model": args.model,
|
||||||
|
"base_url": getattr(args, 'model_url', ''),
|
||||||
|
"api_key": getattr(args, 'model_api_key', ''),
|
||||||
|
"temperature": getattr(args, 'model_temperature', None),
|
||||||
|
}
|
||||||
|
engine_params_for_grounding = {
|
||||||
|
"engine_type": args.ground_provider,
|
||||||
|
"model": args.ground_model,
|
||||||
|
"base_url": getattr(args, 'ground_url', ''),
|
||||||
|
"api_key": getattr(args, 'ground_api_key', ''),
|
||||||
|
"grounding_width": args.grounding_width,
|
||||||
|
"grounding_height": args.grounding_height,
|
||||||
|
}
|
||||||
|
|
||||||
|
env = DesktopEnv(
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
|
headless=args.headless,
|
||||||
|
os_type = "Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type
|
||||||
|
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
enable_proxy=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
grounding_agent = OSWorldACI(
|
||||||
|
env=env,
|
||||||
|
platform="linux",
|
||||||
|
engine_params_for_generation=engine_params,
|
||||||
|
engine_params_for_grounding=engine_params_for_grounding,
|
||||||
|
width=args.screen_width,
|
||||||
|
height=args.screen_height,
|
||||||
|
)
|
||||||
|
agent = AgentS3(
|
||||||
|
engine_params,
|
||||||
|
grounding_agent,
|
||||||
|
platform="linux",
|
||||||
|
)
|
||||||
|
|
||||||
|
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||||
|
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||||
|
config_file = os.path.join(
|
||||||
|
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||||
|
)
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
example = json.load(f)
|
||||||
|
|
||||||
|
logger.info(f"[Domain]: {domain}")
|
||||||
|
logger.info(f"[Example ID]: {example_id}")
|
||||||
|
|
||||||
|
instruction = example["instruction"]
|
||||||
|
|
||||||
|
logger.info(f"[Instruction]: {instruction}")
|
||||||
|
# wandb each example config settings
|
||||||
|
cfg_args["instruction"] = instruction
|
||||||
|
cfg_args["start_time"] = datetime.datetime.now().strftime(
|
||||||
|
"%Y:%m:%d-%H:%M:%S"
|
||||||
|
)
|
||||||
|
# run.config.update(cfg_args)
|
||||||
|
|
||||||
|
example_result_dir = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
domain,
|
||||||
|
example_id,
|
||||||
|
)
|
||||||
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
# example start running
|
||||||
|
try:
|
||||||
|
lib_run_single.run_single_example(
|
||||||
|
agent,
|
||||||
|
env,
|
||||||
|
example,
|
||||||
|
max_steps,
|
||||||
|
instruction,
|
||||||
|
args,
|
||||||
|
example_result_dir,
|
||||||
|
scores,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||||
|
# Only attempt to end recording if controller exists (not Docker provider)
|
||||||
|
if hasattr(env, 'controller') and env.controller is not None:
|
||||||
|
env.controller.end_recording(
|
||||||
|
os.path.join(example_result_dir, "recording.mp4")
|
||||||
|
)
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unfinished(
|
||||||
|
action_space, use_model, observation_type, result_dir, total_file_json
|
||||||
|
):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
finished = {}
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
finished[domain] = []
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
if example_id == "onboard":
|
||||||
|
continue
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" not in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
for file in os.listdir(example_path):
|
||||||
|
os.remove(os.path.join(example_path, file))
|
||||||
|
else:
|
||||||
|
finished[domain].append(example_id)
|
||||||
|
|
||||||
|
if not finished:
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
for domain, examples in finished.items():
|
||||||
|
if domain in total_file_json:
|
||||||
|
total_file_json[domain] = [
|
||||||
|
x for x in total_file_json[domain] if x not in examples
|
||||||
|
]
|
||||||
|
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
|
||||||
|
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_result = []
|
||||||
|
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
try:
|
||||||
|
all_result.append(
|
||||||
|
float(
|
||||||
|
open(
|
||||||
|
os.path.join(example_path, "result.txt"), "r"
|
||||||
|
).read()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
all_result.append(0.0)
|
||||||
|
|
||||||
|
if not all_result:
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||||
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
####### The complete version of the list of examples #######
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
args = config()
|
||||||
|
|
||||||
|
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||||
|
path_to_args = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
"args.json",
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||||
|
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(vars(args), f, indent=4)
|
||||||
|
|
||||||
|
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||||
|
test_all_meta = json.load(f)
|
||||||
|
|
||||||
|
if args.domain != "all":
|
||||||
|
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||||
|
|
||||||
|
test_file_list = get_unfinished(
|
||||||
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
left_info = ""
|
||||||
|
for domain in test_file_list:
|
||||||
|
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||||
|
logger.info(f"Left tasks:\n{left_info}")
|
||||||
|
|
||||||
|
get_result(
|
||||||
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
test(args, test_file_list)
|
||||||
Referência em uma Nova Issue
Bloquear um usuário