format
Esse commit está contido em:
@@ -17,6 +17,7 @@ current_platform = platform.system().lower()
|
||||
# Global flag to track pause state for debugging
|
||||
paused = False
|
||||
|
||||
|
||||
def get_char():
|
||||
"""Get a single character from stdin without pressing Enter"""
|
||||
try:
|
||||
@@ -24,6 +25,7 @@ def get_char():
|
||||
if platform.system() in ["Darwin", "Linux"]:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
@@ -35,14 +37,16 @@ def get_char():
|
||||
else:
|
||||
# Windows fallback
|
||||
import msvcrt
|
||||
return msvcrt.getch().decode('utf-8', errors='ignore')
|
||||
|
||||
return msvcrt.getch().decode("utf-8", errors="ignore")
|
||||
except:
|
||||
return input() # Fallback for non-terminal environments
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle Ctrl+C signal for debugging during agent execution"""
|
||||
global paused
|
||||
|
||||
|
||||
if not paused:
|
||||
print("\n\n🔸 Agent-S Workflow Paused 🔸")
|
||||
print("=" * 50)
|
||||
@@ -50,14 +54,14 @@ def signal_handler(signum, frame):
|
||||
print(" • Press Ctrl+C again to quit")
|
||||
print(" • Press Esc to resume workflow")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
paused = True
|
||||
|
||||
|
||||
while paused:
|
||||
try:
|
||||
print("\n[PAUSED] Waiting for input... ", end="", flush=True)
|
||||
char = get_char()
|
||||
|
||||
|
||||
if ord(char) == 3: # Ctrl+C
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
@@ -67,7 +71,7 @@ def signal_handler(signum, frame):
|
||||
break
|
||||
else:
|
||||
print(f"\n Unknown command: '{char}' (ord: {ord(char)})")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
@@ -76,6 +80,7 @@ def signal_handler(signum, frame):
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# Set up signal handler for Ctrl+C
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -175,7 +180,7 @@ def run_agent(agent: UIAgent, instruction: str):
|
||||
time.sleep(0.1)
|
||||
|
||||
print(f"\n🔄 Step {step + 1}/15: Getting next action from agent...")
|
||||
|
||||
|
||||
# Get next action code from the agent
|
||||
info, code = agent.predict(instruction=instruction, observation=obs)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ current_platform = platform.system().lower()
|
||||
# Global flag to track pause state for debugging
|
||||
paused = False
|
||||
|
||||
|
||||
def get_char():
|
||||
"""Get a single character from stdin without pressing Enter"""
|
||||
try:
|
||||
@@ -26,6 +27,7 @@ def get_char():
|
||||
if platform.system() in ["Darwin", "Linux"]:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
@@ -37,14 +39,16 @@ def get_char():
|
||||
else:
|
||||
# Windows fallback
|
||||
import msvcrt
|
||||
return msvcrt.getch().decode('utf-8', errors='ignore')
|
||||
|
||||
return msvcrt.getch().decode("utf-8", errors="ignore")
|
||||
except:
|
||||
return input() # Fallback for non-terminal environments
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle Ctrl+C signal for debugging during agent execution"""
|
||||
global paused
|
||||
|
||||
|
||||
if not paused:
|
||||
print("\n\n🔸 Agent-S Workflow Paused 🔸")
|
||||
print("=" * 50)
|
||||
@@ -52,14 +56,14 @@ def signal_handler(signum, frame):
|
||||
print(" • Press Ctrl+C again to quit")
|
||||
print(" • Press Esc to resume workflow")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
paused = True
|
||||
|
||||
|
||||
while paused:
|
||||
try:
|
||||
print("\n[PAUSED] Waiting for input... ", end="", flush=True)
|
||||
char = get_char()
|
||||
|
||||
|
||||
if ord(char) == 3: # Ctrl+C
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
@@ -69,7 +73,7 @@ def signal_handler(signum, frame):
|
||||
break
|
||||
else:
|
||||
print(f"\n Unknown command: '{char}' (ord: {ord(char)})")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
@@ -78,6 +82,7 @@ def signal_handler(signum, frame):
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# Set up signal handler for Ctrl+C
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -155,7 +160,7 @@ def run_agent(agent, instruction: str, scaled_width: int, scaled_height: int):
|
||||
# Check if we're in paused state and wait
|
||||
while paused:
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
# Get screen shot using pyautogui
|
||||
screenshot = pyautogui.screenshot()
|
||||
screenshot = screenshot.resize((scaled_width, scaled_height), Image.LANCZOS)
|
||||
@@ -174,7 +179,7 @@ def run_agent(agent, instruction: str, scaled_width: int, scaled_height: int):
|
||||
time.sleep(0.1)
|
||||
|
||||
print(f"\n🔄 Step {step + 1}/15: Getting next action from agent...")
|
||||
|
||||
|
||||
# Get next action code from the agent
|
||||
info, code = agent.predict(instruction=instruction, observation=obs)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ current_platform = platform.system().lower()
|
||||
# Global flag to track pause state for debugging
|
||||
paused = False
|
||||
|
||||
|
||||
def get_char():
|
||||
"""Get a single character from stdin without pressing Enter"""
|
||||
try:
|
||||
@@ -26,6 +27,7 @@ def get_char():
|
||||
if platform.system() in ["Darwin", "Linux"]:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
@@ -37,14 +39,16 @@ def get_char():
|
||||
else:
|
||||
# Windows fallback
|
||||
import msvcrt
|
||||
return msvcrt.getch().decode('utf-8', errors='ignore')
|
||||
|
||||
return msvcrt.getch().decode("utf-8", errors="ignore")
|
||||
except:
|
||||
return input() # Fallback for non-terminal environments
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle Ctrl+C signal for debugging during agent execution"""
|
||||
global paused
|
||||
|
||||
|
||||
if not paused:
|
||||
print("\n\n🔸 Agent-S Workflow Paused 🔸")
|
||||
print("=" * 50)
|
||||
@@ -52,14 +56,14 @@ def signal_handler(signum, frame):
|
||||
print(" • Press Ctrl+C again to quit")
|
||||
print(" • Press Esc to resume workflow")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
paused = True
|
||||
|
||||
|
||||
while paused:
|
||||
try:
|
||||
print("\n[PAUSED] Waiting for input... ", end="", flush=True)
|
||||
char = get_char()
|
||||
|
||||
|
||||
if ord(char) == 3: # Ctrl+C
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
@@ -69,7 +73,7 @@ def signal_handler(signum, frame):
|
||||
break
|
||||
else:
|
||||
print(f"\n Unknown command: '{char}' (ord: {ord(char)})")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
@@ -78,6 +82,7 @@ def signal_handler(signum, frame):
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# Set up signal handler for Ctrl+C
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -173,7 +178,7 @@ def run_agent(agent, instruction: str, scaled_width: int, scaled_height: int):
|
||||
time.sleep(0.1)
|
||||
|
||||
print(f"\n🔄 Step {step + 1}/15: Getting next action from agent...")
|
||||
|
||||
|
||||
# Get next action code from the agent
|
||||
info, code = agent.predict(instruction=instruction, observation=obs)
|
||||
|
||||
@@ -249,7 +254,7 @@ def main():
|
||||
"--model_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)"
|
||||
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
|
||||
# Grounding model config: Self-hosted endpoint based (required)
|
||||
@@ -318,7 +323,7 @@ def main():
|
||||
"model": args.model,
|
||||
"base_url": args.model_url,
|
||||
"api_key": args.model_api_key,
|
||||
"temperature": getattr(args, 'model_temperature', None),
|
||||
"temperature": getattr(args, "model_temperature", None),
|
||||
}
|
||||
|
||||
# Load the grounding engine from a custom endpoint
|
||||
|
||||
@@ -66,9 +66,7 @@ class AgentS3(UIAgent):
|
||||
enable_reflection: Creates a reflection agent to assist the worker agent
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
worker_engine_params, grounding_agent, platform
|
||||
)
|
||||
super().__init__(worker_engine_params, grounding_agent, platform)
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.enable_reflection = enable_reflection
|
||||
|
||||
@@ -91,12 +89,6 @@ class AgentS3(UIAgent):
|
||||
)
|
||||
|
||||
# concatenate the three info dictionaries
|
||||
info = {
|
||||
**{
|
||||
k: v
|
||||
for d in [executor_info or {}]
|
||||
for k, v in d.items()
|
||||
}
|
||||
}
|
||||
info = {**{k: v for d in [executor_info or {}] for k, v in d.items()}}
|
||||
|
||||
return info, actions
|
||||
|
||||
@@ -7,6 +7,7 @@ from gui_agents.s3.core.mllm import LMMAgent
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
def extract_code_block(action: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Extract code and determine type from action string."""
|
||||
if "```python" in action:
|
||||
@@ -21,8 +22,10 @@ def extract_code_block(action: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
else:
|
||||
code_type = None
|
||||
code = None
|
||||
|
||||
logger.debug(f"Extracted code block: type={code_type}, length={len(code) if code else 0}")
|
||||
|
||||
logger.debug(
|
||||
f"Extracted code block: type={code_type}, length={len(code) if code else 0}"
|
||||
)
|
||||
return code_type, code
|
||||
|
||||
|
||||
@@ -30,7 +33,7 @@ def execute_code(code_type: str, code: str, env_controller) -> Dict:
|
||||
"""Execute code based on its type."""
|
||||
# Log the full code being executed (untruncated)
|
||||
logger.info(f"CODING_AGENT_CODE_EXECUTION - Type: {code_type}\nCode:\n{code}")
|
||||
|
||||
|
||||
try:
|
||||
if code_type == "bash":
|
||||
result = env_controller.run_bash_script(code, timeout=30)
|
||||
@@ -38,9 +41,9 @@ def execute_code(code_type: str, code: str, env_controller) -> Dict:
|
||||
result = env_controller.run_python_script(code)
|
||||
else:
|
||||
result = {"status": "error", "error": f"Unknown code type: {code_type}"}
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing {code_type} code: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
@@ -54,97 +57,99 @@ def format_result(result: Dict, step_count: int) -> str:
|
||||
Step {step_count + 1} Error:
|
||||
Error: No result returned from execution
|
||||
"""
|
||||
|
||||
status = result.get('status', 'unknown')
|
||||
return_code = result.get('returncode', result.get('return_code', -1))
|
||||
|
||||
|
||||
status = result.get("status", "unknown")
|
||||
return_code = result.get("returncode", result.get("return_code", -1))
|
||||
|
||||
# Handle different response structures for bash vs python
|
||||
if 'returncode' in result:
|
||||
if "returncode" in result:
|
||||
# Bash script response
|
||||
output = result.get('output', '') # Contains both stdout and stderr merged
|
||||
error = result.get('error', '') # Always empty for bash
|
||||
output = result.get("output", "") # Contains both stdout and stderr merged
|
||||
error = result.get("error", "") # Always empty for bash
|
||||
else:
|
||||
# Python script response
|
||||
output = result.get('output', '') # stdout only
|
||||
error = result.get('error', '') # stderr only
|
||||
|
||||
output = result.get("output", "") # stdout only
|
||||
error = result.get("error", "") # stderr only
|
||||
|
||||
logger.debug(f"Step {step_count + 1}: Status={status}, Return Code={return_code}")
|
||||
|
||||
|
||||
# Format with better structure for multi-line outputs
|
||||
result_text = f"Step {step_count + 1} Result:\n"
|
||||
result_text += f"Status: {status}\n"
|
||||
result_text += f"Return Code: {return_code}\n"
|
||||
|
||||
|
||||
if output:
|
||||
result_text += f"Output:\n{output}\n"
|
||||
|
||||
|
||||
if error:
|
||||
result_text += f"Error:\n{error}\n"
|
||||
|
||||
|
||||
return result_text
|
||||
|
||||
|
||||
class CodeAgent:
|
||||
"""A dedicated agent for executing code with a budget of steps."""
|
||||
|
||||
|
||||
def __init__(self, engine_params: Dict, budget: int = 20):
|
||||
"""Initialize the CodeAgent."""
|
||||
if not engine_params:
|
||||
raise ValueError("engine_params cannot be None or empty")
|
||||
|
||||
|
||||
self.engine_params = engine_params
|
||||
self.budget = budget
|
||||
self.agent = None
|
||||
|
||||
|
||||
logger.info(f"CodeAgent initialized with budget={budget}")
|
||||
self.reset()
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the code agent state."""
|
||||
logger.debug("Resetting CodeAgent state")
|
||||
self.agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.CODE_AGENT_PROMPT
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.CODE_AGENT_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
def execute(self, task_instruction: str, screenshot: str, env_controller) -> Dict:
|
||||
"""Execute code for the given task with a budget of steps."""
|
||||
logger.info(f"Starting code execution for task: {task_instruction}")
|
||||
logger.info(f"Budget: {self.budget} steps")
|
||||
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
# Add initial task instruction and screenshot context as user message
|
||||
context = f"Task: {task_instruction}\n\nCurrent screenshot is provided for context."
|
||||
context = (
|
||||
f"Task: {task_instruction}\n\nCurrent screenshot is provided for context."
|
||||
)
|
||||
self.agent.add_message(context, image_content=screenshot, role="user")
|
||||
|
||||
|
||||
step_count = 0
|
||||
execution_history = []
|
||||
|
||||
|
||||
while step_count < self.budget:
|
||||
logger.info(f"Step {step_count + 1}/{self.budget}")
|
||||
|
||||
|
||||
# Get assistant response (thoughts and code)
|
||||
response = call_llm_safe(self.agent, temperature=1)
|
||||
|
||||
|
||||
# Log the latest message from the coding agent (untruncated)
|
||||
logger.info(f"CODING_AGENT_LATEST_MESSAGE - Step {step_count + 1}:\n{response}")
|
||||
|
||||
logger.info(
|
||||
f"CODING_AGENT_LATEST_MESSAGE - Step {step_count + 1}:\n{response}"
|
||||
)
|
||||
|
||||
# Check if response is None or empty
|
||||
if not response or response.strip() == "":
|
||||
error_msg = f"Step {step_count + 1}: LLM returned empty response"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
# Parse the response to extract action
|
||||
action, thoughts = split_thinking_response(response)
|
||||
|
||||
execution_history.append({
|
||||
"step": step_count + 1,
|
||||
"action": action,
|
||||
"thoughts": thoughts
|
||||
})
|
||||
|
||||
|
||||
execution_history.append(
|
||||
{"step": step_count + 1, "action": action, "thoughts": thoughts}
|
||||
)
|
||||
|
||||
# Check for completion signals
|
||||
action_upper = action.upper().strip()
|
||||
if action_upper == "DONE":
|
||||
@@ -155,10 +160,10 @@ class CodeAgent:
|
||||
logger.info(f"Step {step_count + 1}: Task failed by agent request")
|
||||
completion_reason = "FAIL"
|
||||
break
|
||||
|
||||
|
||||
# Extract and execute code
|
||||
code_type, code = extract_code_block(action)
|
||||
|
||||
|
||||
if code:
|
||||
result = execute_code(code_type, code, env_controller)
|
||||
# Prepare formatted output and error for logging
|
||||
@@ -173,11 +178,17 @@ class CodeAgent:
|
||||
]
|
||||
|
||||
if output:
|
||||
log_lines.append("Output:\n" + ("-" * 40) + f"\n{output}\n" + ("-" * 40))
|
||||
log_lines.append(
|
||||
"Output:\n" + ("-" * 40) + f"\n{output}\n" + ("-" * 40)
|
||||
)
|
||||
if error:
|
||||
log_lines.append("Error:\n" + ("!" * 40) + f"\n{error}\n" + ("!" * 40))
|
||||
log_lines.append(
|
||||
"Error:\n" + ("!" * 40) + f"\n{error}\n" + ("!" * 40)
|
||||
)
|
||||
if message and not output and not error:
|
||||
log_lines.append("Message:\n" + ("-" * 40) + f"\n{message}\n" + ("-" * 40))
|
||||
log_lines.append(
|
||||
"Message:\n" + ("-" * 40) + f"\n{message}\n" + ("-" * 40)
|
||||
)
|
||||
|
||||
# Remove None entries and join
|
||||
formatted_log = "\n".join([line for line in log_lines if line])
|
||||
@@ -192,55 +203,57 @@ class CodeAgent:
|
||||
)
|
||||
# Add assistant's thoughts and code to message history
|
||||
self.agent.add_message(response, role="assistant")
|
||||
|
||||
|
||||
# Process result and add formatted environment results as user message
|
||||
result_context = format_result(result, step_count)
|
||||
self.agent.add_message(result_context, role="user")
|
||||
|
||||
|
||||
step_count += 1
|
||||
|
||||
|
||||
# Handle budget exhaustion
|
||||
if 'completion_reason' not in locals():
|
||||
if "completion_reason" not in locals():
|
||||
logger.info(f"Budget exhausted after {step_count} steps")
|
||||
completion_reason = f"BUDGET_EXHAUSTED_AFTER_{step_count}_STEPS"
|
||||
|
||||
|
||||
# Generate final summary
|
||||
logger.info("Generating execution summary")
|
||||
summary = self._generate_summary(execution_history, task_instruction)
|
||||
|
||||
|
||||
result = {
|
||||
"task_instruction": task_instruction,
|
||||
"completion_reason": completion_reason,
|
||||
"summary": summary,
|
||||
"execution_history": execution_history,
|
||||
"steps_executed": step_count,
|
||||
"budget": self.budget
|
||||
"budget": self.budget,
|
||||
}
|
||||
|
||||
|
||||
logger.info(f"Code execution completed: steps={step_count}")
|
||||
return result
|
||||
|
||||
def _generate_summary(self, execution_history: List[Dict], task_instruction: str) -> str:
|
||||
|
||||
def _generate_summary(
|
||||
self, execution_history: List[Dict], task_instruction: str
|
||||
) -> str:
|
||||
"""Generate summary of code execution session."""
|
||||
if not execution_history:
|
||||
logger.info("No execution history to summarize")
|
||||
return "No actions were executed."
|
||||
|
||||
|
||||
logger.info(f"Generated summary for {len(execution_history)} steps")
|
||||
|
||||
|
||||
# Build detailed execution context for summary agent
|
||||
execution_context = f"Task: {task_instruction}\n\nExecution Steps:\n"
|
||||
|
||||
|
||||
for step in execution_history:
|
||||
step_num = step['step']
|
||||
thoughts = step.get('thoughts', '')
|
||||
action = step.get('action', '')
|
||||
|
||||
step_num = step["step"]
|
||||
thoughts = step.get("thoughts", "")
|
||||
action = step.get("action", "")
|
||||
|
||||
execution_context += f"\nStep {step_num}:\n"
|
||||
if thoughts:
|
||||
execution_context += f"Thoughts: {thoughts}\n"
|
||||
execution_context += f"Code: {action}\n"
|
||||
|
||||
|
||||
# Create summary prompt with same context as coding agent
|
||||
summary_prompt = f"""
|
||||
{execution_context}
|
||||
@@ -255,24 +268,22 @@ Do not make judgments about success or failure. Simply describe what was attempt
|
||||
|
||||
Keep the summary under 150 words and use clear, factual language.
|
||||
"""
|
||||
|
||||
|
||||
# Generate summary using LLM with dedicated summary system prompt
|
||||
try:
|
||||
summary_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.CODE_SUMMARY_AGENT_PROMPT
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.CODE_SUMMARY_AGENT_PROMPT,
|
||||
)
|
||||
summary_agent.add_message(summary_prompt, role="user")
|
||||
summary = call_llm_safe(summary_agent, temperature=1)
|
||||
|
||||
|
||||
if not summary or summary.strip() == "":
|
||||
summary = "Summary generation failed - no response from LLM"
|
||||
logger.warning("Summary generation failed - empty response from LLM")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
summary = f"Summary generation failed: {str(e)}"
|
||||
logger.error(f"Error generating summary: {e}")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
return summary
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from io import BytesIO
|
||||
@@ -215,11 +214,13 @@ class OSWorldACI(ACI):
|
||||
engine_params=engine_params_for_generation,
|
||||
system_prompt=PROCEDURAL_MEMORY.PHRASE_TO_WORD_COORDS_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
# Configure code agent
|
||||
code_agent_engine_params = code_agent_engine_params or engine_params_for_generation
|
||||
code_agent_engine_params = (
|
||||
code_agent_engine_params or engine_params_for_generation
|
||||
)
|
||||
self.code_agent = CodeAgent(code_agent_engine_params, code_agent_budget)
|
||||
|
||||
|
||||
# Store task instruction for code agent
|
||||
self.current_task_instruction = None
|
||||
self.last_code_agent_result = None
|
||||
@@ -382,7 +383,9 @@ class OSWorldACI(ACI):
|
||||
elif self.platform == "windows":
|
||||
return f"import pyautogui; import time; pyautogui.hotkey('win', 'd', interval=0.5); pyautogui.typewrite({repr(app_code)}); pyautogui.press('enter'); time.sleep(1.0)"
|
||||
else:
|
||||
assert False, f"Unsupported platform: {self.platform}. Supported platforms are: darwin, linux, windows."
|
||||
assert (
|
||||
False
|
||||
), f"Unsupported platform: {self.platform}. Supported platforms are: darwin, linux, windows."
|
||||
|
||||
@agent_action
|
||||
def open(self, app_or_filename: str):
|
||||
@@ -434,7 +437,7 @@ class OSWorldACI(ACI):
|
||||
|
||||
# Check if text contains Unicode characters that pyautogui.write() can't handle
|
||||
has_unicode = any(ord(char) > 127 for char in text)
|
||||
|
||||
|
||||
if has_unicode:
|
||||
# Use clipboard method for Unicode characters
|
||||
command += f"pyperclip.copy({repr(text)}); "
|
||||
@@ -486,14 +489,18 @@ class OSWorldACI(ACI):
|
||||
return command
|
||||
|
||||
@agent_action
|
||||
def highlight_text_span(self, starting_phrase: str, ending_phrase: str, button: str = "left"):
|
||||
def highlight_text_span(
|
||||
self, starting_phrase: str, ending_phrase: str, button: str = "left"
|
||||
):
|
||||
"""Highlight a text span between a provided starting phrase and ending phrase. Use this to highlight words, lines, and paragraphs.
|
||||
Args:
|
||||
starting_phrase:str, the phrase that denotes the start of the text span you want to highlight. If you only want to highlight one word, just pass in that single word.
|
||||
ending_phrase:str, the phrase that denotes the end of the text span you want to highlight. If you only want to highlight one word, just pass in that single word.
|
||||
button:str, the button to use to highlight the text span. Defaults to "left". Can be "left", "right", or "middle".
|
||||
"""
|
||||
coords1 = self.generate_text_coords(starting_phrase, self.obs, alignment="start")
|
||||
coords1 = self.generate_text_coords(
|
||||
starting_phrase, self.obs, alignment="start"
|
||||
)
|
||||
coords2 = self.generate_text_coords(ending_phrase, self.obs, alignment="end")
|
||||
x1, y1 = coords1
|
||||
x2, y2 = coords2
|
||||
@@ -523,16 +530,16 @@ class OSWorldACI(ACI):
|
||||
@agent_action
|
||||
def call_code_agent(self, task: str = None):
|
||||
"""Call the code agent to execute code for tasks or subtasks that can be completed solely with coding.
|
||||
|
||||
|
||||
Args:
|
||||
task: str, the task or subtask to execute. If None, uses the current full task instruction.
|
||||
|
||||
|
||||
**🚨 CRITICAL GUIDELINES:**
|
||||
- **ONLY pass a task parameter for SPECIFIC subtasks** (e.g., "Calculate sum of column B", "Filter data by date")
|
||||
- **NEVER pass a task parameter for full tasks** - let it default to the original task instruction
|
||||
- **NEVER rephrase or modify the original task** - this prevents hallucination corruption
|
||||
- **If unsure, omit the task parameter entirely** to use the original task instruction
|
||||
|
||||
|
||||
Use this for tasks that can be fully accomplished through code execution, particularly for:
|
||||
- Spreadsheet applications (LibreOffice Calc, Excel): data processing, filtering, sorting, calculations, formulas, data analysis
|
||||
- Document editors (LibreOffice Writer, Word): text processing, content editing, formatting, document manipulation
|
||||
@@ -544,7 +551,7 @@ class OSWorldACI(ACI):
|
||||
logger.info("=" * 50)
|
||||
logger.info("GROUNDING AGENT: Calling Code Agent")
|
||||
logger.info("=" * 50)
|
||||
|
||||
|
||||
# **CRITICAL**: Only use provided task for specific subtasks, otherwise use original task instruction
|
||||
if task is not None:
|
||||
# This is a subtask - use the provided task
|
||||
@@ -554,27 +561,29 @@ class OSWorldACI(ACI):
|
||||
# This is a full task - use the original task instruction to prevent hallucination
|
||||
task_to_execute = self.current_task_instruction
|
||||
logger.info(f"Executing FULL TASK: {task_to_execute}")
|
||||
|
||||
|
||||
if task_to_execute:
|
||||
print("obs keys: ", self.obs.keys())
|
||||
screenshot = self.obs.get('screenshot', '') if self.obs else ''
|
||||
screenshot = self.obs.get("screenshot", "") if self.obs else ""
|
||||
logger.info(f"Screenshot available: {'Yes' if screenshot else 'No'}")
|
||||
|
||||
|
||||
logger.info("Executing code agent...")
|
||||
result = self.code_agent.execute(task_to_execute, screenshot, self.env.controller)
|
||||
|
||||
result = self.code_agent.execute(
|
||||
task_to_execute, screenshot, self.env.controller
|
||||
)
|
||||
|
||||
# Store the result for the worker to access
|
||||
self.last_code_agent_result = result
|
||||
|
||||
|
||||
logger.info("Code agent execution completed")
|
||||
logger.info(f"Result - Completion reason: {result['completion_reason']}")
|
||||
logger.info(f"Steps executed: {result['steps_executed']}")
|
||||
logger.info(f"Summary: {result['summary']}")
|
||||
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("GROUNDING AGENT: Code Agent Call Finished")
|
||||
logger.info("=" * 50)
|
||||
|
||||
|
||||
# Return code to be executed in the environment
|
||||
return "import time; time.sleep(2.222)"
|
||||
else:
|
||||
|
||||
@@ -11,7 +11,7 @@ from gui_agents.s3.utils.common_utils import (
|
||||
call_llm_formatted,
|
||||
parse_code_from_string,
|
||||
split_thinking_response,
|
||||
create_pyautogui_code
|
||||
create_pyautogui_code,
|
||||
)
|
||||
from gui_agents.s3.utils.formatters import (
|
||||
SINGLE_ACTION_FORMATTER,
|
||||
@@ -70,7 +70,9 @@ class Worker(BaseModule):
|
||||
).replace("CURRENT_OS", self.platform)
|
||||
|
||||
self.generator_agent = self._create_agent(sys_prompt)
|
||||
self.reflection_agent = self._create_agent(PROCEDURAL_MEMORY.REFLECTION_ON_TRAJECTORY)
|
||||
self.reflection_agent = self._create_agent(
|
||||
PROCEDURAL_MEMORY.REFLECTION_ON_TRAJECTORY
|
||||
)
|
||||
|
||||
self.turn_count = 0
|
||||
self.worker_history = []
|
||||
@@ -80,7 +82,7 @@ class Worker(BaseModule):
|
||||
|
||||
def flush_messages(self):
|
||||
"""Flush messages based on the model's context limits.
|
||||
|
||||
|
||||
This method ensures that the agent's message history does not exceed the maximum trajectory length.
|
||||
|
||||
Side Effects:
|
||||
@@ -92,7 +94,8 @@ class Worker(BaseModule):
|
||||
if engine_type in ["anthropic", "openai", "gemini"]:
|
||||
max_images = self.max_trajectory_length
|
||||
for agent in [self.generator_agent, self.reflection_agent]:
|
||||
if agent is None: continue
|
||||
if agent is None:
|
||||
continue
|
||||
# keep latest k images
|
||||
img_count = 0
|
||||
for i in range(len(agent.messages) - 1, -1, -1):
|
||||
@@ -119,10 +122,10 @@ class Worker(BaseModule):
|
||||
Args:
|
||||
instruction (str): The task instruction.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[str, str]: The generated reflection text and thoughts, if any (turn_count > 0).
|
||||
|
||||
|
||||
Side Effects:
|
||||
- Updates reflection agent's history
|
||||
- Generates reflection response with API call
|
||||
@@ -166,7 +169,7 @@ class Worker(BaseModule):
|
||||
logger.info("REFLECTION THOUGHTS: %s", reflection_thoughts)
|
||||
logger.info("REFLECTION: %s", reflection)
|
||||
return reflection, reflection_thoughts
|
||||
|
||||
|
||||
def generate_next_action(self, instruction: str, obs: Dict) -> Tuple[Dict, List]:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
@@ -174,12 +177,18 @@ class Worker(BaseModule):
|
||||
|
||||
self.grounding_agent.assign_screenshot(obs)
|
||||
self.grounding_agent.set_task_instruction(instruction)
|
||||
|
||||
generator_message = "" if self.turn_count > 0 else "The initial screen is provided. No action has been taken yet."
|
||||
|
||||
generator_message = (
|
||||
""
|
||||
if self.turn_count > 0
|
||||
else "The initial screen is provided. No action has been taken yet."
|
||||
)
|
||||
|
||||
# Load the task into the system prompt
|
||||
if self.turn_count == 0:
|
||||
prompt_with_instructions = self.generator_agent.system_prompt.replace("TASK_DESCRIPTION", instruction)
|
||||
prompt_with_instructions = self.generator_agent.system_prompt.replace(
|
||||
"TASK_DESCRIPTION", instruction
|
||||
)
|
||||
self.generator_agent.add_system_prompt(prompt_with_instructions)
|
||||
|
||||
# Get the per-step reflection
|
||||
@@ -188,115 +197,140 @@ class Worker(BaseModule):
|
||||
generator_message += f"REFLECTION: You may use this reflection on the previous action and overall trajectory:\n{reflection}\n"
|
||||
|
||||
# Get the grounding agent's knowledge base buffer
|
||||
generator_message += f"\nCurrent Text Buffer = [{','.join(self.grounding_agent.notes)}]\n"
|
||||
generator_message += (
|
||||
f"\nCurrent Text Buffer = [{','.join(self.grounding_agent.notes)}]\n"
|
||||
)
|
||||
|
||||
# Add code agent result from previous step if available (from full task or subtask execution)
|
||||
if hasattr(self.grounding_agent, 'last_code_agent_result') and self.grounding_agent.last_code_agent_result is not None:
|
||||
if (
|
||||
hasattr(self.grounding_agent, "last_code_agent_result")
|
||||
and self.grounding_agent.last_code_agent_result is not None
|
||||
):
|
||||
code_result = self.grounding_agent.last_code_agent_result
|
||||
generator_message += f"\nCODE AGENT RESULT:\n"
|
||||
generator_message += f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||
generator_message += (
|
||||
f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||
)
|
||||
generator_message += f"Steps Completed: {code_result['steps_executed']}\n"
|
||||
generator_message += f"Max Steps: {code_result['budget']}\n"
|
||||
generator_message += f"Completion Reason: {code_result['completion_reason']}\n"
|
||||
generator_message += (
|
||||
f"Completion Reason: {code_result['completion_reason']}\n"
|
||||
)
|
||||
generator_message += f"Summary: {code_result['summary']}\n"
|
||||
if code_result['execution_history']:
|
||||
if code_result["execution_history"]:
|
||||
generator_message += f"Execution History:\n"
|
||||
for i, step in enumerate(code_result['execution_history']):
|
||||
action = step['action']
|
||||
for i, step in enumerate(code_result["execution_history"]):
|
||||
action = step["action"]
|
||||
# Format code snippets with proper backticks
|
||||
if '```python' in action:
|
||||
if "```python" in action:
|
||||
# Extract Python code and format it
|
||||
code_start = action.find('```python') + 9
|
||||
code_end = action.find('```', code_start)
|
||||
code_start = action.find("```python") + 9
|
||||
code_end = action.find("```", code_start)
|
||||
if code_end != -1:
|
||||
python_code = action[code_start:code_end].strip()
|
||||
generator_message += f"Step {i+1}: \n```python\n{python_code}\n```\n"
|
||||
generator_message += (
|
||||
f"Step {i+1}: \n```python\n{python_code}\n```\n"
|
||||
)
|
||||
else:
|
||||
generator_message += f"Step {i+1}: \n{action}\n"
|
||||
elif '```bash' in action:
|
||||
elif "```bash" in action:
|
||||
# Extract Bash code and format it
|
||||
code_start = action.find('```bash') + 7
|
||||
code_end = action.find('```', code_start)
|
||||
code_start = action.find("```bash") + 7
|
||||
code_end = action.find("```", code_start)
|
||||
if code_end != -1:
|
||||
bash_code = action[code_start:code_end].strip()
|
||||
generator_message += f"Step {i+1}: \n```bash\n{bash_code}\n```\n"
|
||||
generator_message += (
|
||||
f"Step {i+1}: \n```bash\n{bash_code}\n```\n"
|
||||
)
|
||||
else:
|
||||
generator_message += f"Step {i+1}: \n{action}\n"
|
||||
else:
|
||||
generator_message += f"Step {i+1}: \n{action}\n"
|
||||
generator_message += "\n"
|
||||
|
||||
|
||||
# Save code agent result to text file
|
||||
try:
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
logs_dir = "logs"
|
||||
if not os.path.exists(logs_dir):
|
||||
os.makedirs(logs_dir)
|
||||
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"logs/code_agent_result_step_{self.turn_count + 1}_{timestamp}.txt"
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
filename = (
|
||||
f"logs/code_agent_result_step_{self.turn_count + 1}_{timestamp}.txt"
|
||||
)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(f"CODE AGENT RESULT - Step {self.turn_count + 1}\n")
|
||||
f.write(f"Timestamp: {datetime.now().isoformat()}\n")
|
||||
f.write(f"Task/Subtask Instruction: {code_result['task_instruction']}\n")
|
||||
f.write(
|
||||
f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||
)
|
||||
f.write(f"Steps Completed: {code_result['steps_executed']}\n")
|
||||
f.write(f"Max Steps: {code_result['budget']}\n")
|
||||
f.write(f"Completion Reason: {code_result['completion_reason']}\n")
|
||||
f.write(f"Summary: {code_result['summary']}\n")
|
||||
if code_result['execution_history']:
|
||||
if code_result["execution_history"]:
|
||||
f.write(f"\nExecution History:\n")
|
||||
for i, step in enumerate(code_result['execution_history']):
|
||||
for i, step in enumerate(code_result["execution_history"]):
|
||||
f.write(f"\nStep {i+1}:\n")
|
||||
f.write(f"Action: {step['action']}\n")
|
||||
if 'thoughts' in step:
|
||||
if "thoughts" in step:
|
||||
f.write(f"Thoughts: {step['thoughts']}\n")
|
||||
|
||||
|
||||
logger.info(f"Code agent result saved to: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save code agent result to file: {e}")
|
||||
|
||||
|
||||
# Log the code agent result section for debugging (truncated execution history)
|
||||
log_message = f"\nCODE AGENT RESULT:\n"
|
||||
log_message += f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||
log_message += (
|
||||
f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||
)
|
||||
log_message += f"Steps Completed: {code_result['steps_executed']}\n"
|
||||
log_message += f"Max Steps: {code_result['budget']}\n"
|
||||
log_message += f"Completion Reason: {code_result['completion_reason']}\n"
|
||||
log_message += f"Summary: {code_result['summary']}\n"
|
||||
if code_result['execution_history']:
|
||||
if code_result["execution_history"]:
|
||||
log_message += f"Execution History (truncated):\n"
|
||||
# Only log first 3 steps and last 2 steps to keep logs manageable
|
||||
total_steps = len(code_result['execution_history'])
|
||||
for i, step in enumerate(code_result['execution_history']):
|
||||
total_steps = len(code_result["execution_history"])
|
||||
for i, step in enumerate(code_result["execution_history"]):
|
||||
if i < 3 or i >= total_steps - 2: # First 3 and last 2 steps
|
||||
action = step['action']
|
||||
if '```python' in action:
|
||||
code_start = action.find('```python') + 9
|
||||
code_end = action.find('```', code_start)
|
||||
action = step["action"]
|
||||
if "```python" in action:
|
||||
code_start = action.find("```python") + 9
|
||||
code_end = action.find("```", code_start)
|
||||
if code_end != -1:
|
||||
python_code = action[code_start:code_end].strip()
|
||||
log_message += f"Step {i+1}: ```python\n{python_code}\n```\n"
|
||||
log_message += (
|
||||
f"Step {i+1}: ```python\n{python_code}\n```\n"
|
||||
)
|
||||
else:
|
||||
log_message += f"Step {i+1}: {action}\n"
|
||||
elif '```bash' in action:
|
||||
code_start = action.find('```bash') + 7
|
||||
code_end = action.find('```', code_start)
|
||||
elif "```bash" in action:
|
||||
code_start = action.find("```bash") + 7
|
||||
code_end = action.find("```", code_start)
|
||||
if code_end != -1:
|
||||
bash_code = action[code_start:code_end].strip()
|
||||
log_message += f"Step {i+1}: ```bash\n{bash_code}\n```\n"
|
||||
log_message += (
|
||||
f"Step {i+1}: ```bash\n{bash_code}\n```\n"
|
||||
)
|
||||
else:
|
||||
log_message += f"Step {i+1}: {action}\n"
|
||||
else:
|
||||
log_message += f"Step {i+1}: {action}\n"
|
||||
elif i == 3 and total_steps > 5:
|
||||
log_message += f"... (truncated {total_steps - 5} steps) ...\n"
|
||||
|
||||
logger.info(f"WORKER_CODE_AGENT_RESULT_SECTION - Step {self.turn_count + 1}: Code agent result added to generator message:\n{log_message}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"WORKER_CODE_AGENT_RESULT_SECTION - Step {self.turn_count + 1}: Code agent result added to generator message:\n{log_message}"
|
||||
)
|
||||
|
||||
# Reset the code agent result after adding it to context
|
||||
self.grounding_agent.last_code_agent_result = None
|
||||
|
||||
@@ -304,10 +338,18 @@ class Worker(BaseModule):
|
||||
self.generator_agent.add_message(
|
||||
generator_message, image_content=obs["screenshot"], role="user"
|
||||
)
|
||||
|
||||
|
||||
# Generate the plan and next action
|
||||
format_checkers = [SINGLE_ACTION_FORMATTER, partial(CODE_VALID_FORMATTER, self.grounding_agent, obs)]
|
||||
plan = call_llm_formatted(self.generator_agent, format_checkers, temperature=self.temperature, use_thinking=self.use_thinking)
|
||||
format_checkers = [
|
||||
SINGLE_ACTION_FORMATTER,
|
||||
partial(CODE_VALID_FORMATTER, self.grounding_agent, obs),
|
||||
]
|
||||
plan = call_llm_formatted(
|
||||
self.generator_agent,
|
||||
format_checkers,
|
||||
temperature=self.temperature,
|
||||
use_thinking=self.use_thinking,
|
||||
)
|
||||
self.worker_history.append(plan)
|
||||
self.generator_agent.add_message(plan, role="assistant")
|
||||
logger.info("PLAN:\n %s", plan)
|
||||
@@ -318,18 +360,27 @@ class Worker(BaseModule):
|
||||
assert plan_code, "Plan code should not be empty"
|
||||
exec_code = create_pyautogui_code(self.grounding_agent, plan_code, obs)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}")
|
||||
exec_code = self.grounding_agent.wait(1.333) # Skip a turn if the code cannot be evaluated
|
||||
|
||||
logger.error(
|
||||
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
|
||||
)
|
||||
exec_code = self.grounding_agent.wait(
|
||||
1.333
|
||||
) # Skip a turn if the code cannot be evaluated
|
||||
|
||||
executor_info = {
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"exec_code": exec_code,
|
||||
"reflection": reflection,
|
||||
"reflection_thoughts": reflection_thoughts,
|
||||
"code_agent_output": self.grounding_agent.last_code_agent_result if hasattr(self.grounding_agent, 'last_code_agent_result') and self.grounding_agent.last_code_agent_result is not None else None,
|
||||
"code_agent_output": (
|
||||
self.grounding_agent.last_code_agent_result
|
||||
if hasattr(self.grounding_agent, "last_code_agent_result")
|
||||
and self.grounding_agent.last_code_agent_result is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
self.turn_count += 1
|
||||
self.screenshot_inputs.append(obs["screenshot"])
|
||||
self.flush_messages()
|
||||
return executor_info, [exec_code]
|
||||
return executor_info, [exec_code]
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from gui_agents.s3.core.mllm import LMMAgent
|
||||
from gui_agents.s3.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from gui_agents.s3.utils.common_utils import call_llm_formatted, split_thinking_response, compress_image
|
||||
from gui_agents.s3.utils.common_utils import (
|
||||
call_llm_formatted,
|
||||
split_thinking_response,
|
||||
compress_image,
|
||||
)
|
||||
from gui_agents.s3.utils.formatters import (
|
||||
THOUGHTS_ANSWER_TAG_FORMATTER,
|
||||
)
|
||||
@@ -11,6 +15,7 @@ import base64
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BehaviorNarrator:
|
||||
def __init__(self, engine_params):
|
||||
self.judge_agent = LMMAgent(engine_params=engine_params)
|
||||
@@ -18,75 +23,101 @@ class BehaviorNarrator:
|
||||
@staticmethod
|
||||
def extract_mouse_action(action: str) -> list[str]:
|
||||
mouse_actions = []
|
||||
for sub_action in action.split(';'):
|
||||
for sub_action in action.split(";"):
|
||||
sub_action = sub_action.strip()
|
||||
if sub_action.startswith('pyautogui.click') or sub_action.startswith('pyautogui.moveTo') or sub_action.startswith('pyautogui.dragTo'):
|
||||
if (
|
||||
sub_action.startswith("pyautogui.click")
|
||||
or sub_action.startswith("pyautogui.moveTo")
|
||||
or sub_action.startswith("pyautogui.dragTo")
|
||||
):
|
||||
mouse_actions.append(sub_action)
|
||||
return mouse_actions
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mark_action(mouse_actions:list[str], img: Image):
|
||||
def mark_action(mouse_actions: list[str], img: Image):
|
||||
draw = ImageDraw.Draw(img)
|
||||
font = ImageFont.load_default(25)
|
||||
|
||||
drag_start_width, drag_start_height = None, None
|
||||
|
||||
for mouse_action in mouse_actions:
|
||||
width, height = mouse_action.split('(')[1].strip(')').split(', ')[:2]
|
||||
width, height = mouse_action.split("(")[1].strip(")").split(", ")[:2]
|
||||
width, height = int(width), int(height)
|
||||
|
||||
# Clamp coordinates within bounds
|
||||
width = max(0, min(img.width - 1, width ))
|
||||
width = max(0, min(img.width - 1, width))
|
||||
height = max(0, min(img.height - 1, height))
|
||||
|
||||
def place_text(label, color):
|
||||
bbox = draw.textbbox((0, 0), label, font=font)
|
||||
text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1] # Measure text size
|
||||
offset_x, offset_y = -5, 5 # Default offset
|
||||
if width + offset_x + text_w > img.width: # Out of bounds on right
|
||||
text_w, text_h = (
|
||||
bbox[2] - bbox[0],
|
||||
bbox[3] - bbox[1],
|
||||
) # Measure text size
|
||||
offset_x, offset_y = -5, 5 # Default offset
|
||||
if width + offset_x + text_w > img.width: # Out of bounds on right
|
||||
offset_x = -text_w - 5
|
||||
if height + offset_y + text_h > img.height: # Out of bounds on bottom
|
||||
if height + offset_y + text_h > img.height: # Out of bounds on bottom
|
||||
offset_y = -text_h - 5
|
||||
if width + offset_x < 0: # Out of bounds on left
|
||||
if width + offset_x < 0: # Out of bounds on left
|
||||
offset_x = 5
|
||||
if height + offset_y < 0: # Out of bounds on top
|
||||
if height + offset_y < 0: # Out of bounds on top
|
||||
offset_y = 5
|
||||
draw.text((width + offset_x, height + offset_y), label, fill=color, font=font)
|
||||
|
||||
if mouse_action.startswith('pyautogui.click'):
|
||||
draw.text(
|
||||
(width + offset_x, height + offset_y), label, fill=color, font=font
|
||||
)
|
||||
|
||||
if mouse_action.startswith("pyautogui.click"):
|
||||
draw.circle((width, height), radius=3, fill=(255, 0, 0))
|
||||
place_text("Click", (255, 0, 0))
|
||||
if mouse_action.startswith('pyautogui.moveTo'):
|
||||
if mouse_action.startswith("pyautogui.moveTo"):
|
||||
draw.circle((width, height), radius=3, fill=(0, 0, 255))
|
||||
place_text("MoveTo", (0, 0, 255))
|
||||
drag_start_height, drag_start_width = height, width
|
||||
if mouse_action.startswith('pyautogui.dragTo'):
|
||||
draw.line([(drag_start_width, drag_start_height), (width, height)], fill=(0, 255, 0), width=2)
|
||||
if mouse_action.startswith("pyautogui.dragTo"):
|
||||
draw.line(
|
||||
[(drag_start_width, drag_start_height), (width, height)],
|
||||
fill=(0, 255, 0),
|
||||
width=2,
|
||||
)
|
||||
draw.circle((width, height), radius=3, fill=(0, 255, 0))
|
||||
place_text("DragTo", (0, 255, 0))
|
||||
|
||||
@staticmethod
|
||||
def get_mouse_action_representation(mouse_actions:list[str]) -> str:
|
||||
def get_mouse_action_representation(mouse_actions: list[str]) -> str:
|
||||
"""
|
||||
Returns a string representation of the mouse action for the given action.
|
||||
"""
|
||||
assert len(mouse_actions) <= 2, f"Multiple mouse action types found: {mouse_actions}"
|
||||
assert (
|
||||
len(mouse_actions) <= 2
|
||||
), f"Multiple mouse action types found: {mouse_actions}"
|
||||
if len(mouse_actions) == 1:
|
||||
action = mouse_actions[0]
|
||||
if action.startswith('pyautogui.click'):
|
||||
if action.startswith("pyautogui.click"):
|
||||
return "The red circle labeled 'Click' marks the position where the mouse was clicked."
|
||||
elif action.startswith('pyautogui.moveTo'):
|
||||
elif action.startswith("pyautogui.moveTo"):
|
||||
return "The blue circle labeled 'MoveTo' marks the position where the mouse was moved to."
|
||||
else:
|
||||
raise ValueError(f"Unknown single action type: {action}")
|
||||
else:
|
||||
assert mouse_actions[0].startswith('pyautogui.moveTo') and mouse_actions[1].startswith('pyautogui.dragTo')
|
||||
assert mouse_actions[0].startswith("pyautogui.moveTo") and mouse_actions[
|
||||
1
|
||||
].startswith("pyautogui.dragTo")
|
||||
return "The blue circle labeled 'MoveTo' marks the starting position of the mouse.\nThe green circle labeled 'DragTo' marks the ending position.\nThe green line illustrates the mouse's drag path."
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_zoomed_image(image_bytes: bytes, x: int, y: int, width: int = 300, height: int = 300, upscaling: bool = False, scale: int = 4, add_bounding_box: bool = False) -> bytes:
|
||||
def get_zoomed_image(
|
||||
image_bytes: bytes,
|
||||
x: int,
|
||||
y: int,
|
||||
width: int = 300,
|
||||
height: int = 300,
|
||||
upscaling: bool = False,
|
||||
scale: int = 4,
|
||||
add_bounding_box: bool = False,
|
||||
) -> bytes:
|
||||
"""Returns a zoomed image centered around (x, y) coordinates.
|
||||
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): The original image in bytes.
|
||||
x (int): The x-coordinate of the center point.
|
||||
@@ -97,14 +128,14 @@ class BehaviorNarrator:
|
||||
upscaling (bool): Whether to upscale and enhance the zoomed image.
|
||||
scale (int): The upscaling factor if upscaling is True.
|
||||
add_bounding_box (bool): Whether to add a bounding box around the zoomed area in the original image.
|
||||
|
||||
|
||||
Returns:
|
||||
bytes: The zoomed image in bytes.
|
||||
bytes: The original image with bounding box in bytes (if add_bounding_box is True). Otherwise, returns original bytes.
|
||||
"""
|
||||
# Find zoom dimensions
|
||||
img = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||
cx, cy = x - width // 2, y - height // 2 # Center coordinates
|
||||
cx, cy = x - width // 2, y - height // 2 # Center coordinates
|
||||
W, H = img.size
|
||||
left = min(max(cx, 0), W - width)
|
||||
top = min(max(cy, 0), H - height)
|
||||
@@ -116,66 +147,127 @@ class BehaviorNarrator:
|
||||
draw_img = img.copy()
|
||||
draw = ImageDraw.Draw(draw_img)
|
||||
draw.rectangle([left, top, right, bottom], outline="red", width=3)
|
||||
original_with_box_bytes = compress_image(image=draw_img) # Compress to reduce size
|
||||
original_with_box_bytes = compress_image(
|
||||
image=draw_img
|
||||
) # Compress to reduce size
|
||||
else:
|
||||
original_with_box_bytes = image_bytes
|
||||
if upscaling:
|
||||
# Upscale and enhance zoomed image
|
||||
zoomed_img = cv2.cvtColor(np.array(zoomed_img), cv2.COLOR_RGB2BGR) # PIL -> OpenCV
|
||||
zoomed_img = cv2.resize(zoomed_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_LANCZOS4)
|
||||
zoomed_img = cv2.fastNlMeansDenoisingColored(zoomed_img, None, 5, 5, 7, 21) # light denoise (helps with JPEG speckle)
|
||||
zoomed_img = Image.fromarray(cv2.cvtColor(zoomed_img, cv2.COLOR_BGR2RGB)) # OpenCV -> PIL
|
||||
zoomed_img_bytes = compress_image(image=zoomed_img) # Compress to reduce size
|
||||
zoomed_img = cv2.cvtColor(
|
||||
np.array(zoomed_img), cv2.COLOR_RGB2BGR
|
||||
) # PIL -> OpenCV
|
||||
zoomed_img = cv2.resize(
|
||||
zoomed_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
zoomed_img = cv2.fastNlMeansDenoisingColored(
|
||||
zoomed_img, None, 5, 5, 7, 21
|
||||
) # light denoise (helps with JPEG speckle)
|
||||
zoomed_img = Image.fromarray(
|
||||
cv2.cvtColor(zoomed_img, cv2.COLOR_BGR2RGB)
|
||||
) # OpenCV -> PIL
|
||||
zoomed_img_bytes = compress_image(image=zoomed_img) # Compress to reduce size
|
||||
return zoomed_img_bytes, original_with_box_bytes
|
||||
|
||||
def judge(self, screenshot_num: int, before_img_bytes: bytes, after_img_bytes: bytes, pyautogui_action: str) -> Dict[str, str]:
|
||||
|
||||
def judge(
|
||||
self,
|
||||
screenshot_num: int,
|
||||
before_img_bytes: bytes,
|
||||
after_img_bytes: bytes,
|
||||
pyautogui_action: str,
|
||||
) -> Dict[str, str]:
|
||||
if pyautogui_action == "DONE":
|
||||
return {
|
||||
"fact_thoughts": "The agent has indicated that it is done with the task.",
|
||||
"fact_answer": "The agent has indicated that it is done with the task."
|
||||
"fact_answer": "The agent has indicated that it is done with the task.",
|
||||
}
|
||||
elif pyautogui_action == "FAIL":
|
||||
return {
|
||||
"fact_thoughts": "The agent has indicated that it is impossible to proceed further with the task.",
|
||||
"fact_answer": "The agent has indicated that it is impossible to proceed further with the task."
|
||||
"fact_answer": "The agent has indicated that it is impossible to proceed further with the task.",
|
||||
}
|
||||
# Prepare ANNOTATED BEFORE image
|
||||
mouse_actions = BehaviorNarrator.extract_mouse_action(pyautogui_action)
|
||||
before_img = Image.open(BytesIO(before_img_bytes))
|
||||
BehaviorNarrator.mark_action(mouse_actions, before_img)
|
||||
out_buffer = BytesIO()
|
||||
before_img.save(out_buffer, format='PNG')
|
||||
before_img.save(out_buffer, format="PNG")
|
||||
marked_before_img_bytes = out_buffer.getvalue()
|
||||
marked_before_img_message = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(marked_before_img_bytes).decode('utf-8')}", "detail": "high"}}
|
||||
marked_before_img_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64.b64encode(marked_before_img_bytes).decode('utf-8')}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
if mouse_actions:
|
||||
coords = mouse_actions[-1].split('(')[1].strip(')').split(', ')
|
||||
coords = mouse_actions[-1].split("(")[1].strip(")").split(", ")
|
||||
x, y = int(coords[0]), int(coords[1])
|
||||
zoomed_after_img_bytes, marked_after_img_bytes = BehaviorNarrator.get_zoomed_image(image_bytes=after_img_bytes, x=x, y=y, width=300, height=300, scale=4, upscaling=True, add_bounding_box=True)
|
||||
after_img_message = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(marked_after_img_bytes).decode('utf-8')}", "detail": "high"}}
|
||||
zoomed_after_img_message = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(zoomed_after_img_bytes).decode('utf-8')}", "detail": "high"}}
|
||||
zoomed_after_img_bytes, marked_after_img_bytes = (
|
||||
BehaviorNarrator.get_zoomed_image(
|
||||
image_bytes=after_img_bytes,
|
||||
x=x,
|
||||
y=y,
|
||||
width=300,
|
||||
height=300,
|
||||
scale=4,
|
||||
upscaling=True,
|
||||
add_bounding_box=True,
|
||||
)
|
||||
)
|
||||
after_img_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64.b64encode(marked_after_img_bytes).decode('utf-8')}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
zoomed_after_img_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64.b64encode(zoomed_after_img_bytes).decode('utf-8')}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
else:
|
||||
after_img_message = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(after_img_bytes).decode('utf-8')}", "detail": "high"}}
|
||||
after_img_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64.b64encode(after_img_bytes).decode('utf-8')}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
zoomed_after_img_message = None
|
||||
|
||||
fact_message = [{"role": "system", "content": PROCEDURAL_MEMORY.BEHAVIOR_NARRATOR_SYSTEM_PROMPT}]
|
||||
fact_message = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": PROCEDURAL_MEMORY.BEHAVIOR_NARRATOR_SYSTEM_PROMPT,
|
||||
}
|
||||
]
|
||||
fact_message_content = [
|
||||
{"type": "text", "text": "BEFORE:"},
|
||||
marked_before_img_message,
|
||||
{"type": "text", "text": f"Agent Action: {pyautogui_action}"},
|
||||
{"type": "text", "text": "AFTER:"},
|
||||
after_img_message
|
||||
after_img_message,
|
||||
]
|
||||
if zoomed_after_img_message:
|
||||
fact_message_content += [
|
||||
{"type": "text", "text": "ZOOMED AFTER:"},
|
||||
zoomed_after_img_message
|
||||
zoomed_after_img_message,
|
||||
]
|
||||
fact_message += [{"role": "user", "content": fact_message_content}]
|
||||
fact_response = call_llm_formatted(self.judge_agent, [THOUGHTS_ANSWER_TAG_FORMATTER], messages=fact_message, temperature=0.0)
|
||||
fact_response = call_llm_formatted(
|
||||
self.judge_agent,
|
||||
[THOUGHTS_ANSWER_TAG_FORMATTER],
|
||||
messages=fact_message,
|
||||
temperature=0.0,
|
||||
)
|
||||
fact_answer, fact_thoughts = split_thinking_response(fact_response)
|
||||
|
||||
result = {
|
||||
"fact_thoughts": fact_thoughts,
|
||||
"fact_answer": f"Fact Caption from Screenshot {screenshot_num}: {fact_answer}"
|
||||
"fact_answer": f"Fact Caption from Screenshot {screenshot_num}: {fact_answer}",
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -13,42 +13,46 @@ def get_final_screenshot_file(task_dir: str) -> str:
|
||||
for filename in os.listdir(task_dir):
|
||||
if filename.startswith("step_") and filename.endswith(".png"):
|
||||
screenshot_files.append(filename)
|
||||
|
||||
|
||||
if not screenshot_files:
|
||||
return "step_0.png" # fallback
|
||||
|
||||
|
||||
# Sort by step number and get the last one
|
||||
def extract_step_num(filename):
|
||||
try:
|
||||
return int(filename.split("_")[1].split(".")[0])
|
||||
except:
|
||||
return 0
|
||||
|
||||
|
||||
screenshot_files.sort(key=extract_step_num)
|
||||
return screenshot_files[-1]
|
||||
|
||||
|
||||
def image_to_openai_message_format(image_path: str, caption: str = "") -> Optional[dict]:
|
||||
def image_to_openai_message_format(
|
||||
image_path: str, caption: str = ""
|
||||
) -> Optional[dict]:
|
||||
"""Convert an image file to OpenAI message format."""
|
||||
if not os.path.exists(image_path):
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(image_path, "rb") as image_file:
|
||||
image_data = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
image_data = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
content = []
|
||||
if caption:
|
||||
content.append({"type": "text", "text": caption})
|
||||
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_data}",
|
||||
"detail": "high"
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_data}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
return {"role": "user", "content": content}
|
||||
except Exception as e:
|
||||
print(f"Error loading image {image_path}: {e}")
|
||||
@@ -58,36 +62,81 @@ def image_to_openai_message_format(image_path: str, caption: str = "") -> Option
|
||||
class ComparativeJudge:
|
||||
def __init__(self, engine_params):
|
||||
self.judge_agent = LMMAgent(engine_params=engine_params)
|
||||
|
||||
def judge(self, task_description: str, task: str, result_dirs: List[str], all_fact_captions: List[List[str]]) -> Tuple[str, str, Optional[str]]:
|
||||
|
||||
def judge(
|
||||
self,
|
||||
task_description: str,
|
||||
task: str,
|
||||
result_dirs: List[str],
|
||||
all_fact_captions: List[List[str]],
|
||||
) -> Tuple[str, str, Optional[str]]:
|
||||
"""
|
||||
Fact captions + initial/final screenshots judging.
|
||||
Pipeline: use provided fact captions → include initial/final screenshots → judge.
|
||||
"""
|
||||
num_trajectories = len(result_dirs)
|
||||
system_prompt = PROCEDURAL_MEMORY.VLM_EVALUATOR_PROMPT_COMPARATIVE_BASELINE
|
||||
system_prompt = system_prompt.replace("<TASK_DESCRIPTION_INPUT>", task_description)
|
||||
system_prompt = system_prompt.replace("<NUMBER OF TRAJECTORIES>", str(num_trajectories))
|
||||
|
||||
system_prompt = system_prompt.replace(
|
||||
"<TASK_DESCRIPTION_INPUT>", task_description
|
||||
)
|
||||
system_prompt = system_prompt.replace(
|
||||
"<NUMBER OF TRAJECTORIES>", str(num_trajectories)
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i, (result_dir, fact_captions) in enumerate(zip(result_dirs, all_fact_captions)):
|
||||
|
||||
for i, (result_dir, fact_captions) in enumerate(
|
||||
zip(result_dirs, all_fact_captions)
|
||||
):
|
||||
task_dir = os.path.join(result_dir, task.split("/")[0], task.split("/")[1])
|
||||
result_initial_screenshot = os.path.join(task_dir, "step_0.png")
|
||||
result_final_screenshot = os.path.join(task_dir, get_final_screenshot_file(task_dir))
|
||||
initial_screenshot_message = image_to_openai_message_format(result_initial_screenshot, caption=f"Initial screenshot of result{i+1}")
|
||||
final_screenshot_message = image_to_openai_message_format(result_final_screenshot, caption=f"Final screenshot of result{i+1}")
|
||||
if initial_screenshot_message is not None and final_screenshot_message is not None:
|
||||
result_final_screenshot = os.path.join(
|
||||
task_dir, get_final_screenshot_file(task_dir)
|
||||
)
|
||||
initial_screenshot_message = image_to_openai_message_format(
|
||||
result_initial_screenshot, caption=f"Initial screenshot of result{i+1}"
|
||||
)
|
||||
final_screenshot_message = image_to_openai_message_format(
|
||||
result_final_screenshot, caption=f"Final screenshot of result{i+1}"
|
||||
)
|
||||
if (
|
||||
initial_screenshot_message is not None
|
||||
and final_screenshot_message is not None
|
||||
):
|
||||
messages.append(initial_screenshot_message)
|
||||
messages.append(final_screenshot_message)
|
||||
if fact_captions:
|
||||
messages.append({"role": "user", "content": [{"type": "text", "text": f"Fact captions for Trajectory {i+1}:"}] + [{"type": "text", "text": caption} for caption in fact_captions]})
|
||||
|
||||
messages.append({"role": "user", "content": [{"type": "text", "text": f"Please evaluate the {num_trajectories} trajectories based on the criteria provided in the system prompt."}]})
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Fact captions for Trajectory {i+1}:",
|
||||
}
|
||||
]
|
||||
+ [
|
||||
{"type": "text", "text": caption}
|
||||
for caption in fact_captions
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Please evaluate the {num_trajectories} trajectories based on the criteria provided in the system prompt.",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
response = call_llm_formatted(self.judge_agent, [], messages=messages)
|
||||
answer, thoughts = split_thinking_response(response)
|
||||
|
||||
|
||||
try:
|
||||
judge_choice = int(answer)
|
||||
if 1 <= judge_choice <= num_trajectories:
|
||||
@@ -96,5 +145,5 @@ class ComparativeJudge:
|
||||
selected_trajectory = None
|
||||
except ValueError:
|
||||
selected_trajectory = None
|
||||
|
||||
|
||||
return answer, thoughts, selected_trajectory
|
||||
|
||||
@@ -19,6 +19,7 @@ current_platform = platform.system().lower()
|
||||
# Global flag to track pause state for debugging
|
||||
paused = False
|
||||
|
||||
|
||||
def get_char():
|
||||
"""Get a single character from stdin without pressing Enter"""
|
||||
try:
|
||||
@@ -26,6 +27,7 @@ def get_char():
|
||||
if platform.system() in ["Darwin", "Linux"]:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
@@ -37,14 +39,16 @@ def get_char():
|
||||
else:
|
||||
# Windows fallback
|
||||
import msvcrt
|
||||
return msvcrt.getch().decode('utf-8', errors='ignore')
|
||||
|
||||
return msvcrt.getch().decode("utf-8", errors="ignore")
|
||||
except:
|
||||
return input() # Fallback for non-terminal environments
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle Ctrl+C signal for debugging during agent execution"""
|
||||
global paused
|
||||
|
||||
|
||||
if not paused:
|
||||
print("\n\n🔸 Agent-S Workflow Paused 🔸")
|
||||
print("=" * 50)
|
||||
@@ -52,14 +56,14 @@ def signal_handler(signum, frame):
|
||||
print(" • Press Ctrl+C again to quit")
|
||||
print(" • Press Esc to resume workflow")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
paused = True
|
||||
|
||||
|
||||
while paused:
|
||||
try:
|
||||
print("\n[PAUSED] Waiting for input... ", end="", flush=True)
|
||||
char = get_char()
|
||||
|
||||
|
||||
if ord(char) == 3: # Ctrl+C
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
@@ -69,7 +73,7 @@ def signal_handler(signum, frame):
|
||||
break
|
||||
else:
|
||||
print(f"\n Unknown command: '{char}' (ord: {ord(char)})")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
@@ -78,6 +82,7 @@ def signal_handler(signum, frame):
|
||||
print("\n\n🛑 Exiting Agent-S...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# Set up signal handler for Ctrl+C
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -173,7 +178,7 @@ def run_agent(agent, instruction: str, scaled_width: int, scaled_height: int):
|
||||
time.sleep(0.1)
|
||||
|
||||
print(f"\n🔄 Step {step + 1}/15: Getting next action from agent...")
|
||||
|
||||
|
||||
# Get next action code from the agent
|
||||
info, code = agent.predict(instruction=instruction, observation=obs)
|
||||
|
||||
@@ -249,7 +254,7 @@ def main():
|
||||
"--model_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)"
|
||||
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
|
||||
# Grounding model config: Self-hosted endpoint based (required)
|
||||
@@ -318,7 +323,7 @@ def main():
|
||||
"model": args.model,
|
||||
"base_url": args.model_url,
|
||||
"api_key": args.model_api_key,
|
||||
"temperature": getattr(args, 'model_temperature', None),
|
||||
"temperature": getattr(args, "model_temperature", None),
|
||||
}
|
||||
|
||||
# Load the grounding engine from a custom endpoint
|
||||
|
||||
@@ -18,7 +18,14 @@ class LMMEngine:
|
||||
|
||||
class LMMEngineOpenAI(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, organization=None, **kwargs
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
organization=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
@@ -27,7 +34,7 @@ class LMMEngineOpenAI(LMMEngine):
|
||||
self.organization = organization
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature # Can force temperature to be the same (in the case of o3 requiring temperature to be 1)
|
||||
self.temperature = temperature # Can force temperature to be the same (in the case of o3 requiring temperature to be 1)
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
@@ -43,13 +50,17 @@ class LMMEngineOpenAI(LMMEngine):
|
||||
if not self.base_url:
|
||||
self.llm_client = OpenAI(api_key=api_key, organization=organization)
|
||||
else:
|
||||
self.llm_client = OpenAI(base_url=self.base_url, api_key=api_key, organization=organization)
|
||||
self.llm_client = OpenAI(
|
||||
base_url=self.base_url, api_key=api_key, organization=organization
|
||||
)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
# max_completion_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature if self.temperature is None else self.temperature,
|
||||
temperature=(
|
||||
temperature if self.temperature is None else self.temperature
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
@@ -59,7 +70,13 @@ class LMMEngineOpenAI(LMMEngine):
|
||||
|
||||
class LMMEngineAnthropic(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, thinking=False, temperature=None, **kwargs
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
thinking=False,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
@@ -137,7 +154,13 @@ class LMMEngineAnthropic(LMMEngine):
|
||||
|
||||
class LMMEngineGemini(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
@@ -180,7 +203,13 @@ class LMMEngineGemini(LMMEngine):
|
||||
|
||||
class LMMEngineOpenRouter(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
@@ -284,7 +313,13 @@ class LMMEngineAzureOpenAI(LMMEngine):
|
||||
|
||||
class LMMEnginevLLM(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, temperature=None, **kwargs
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
@@ -304,7 +339,7 @@ class LMMEnginevLLM(LMMEngine):
|
||||
top_p=0.8,
|
||||
repetition_penalty=1.05,
|
||||
max_new_tokens=512,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
api_key = self.api_key or os.getenv("vLLM_API_KEY")
|
||||
if api_key is None:
|
||||
@@ -368,7 +403,9 @@ class LMMEngineHuggingFace(LMMEngine):
|
||||
|
||||
|
||||
class LMMEngineParasail(LMMEngine):
|
||||
def __init__(self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
||||
):
|
||||
assert model is not None, "Parasail model id must be provided"
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
@@ -391,15 +428,18 @@ class LMMEngineParasail(LMMEngine):
|
||||
"Parasail endpoint must be provided as base_url parameter or as an environment variable named PARASAIL_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url if base_url else "https://api.parasail.io/v1", api_key=api_key)
|
||||
self.llm_client = OpenAI(
|
||||
base_url=base_url if base_url else "https://api.parasail.io/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0].
|
||||
message.content
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
@@ -128,7 +128,7 @@ class LMMAgent:
|
||||
LMMEngineHuggingFace,
|
||||
LMMEngineGemini,
|
||||
LMMEngineOpenRouter,
|
||||
LMMEngineParasail
|
||||
LMMEngineParasail,
|
||||
),
|
||||
):
|
||||
# infer role from previous message
|
||||
|
||||
@@ -4,12 +4,13 @@ import textwrap
|
||||
|
||||
class PROCEDURAL_MEMORY:
|
||||
|
||||
FORMATTING_FEEDBACK_PROMPT = textwrap.dedent("""
|
||||
FORMATTING_FEEDBACK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
Your previous response was not formatted correctly. You must respond again to replace your previous response. Do not make reference to this message while fixing the response. Please address the following issues below to improve the previous response:
|
||||
FORMATTING_FEEDBACK
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def construct_simple_worker_procedural_memory(agent_class, skipped_actions):
|
||||
procedural_memory = textwrap.dedent(
|
||||
@@ -153,7 +154,8 @@ class PROCEDURAL_MEMORY:
|
||||
"""
|
||||
)
|
||||
|
||||
CODE_AGENT_PROMPT = textwrap.dedent("""\
|
||||
CODE_AGENT_PROMPT = textwrap.dedent(
|
||||
"""\
|
||||
You are a code execution agent with a limited step budget to complete tasks.
|
||||
|
||||
# Core Guidelines:
|
||||
@@ -260,9 +262,11 @@ class PROCEDURAL_MEMORY:
|
||||
- After in-place modifications, close/reopen files via GUI to show changes
|
||||
|
||||
Focus on progress within your step budget.
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
CODE_SUMMARY_AGENT_PROMPT = textwrap.dedent("""\
|
||||
CODE_SUMMARY_AGENT_PROMPT = textwrap.dedent(
|
||||
"""\
|
||||
You are a code execution summarizer. Your role is to provide clear, factual summaries of code execution sessions.
|
||||
|
||||
Key responsibilities:
|
||||
@@ -282,9 +286,11 @@ class PROCEDURAL_MEMORY:
|
||||
- This helps the GUI agent understand what to expect and verify your work properly
|
||||
|
||||
Always maintain a factual, non-judgmental tone.
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
BEHAVIOR_NARRATOR_SYSTEM_PROMPT = textwrap.dedent("""\
|
||||
BEHAVIOR_NARRATOR_SYSTEM_PROMPT = textwrap.dedent(
|
||||
"""\
|
||||
You are an expert in computer usage responsible for analyzing what happened after a computer action is taken.
|
||||
|
||||
**Reasoning Guidelines:**
|
||||
@@ -311,9 +317,11 @@ class PROCEDURAL_MEMORY:
|
||||
<answer>
|
||||
[An unordered list of the relevant changes induced by the action]
|
||||
</answer>
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
VLM_EVALUATOR_PROMPT_COMPARATIVE_BASELINE = textwrap.dedent("""\
|
||||
VLM_EVALUATOR_PROMPT_COMPARATIVE_BASELINE = textwrap.dedent(
|
||||
"""\
|
||||
You are a meticulous and impartial evaluator, tasked with judging <NUMBER OF TRAJECTORIES> sequences of OS desktop actions to determine which one better completes the user's request. Your evaluation must be strict, detailed, and adhere to the provided criteria.
|
||||
|
||||
**User Request:**
|
||||
@@ -364,4 +372,5 @@ class PROCEDURAL_MEMORY:
|
||||
<answer>
|
||||
[The index of the better sequence, a single integer from 1 to <NUMBER OF TRAJECTORIES>]
|
||||
</answer>
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -8,8 +8,10 @@ from typing import Tuple, Dict
|
||||
from gui_agents.s3.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
def create_pyautogui_code(agent, code: str, obs: Dict) -> str:
|
||||
"""
|
||||
Attempts to evaluate the code into a pyautogui code snippet with grounded actions using the observation screenshot.
|
||||
@@ -18,17 +20,18 @@ def create_pyautogui_code(agent, code: str, obs: Dict) -> str:
|
||||
agent (ACI): The grounding agent to use for evaluation.
|
||||
code (str): The code string to evaluate.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
|
||||
|
||||
Returns:
|
||||
exec_code (str): The pyautogui code to execute the grounded action.
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error in evaluating the code.
|
||||
"""
|
||||
agent.assign_screenshot(obs) # Necessary for grounding
|
||||
agent.assign_screenshot(obs) # Necessary for grounding
|
||||
exec_code = eval(code)
|
||||
return exec_code
|
||||
|
||||
|
||||
def call_llm_safe(
|
||||
agent, temperature: float = 0.0, use_thinking: bool = False, **kwargs
|
||||
) -> str:
|
||||
@@ -52,78 +55,93 @@ def call_llm_safe(
|
||||
time.sleep(1.0)
|
||||
return response if response is not None else ""
|
||||
|
||||
|
||||
def call_llm_formatted(generator, format_checkers, **kwargs):
|
||||
"""
|
||||
Calls the generator agent's LLM and ensures correct formatting.
|
||||
"""
|
||||
Calls the generator agent's LLM and ensures correct formatting.
|
||||
|
||||
Args:
|
||||
generator (ACI): The generator agent to call.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
format_checkers (Callable): Functions that take the response and return a tuple of (success, feedback).
|
||||
**kwargs: Additional keyword arguments for the LLM call.
|
||||
|
||||
Returns:
|
||||
response (str): The formatted response from the generator agent.
|
||||
"""
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
messages = generator.messages.copy() # Copy messages to avoid modifying the original
|
||||
while attempt < max_retries:
|
||||
response = call_llm_safe(generator, messages=messages, **kwargs)
|
||||
Args:
|
||||
generator (ACI): The generator agent to call.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
format_checkers (Callable): Functions that take the response and return a tuple of (success, feedback).
|
||||
**kwargs: Additional keyword arguments for the LLM call.
|
||||
|
||||
# Prepare feedback messages for incorrect formatting
|
||||
feedback_msgs = []
|
||||
for format_checker in format_checkers:
|
||||
success, feedback = format_checker(response)
|
||||
if not success:
|
||||
feedback_msgs.append(feedback)
|
||||
if not feedback_msgs:
|
||||
# logger.info(f"Response formatted correctly on attempt {attempt} for {generator.engine.model}")
|
||||
break
|
||||
logger.error(f"Response formatting error on attempt {attempt} for {generator.engine.model}. Response: {response} {', '.join(feedback_msgs)}")
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": response}],
|
||||
}
|
||||
Returns:
|
||||
response (str): The formatted response from the generator agent.
|
||||
"""
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
messages = (
|
||||
generator.messages.copy()
|
||||
) # Copy messages to avoid modifying the original
|
||||
while attempt < max_retries:
|
||||
response = call_llm_safe(generator, messages=messages, **kwargs)
|
||||
|
||||
# Prepare feedback messages for incorrect formatting
|
||||
feedback_msgs = []
|
||||
for format_checker in format_checkers:
|
||||
success, feedback = format_checker(response)
|
||||
if not success:
|
||||
feedback_msgs.append(feedback)
|
||||
if not feedback_msgs:
|
||||
# logger.info(f"Response formatted correctly on attempt {attempt} for {generator.engine.model}")
|
||||
break
|
||||
logger.error(
|
||||
f"Response formatting error on attempt {attempt} for {generator.engine.model}. Response: {response} {', '.join(feedback_msgs)}"
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": response}],
|
||||
}
|
||||
)
|
||||
logger.info(f"Bad response: {response}")
|
||||
delimiter = "\n- "
|
||||
formatting_feedback = f"- {delimiter.join(feedback_msgs)}"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": PROCEDURAL_MEMORY.FORMATTING_FEEDBACK_PROMPT.replace(
|
||||
"FORMATTING_FEEDBACK", formatting_feedback
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
logger.info("Feedback:\n%s", formatting_feedback)
|
||||
|
||||
attempt += 1
|
||||
if attempt == max_retries:
|
||||
logger.error(
|
||||
"Max retries reached when formatting response. Handling failure."
|
||||
)
|
||||
logger.info(f"Bad response: {response}")
|
||||
delimiter = "\n- "
|
||||
formatting_feedback = f"- {delimiter.join(feedback_msgs)}"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": PROCEDURAL_MEMORY.FORMATTING_FEEDBACK_PROMPT.replace("FORMATTING_FEEDBACK", formatting_feedback)}],
|
||||
}
|
||||
)
|
||||
logger.info("Feedback:\n%s", formatting_feedback)
|
||||
|
||||
attempt += 1
|
||||
if attempt == max_retries:
|
||||
logger.error("Max retries reached when formatting response. Handling failure.")
|
||||
time.sleep(1.0)
|
||||
return response
|
||||
time.sleep(1.0)
|
||||
return response
|
||||
|
||||
|
||||
def split_thinking_response(full_response: str) -> Tuple[str, str]:
|
||||
try:
|
||||
# Extract thoughts section
|
||||
thoughts = full_response.split("<thoughts>")[-1].split("</thoughts>")[0].strip()
|
||||
|
||||
|
||||
# Extract answer section
|
||||
answer = full_response.split("<answer>")[-1].split("</answer>")[0].strip()
|
||||
|
||||
|
||||
return answer, thoughts
|
||||
except Exception as e:
|
||||
return full_response, ""
|
||||
|
||||
|
||||
def parse_code_from_string(input_string):
|
||||
""" Parses a string to extract each line of code enclosed in triple backticks (```)
|
||||
|
||||
"""Parses a string to extract each line of code enclosed in triple backticks (```)
|
||||
|
||||
Args:
|
||||
input_string (str): The input string containing code snippets.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The last code snippet found in the input string, or an empty string if no code is found.
|
||||
"""
|
||||
@@ -138,35 +156,39 @@ def parse_code_from_string(input_string):
|
||||
if len(matches) == 0:
|
||||
# return []
|
||||
return ""
|
||||
relevant_code = matches[-1] # We only care about the last match given it is the grounded action
|
||||
relevant_code = matches[
|
||||
-1
|
||||
] # We only care about the last match given it is the grounded action
|
||||
return relevant_code
|
||||
|
||||
|
||||
def extract_agent_functions(code):
|
||||
"""Extracts all agent function calls from the given code.
|
||||
|
||||
|
||||
Args:
|
||||
code (str): The code string to search for agent function calls.
|
||||
|
||||
|
||||
Returns:
|
||||
list: A list of all agent function calls found in the code.
|
||||
"""
|
||||
pattern = r'(agent\.\w+\(\s*.*\))' # Matches
|
||||
pattern = r"(agent\.\w+\(\s*.*\))" # Matches
|
||||
return re.findall(pattern, code)
|
||||
|
||||
|
||||
def compress_image(image_bytes: bytes = None, image: Image = None) -> bytes:
|
||||
"""Compresses an image represented as bytes.
|
||||
|
||||
Compression involves resizing image into half its original size and saving to webp format.
|
||||
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): The image data to compress.
|
||||
|
||||
|
||||
Returns:
|
||||
bytes: The compressed image data.
|
||||
"""
|
||||
if not image:
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
output = BytesIO()
|
||||
image.save(output, format='WEBP')
|
||||
image.save(output, format="WEBP")
|
||||
compressed_image_bytes = output.getvalue()
|
||||
return compressed_image_bytes
|
||||
return compressed_image_bytes
|
||||
|
||||
@@ -1,37 +1,58 @@
|
||||
"""This file contains various formatting checks used to reprompt an agent for correctly formatted responses."""
|
||||
|
||||
from gui_agents.s3.utils.common_utils import (
|
||||
extract_agent_functions,
|
||||
parse_code_from_string,
|
||||
create_pyautogui_code,
|
||||
split_thinking_response
|
||||
split_thinking_response,
|
||||
)
|
||||
|
||||
single_action_check = lambda response: len(extract_agent_functions(parse_code_from_string(response))) == 1
|
||||
single_action_error_msg = "Incorrect code: There must be a single agent action in the code response."
|
||||
SINGLE_ACTION_FORMATTER = lambda response: (
|
||||
single_action_check(response), single_action_error_msg
|
||||
single_action_check = (
|
||||
lambda response: len(extract_agent_functions(parse_code_from_string(response))) == 1
|
||||
)
|
||||
single_action_error_msg = (
|
||||
"Incorrect code: There must be a single agent action in the code response."
|
||||
)
|
||||
SINGLE_ACTION_FORMATTER = lambda response: (
|
||||
single_action_check(response),
|
||||
single_action_error_msg,
|
||||
)
|
||||
|
||||
|
||||
def _attempt_code_creation(agent, code, obs):
|
||||
""" Attempts to create a pyautogui code snippet from the response code """
|
||||
"""Attempts to create a pyautogui code snippet from the response code"""
|
||||
try:
|
||||
return create_pyautogui_code(agent, code, obs)
|
||||
except Exception as e:
|
||||
return None
|
||||
code_valid_check = lambda agent, obs, response: _attempt_code_creation(agent, parse_code_from_string(response), obs) is not None
|
||||
|
||||
|
||||
code_valid_check = (
|
||||
lambda agent, obs, response: _attempt_code_creation(
|
||||
agent, parse_code_from_string(response), obs
|
||||
)
|
||||
is not None
|
||||
)
|
||||
code_valid_error_msg = "Incorrect code: The agent action must be a valid function and use valid parameters from the docstring list."
|
||||
CODE_VALID_FORMATTER = lambda agent, obs, response: (
|
||||
code_valid_check(agent, obs, response), code_valid_error_msg
|
||||
code_valid_check(agent, obs, response),
|
||||
code_valid_error_msg,
|
||||
)
|
||||
|
||||
thoughts_answer_tag_check = lambda response: split_thinking_response(response)[1] != ""
|
||||
thoughts_answer_tag_error_msg = "Incorrect response: The response must contain both <thoughts>...</thoughts> and <answer>...</answer> tags."
|
||||
THOUGHTS_ANSWER_TAG_FORMATTER = lambda response: (
|
||||
thoughts_answer_tag_check(response), thoughts_answer_tag_error_msg
|
||||
thoughts_answer_tag_check(response),
|
||||
thoughts_answer_tag_error_msg,
|
||||
)
|
||||
|
||||
integer_answer_check = lambda response: split_thinking_response(response)[0].strip().isdigit()
|
||||
integer_answer_error_msg = "Incorrect response: The <answer>...</answer> tag must contain a single integer."
|
||||
integer_answer_check = (
|
||||
lambda response: split_thinking_response(response)[0].strip().isdigit()
|
||||
)
|
||||
integer_answer_error_msg = (
|
||||
"Incorrect response: The <answer>...</answer> tag must contain a single integer."
|
||||
)
|
||||
INTEGER_ANSWER_FORMATTER = lambda response: (
|
||||
integer_answer_check(response), integer_answer_error_msg
|
||||
)
|
||||
integer_answer_check(response),
|
||||
integer_answer_error_msg,
|
||||
)
|
||||
|
||||
@@ -24,18 +24,17 @@ def run_single_example(
|
||||
|
||||
with open(os.path.join(example_result_dir, f"step_0.png"), "wb") as _f:
|
||||
_f.write(obs["screenshot"])
|
||||
|
||||
with open(os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8") as f:
|
||||
|
||||
with open(
|
||||
os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(instruction)
|
||||
|
||||
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
response, actions = agent.predict(instruction, obs)
|
||||
for action in actions:
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
@@ -64,11 +63,7 @@ def run_single_example(
|
||||
}
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
response
|
||||
)
|
||||
)
|
||||
f.write(json.dumps(response))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
@@ -90,4 +85,4 @@ def setup_logger(example, example_result_dir):
|
||||
runtime_logger.addHandler(
|
||||
logging.FileHandler(os.path.join(example_result_dir, "runtime.log"))
|
||||
)
|
||||
return runtime_logger
|
||||
return runtime_logger
|
||||
|
||||
@@ -24,19 +24,18 @@ def run_single_example(
|
||||
|
||||
with open(os.path.join(example_result_dir, f"step_0.png"), "wb") as _f:
|
||||
_f.write(obs["screenshot"])
|
||||
|
||||
with open(os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8") as f:
|
||||
|
||||
with open(
|
||||
os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(instruction)
|
||||
|
||||
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
time.sleep(0.5)
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
response, actions = agent.predict(instruction, obs)
|
||||
for action in actions:
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
@@ -65,11 +64,7 @@ def run_single_example(
|
||||
}
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
response
|
||||
)
|
||||
)
|
||||
f.write(json.dumps(response))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
@@ -91,4 +86,4 @@ def setup_logger(example, example_result_dir):
|
||||
runtime_logger.addHandler(
|
||||
logging.FileHandler(os.path.join(example_result_dir, "runtime.log"))
|
||||
)
|
||||
return runtime_logger
|
||||
return runtime_logger
|
||||
|
||||
+98
-38
@@ -15,6 +15,7 @@ import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -47,6 +48,7 @@ active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> list:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
@@ -54,10 +56,11 @@ def distribute_tasks(test_all_meta: dict) -> list:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
active_environments = local_vars.get("active_environments", [])
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
@@ -69,23 +72,34 @@ def process_signal_handler(signum, frame, env_idx):
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list, engine_params, engine_params_for_grounding):
|
||||
|
||||
def run_env_tasks(
|
||||
task_queue: Queue,
|
||||
args: argparse.Namespace,
|
||||
shared_scores: list,
|
||||
engine_params,
|
||||
engine_params_for_grounding,
|
||||
):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
||||
snapshot_name = None
|
||||
region = getattr(args, 'region', None)
|
||||
if args.provider_name == 'aws' and region is not None:
|
||||
region = getattr(args, "region", None)
|
||||
if args.provider_name == "aws" and region is not None:
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
snapshot_name = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
snapshot_name = IMAGE_ID_MAP[region].get(
|
||||
screen_size, IMAGE_ID_MAP[region][(1920, 1080)]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get snapshot_name from IMAGE_ID_MAP: {e}")
|
||||
snapshot_name = None
|
||||
from gui_agents.s2_5.agents.agent_s import AgentS2_5
|
||||
from gui_agents.s2_5.agents.grounding import OSWorldACI
|
||||
|
||||
grounding_agent = OSWorldACI(
|
||||
platform="linux",
|
||||
engine_params_for_generation=engine_params,
|
||||
@@ -106,10 +120,11 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type = "Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, 'client_password', ''),
|
||||
client_password=getattr(args, "client_password", ""),
|
||||
)
|
||||
active_environments.append(env)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
@@ -151,7 +166,10 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
|
||||
logger.error(
|
||||
f"Exception in {current_process().name} {domain}/{example_id}: {e}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
@@ -160,19 +178,17 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
@@ -181,7 +197,10 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
logger.error(
|
||||
f"{current_process().name} error during environment cleanup: {e}"
|
||||
)
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
global is_terminating, active_environments, processes
|
||||
@@ -209,12 +228,14 @@ def signal_handler(signum, frame):
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
@@ -223,8 +244,10 @@ def config() -> argparse.Namespace:
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||
"--provider_name",
|
||||
type=str,
|
||||
default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
@@ -238,7 +261,12 @@ def config() -> argparse.Namespace:
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument(
|
||||
"--num_envs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of environments to run in parallel",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=1.0)
|
||||
@@ -263,7 +291,6 @@ def config() -> argparse.Namespace:
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=8)
|
||||
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model_provider", type=str, default="openai")
|
||||
parser.add_argument("--model", type=str, default="gpt-4o")
|
||||
@@ -279,11 +306,23 @@ def config() -> argparse.Namespace:
|
||||
default="",
|
||||
help="The API key of the main generation model.",
|
||||
)
|
||||
parser.add_argument("--model_temperature", type=float, default=None, help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
|
||||
# grounding model config
|
||||
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||
parser.add_argument(
|
||||
"--ground_provider",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The provider for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_url", type=str, required=True, help="The URL of the grounding model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_api_key",
|
||||
type=str,
|
||||
@@ -291,7 +330,10 @@ def config() -> argparse.Namespace:
|
||||
help="The API key of the grounding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||
"--ground_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The model name for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
@@ -320,15 +362,15 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
engine_params = {
|
||||
"engine_type": args.model_provider,
|
||||
"model": args.model,
|
||||
"base_url": getattr(args, 'model_url', ''),
|
||||
"api_key": getattr(args, 'model_api_key', ''),
|
||||
"temperature": getattr(args, 'model_temperature', None),
|
||||
"base_url": getattr(args, "model_url", ""),
|
||||
"api_key": getattr(args, "model_api_key", ""),
|
||||
"temperature": getattr(args, "model_temperature", None),
|
||||
}
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": args.ground_provider,
|
||||
"model": args.ground_model,
|
||||
"base_url": getattr(args, 'ground_url', ''),
|
||||
"api_key": getattr(args, 'ground_api_key', ''),
|
||||
"base_url": getattr(args, "ground_url", ""),
|
||||
"api_key": getattr(args, "ground_api_key", ""),
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
}
|
||||
@@ -343,8 +385,14 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores, engine_params, engine_params_for_grounding),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
args=(
|
||||
task_queue,
|
||||
args,
|
||||
shared_scores,
|
||||
engine_params,
|
||||
engine_params_for_grounding,
|
||||
),
|
||||
name=f"EnvProcess-{i+1}",
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
@@ -358,13 +406,21 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores, engine_params, engine_params_for_grounding),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
args=(
|
||||
task_queue,
|
||||
args,
|
||||
shared_scores,
|
||||
engine_params,
|
||||
engine_params_for_grounding,
|
||||
),
|
||||
name=f"EnvProcess-Restart-{idx+1}",
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
logger.info(
|
||||
f"Restarted process {new_p.name} with PID {new_p.pid}"
|
||||
)
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
@@ -377,10 +433,14 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
logger.info(
|
||||
"Main process received KeyboardInterrupt. Initiating graceful shutdown..."
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error while waiting for processes: {e}", exc_info=True
|
||||
)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
@@ -471,7 +531,7 @@ if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
@@ -509,4 +569,4 @@ if __name__ == "__main__":
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
test(args, test_file_list)
|
||||
|
||||
@@ -17,6 +17,7 @@ from gui_agents.s2_5.agents.agent_s import AgentS2_5
|
||||
from gui_agents.s2_5.agents.grounding import OSWorldACI
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Almost deprecated since it's not multi-env, use run_multienv_*.py instead
|
||||
@@ -71,8 +72,10 @@ def config() -> argparse.Namespace:
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||
"--provider_name",
|
||||
type=str,
|
||||
default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
@@ -115,11 +118,23 @@ def config() -> argparse.Namespace:
|
||||
default="",
|
||||
help="The API key of the main generation model.",
|
||||
)
|
||||
parser.add_argument("--model_temperature", type=float, default=None, help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
|
||||
# grounding model config
|
||||
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||
parser.add_argument(
|
||||
"--ground_provider",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The provider for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_url", type=str, required=True, help="The URL of the grounding model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_api_key",
|
||||
type=str,
|
||||
@@ -127,7 +142,10 @@ def config() -> argparse.Namespace:
|
||||
help="The API key of the grounding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||
"--ground_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The model name for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
@@ -182,15 +200,15 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
engine_params = {
|
||||
"engine_type": args.model_provider,
|
||||
"model": args.model,
|
||||
"base_url": getattr(args, 'model_url', ''),
|
||||
"api_key": getattr(args, 'model_api_key', ''),
|
||||
"temperature": getattr(args, 'model_temperature', None),
|
||||
"base_url": getattr(args, "model_url", ""),
|
||||
"api_key": getattr(args, "model_api_key", ""),
|
||||
"temperature": getattr(args, "model_temperature", None),
|
||||
}
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": args.ground_provider,
|
||||
"model": args.ground_model,
|
||||
"base_url": getattr(args, 'ground_url', ''),
|
||||
"api_key": getattr(args, 'ground_api_key', ''),
|
||||
"base_url": getattr(args, "ground_url", ""),
|
||||
"api_key": getattr(args, "ground_api_key", ""),
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
}
|
||||
@@ -217,11 +235,11 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
action_space=args.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type = "Ubuntu",
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
snapshot_name="signed_in_state_1"
|
||||
snapshot_name="signed_in_state_1",
|
||||
)
|
||||
|
||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||
@@ -269,7 +287,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||
# Only attempt to end recording if controller exists (not Docker provider)
|
||||
if hasattr(env, 'controller') and env.controller is not None:
|
||||
if hasattr(env, "controller") and env.controller is not None:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
@@ -361,7 +379,7 @@ if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
@@ -399,4 +417,4 @@ if __name__ == "__main__":
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
test(args, test_file_list)
|
||||
|
||||
@@ -11,11 +11,17 @@ from utils import get_new_tasks_classification
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def generate_single_fact_caption(task_dir: str, screenshot_files: List[str], i: int, judge: BehaviorNarrator, trajectory_lines: List[str]):
|
||||
async def generate_single_fact_caption(
|
||||
task_dir: str,
|
||||
screenshot_files: List[str],
|
||||
i: int,
|
||||
judge: BehaviorNarrator,
|
||||
trajectory_lines: List[str],
|
||||
):
|
||||
"""Generate a single fact caption for a screenshot pair."""
|
||||
before_file = os.path.join(task_dir, screenshot_files[i])
|
||||
after_file = os.path.join(task_dir, screenshot_files[i + 1])
|
||||
|
||||
|
||||
# Load action from trajectory data if available
|
||||
pyautogui_action = None
|
||||
if i < len(trajectory_lines):
|
||||
@@ -24,10 +30,10 @@ async def generate_single_fact_caption(task_dir: str, screenshot_files: List[str
|
||||
pyautogui_action = data.get("exec_code")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if pyautogui_action is None:
|
||||
raise ValueError(f"No pyautogui action found for step {i+1}")
|
||||
|
||||
|
||||
# Read image bytes
|
||||
try:
|
||||
with open(before_file, "rb") as f:
|
||||
@@ -36,37 +42,43 @@ async def generate_single_fact_caption(task_dir: str, screenshot_files: List[str
|
||||
after_bytes = f.read()
|
||||
except Exception as e:
|
||||
raise Exception(f"Error reading images: {e}")
|
||||
|
||||
|
||||
# Generate fact caption using behavior narrator
|
||||
result = await asyncio.to_thread(judge.judge, before_bytes, after_bytes, pyautogui_action)
|
||||
result = await asyncio.to_thread(
|
||||
judge.judge, before_bytes, after_bytes, pyautogui_action
|
||||
)
|
||||
result["screenshot_num"] = i + 1
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def generate_fact_captions_parallel(task_dir: str, judge: BehaviorNarrator, step_semaphore: Optional[asyncio.Semaphore] = None):
|
||||
async def generate_fact_captions_parallel(
|
||||
task_dir: str,
|
||||
judge: BehaviorNarrator,
|
||||
step_semaphore: Optional[asyncio.Semaphore] = None,
|
||||
):
|
||||
"""Generate fact captions for a task directory when they don't exist (parallelized version)."""
|
||||
print(f"Generating fact captions for {task_dir}...")
|
||||
|
||||
|
||||
# Find all screenshot files
|
||||
screenshot_files = []
|
||||
for filename in os.listdir(task_dir):
|
||||
if filename.startswith("step_") and filename.endswith(".png"):
|
||||
screenshot_files.append(filename)
|
||||
|
||||
|
||||
# Sort by step number
|
||||
def extract_step_num(filename):
|
||||
try:
|
||||
return int(filename.split("_")[1].split(".")[0])
|
||||
except:
|
||||
return 0
|
||||
|
||||
|
||||
screenshot_files.sort(key=extract_step_num)
|
||||
|
||||
|
||||
if len(screenshot_files) < 2:
|
||||
print(f"Not enough screenshots to generate fact captions in {task_dir}")
|
||||
return []
|
||||
|
||||
|
||||
# Load trajectory data once
|
||||
trajectory_lines = []
|
||||
trajectory_file = os.path.join(task_dir, "traj.jsonl")
|
||||
@@ -76,31 +88,38 @@ async def generate_fact_captions_parallel(task_dir: str, judge: BehaviorNarrator
|
||||
trajectory_lines = f.readlines()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Use shared semaphore to limit concurrent judge calls
|
||||
if step_semaphore is None:
|
||||
step_semaphore = asyncio.Semaphore(5) # Default limit
|
||||
|
||||
|
||||
async def bounded_task(task_func, *args, **kwargs):
|
||||
async with step_semaphore:
|
||||
return await task_func(*args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
# Create bounded tasks for parallel execution
|
||||
bounded_tasks = [
|
||||
bounded_task(generate_single_fact_caption, task_dir, screenshot_files, i, judge, trajectory_lines)
|
||||
bounded_task(
|
||||
generate_single_fact_caption,
|
||||
task_dir,
|
||||
screenshot_files,
|
||||
i,
|
||||
judge,
|
||||
trajectory_lines,
|
||||
)
|
||||
for i in range(len(screenshot_files) - 1)
|
||||
]
|
||||
results = await asyncio.gather(*bounded_tasks, return_exceptions=True)
|
||||
except Exception as e:
|
||||
print(f"Error in parallel execution: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Process results and save to file
|
||||
fact_captions = []
|
||||
successful_results = []
|
||||
fact_captions_file = os.path.join(task_dir, "fact_captions.jsonl")
|
||||
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error generating fact caption for step {i+1}: {result}")
|
||||
@@ -108,94 +127,111 @@ async def generate_fact_captions_parallel(task_dir: str, judge: BehaviorNarrator
|
||||
successful_results.append(result)
|
||||
fact_caption = f"Fact Caption from Screenshot {result['screenshot_num']}: {result['fact_answer']}"
|
||||
fact_captions.append(fact_caption)
|
||||
|
||||
|
||||
# Save all results to file at once
|
||||
if successful_results:
|
||||
with open(fact_captions_file, "w") as f:
|
||||
for result in successful_results:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
|
||||
|
||||
print(f"Generated {len(fact_captions)} fact captions for {task_dir}")
|
||||
return fact_captions
|
||||
|
||||
|
||||
async def main(engine_params: dict, results_dirs: List[str]):
|
||||
"""Main function to generate fact captions for multiple task directories.
|
||||
|
||||
|
||||
Args:
|
||||
engine_params: Engine parameters for BehaviorNarrator
|
||||
results_dirs: List of results directories to analyze for task classification
|
||||
"""
|
||||
# Get task IDs automatically using get_new_tasks_classification
|
||||
tasks_classification = get_new_tasks_classification(results_dirs)
|
||||
task_ids = tasks_classification['variance']
|
||||
|
||||
task_ids = tasks_classification["variance"]
|
||||
|
||||
print(f"Found {len(task_ids)} variance tasks to process")
|
||||
judge = BehaviorNarrator(engine_params=engine_params)
|
||||
|
||||
|
||||
# Get concurrency settings from environment
|
||||
per_step = int(os.getenv("DIFFCAP_PER_STEP_CONCURRENCY", "100"))
|
||||
per_taskdir = int(os.getenv("DIFFCAP_PER_TASKDIR_CONCURRENCY", "4"))
|
||||
|
||||
|
||||
# Build list of task directories to process
|
||||
task_dirs = []
|
||||
for task_id in task_ids:
|
||||
domain, example_id = task_id.split("/")
|
||||
|
||||
|
||||
# Check each results directory for this task
|
||||
for results_dir in results_dirs:
|
||||
task_dir = os.path.join(results_dir, domain, example_id)
|
||||
|
||||
|
||||
try:
|
||||
if "fact_captions.jsonl" in os.listdir(task_dir):
|
||||
print(f"Fact captions already exist for {task_dir}")
|
||||
continue
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
|
||||
task_dirs.append(task_dir)
|
||||
|
||||
|
||||
if not task_dirs:
|
||||
print("No new task directories to process.")
|
||||
return
|
||||
|
||||
|
||||
print(f"Scheduling {len(task_dirs)} task directories...")
|
||||
|
||||
|
||||
# Set up semaphores for concurrency control
|
||||
shared_step_semaphore = asyncio.Semaphore(per_step)
|
||||
taskdir_semaphore = asyncio.Semaphore(per_taskdir)
|
||||
|
||||
|
||||
async def run_one(task_dir):
|
||||
async with taskdir_semaphore:
|
||||
print(f"Processing {task_dir}")
|
||||
return await generate_fact_captions_parallel(task_dir, judge, step_semaphore=shared_step_semaphore)
|
||||
|
||||
return await generate_fact_captions_parallel(
|
||||
task_dir, judge, step_semaphore=shared_step_semaphore
|
||||
)
|
||||
|
||||
# Execute all tasks in parallel
|
||||
results = await asyncio.gather(*[run_one(d) for d in task_dirs], return_exceptions=True)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[run_one(d) for d in task_dirs], return_exceptions=True
|
||||
)
|
||||
|
||||
# Report results
|
||||
failures = sum(1 for r in results if isinstance(r, Exception))
|
||||
if failures:
|
||||
print(f"Completed with {failures} failures out of {len(task_dirs)} task directories.")
|
||||
print(
|
||||
f"Completed with {failures} failures out of {len(task_dirs)} task directories."
|
||||
)
|
||||
else:
|
||||
print("Completed all task directories successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate fact captions for OSWorld task directories")
|
||||
parser.add_argument("--results-dirs", nargs="+", required=True, help="List of results directories to analyze for task classification")
|
||||
parser.add_argument("--model", default="gpt-5-2025-08-07", help="Model to use for generation")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate fact captions for OSWorld task directories"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--results-dirs",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="List of results directories to analyze for task classification",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", default="gpt-5-2025-08-07", help="Model to use for generation"
|
||||
)
|
||||
parser.add_argument("--engine-type", default="openai", help="Engine type")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
||||
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=1.0, help="Temperature for generation"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Engine parameters
|
||||
engine_params = {
|
||||
"model": args.model,
|
||||
"engine_type": args.engine_type,
|
||||
"temperature": args.temperature
|
||||
"temperature": args.temperature,
|
||||
}
|
||||
|
||||
|
||||
print(f"Results directories: {args.results_dirs}")
|
||||
asyncio.run(main(engine_params, args.results_dirs))
|
||||
asyncio.run(main(engine_params, args.results_dirs))
|
||||
|
||||
@@ -10,34 +10,42 @@ from tqdm.asyncio import tqdm_asyncio
|
||||
load_dotenv()
|
||||
|
||||
from utils import (
|
||||
get_new_tasks_classification,
|
||||
evaluate_comparative_results,
|
||||
get_new_tasks_classification,
|
||||
evaluate_comparative_results,
|
||||
load_task_instruction,
|
||||
load_facts
|
||||
load_facts,
|
||||
)
|
||||
from gui_agents.s3.bbon.comparative_judge import ComparativeJudge
|
||||
|
||||
def run_judge(task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge) -> Tuple[str, str, Optional[str]]:
|
||||
|
||||
def run_judge(
|
||||
task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge
|
||||
) -> Tuple[str, str, Optional[str]]:
|
||||
"""
|
||||
Fact captions + initial/final screenshots judging.
|
||||
Pipeline: load trajectories → load existing fact captions → include initial/final screenshots → judge.
|
||||
"""
|
||||
# 1. Use provided task instruction
|
||||
# task_instruction is now a direct input parameter
|
||||
|
||||
|
||||
# 2. Load fact captions for all trajectories
|
||||
all_fact_captions = []
|
||||
for result_dir in result_dirs:
|
||||
task_dir = os.path.join(result_dir, task.split("/")[0], task.split("/")[1])
|
||||
fact_captions = load_facts(task_dir)
|
||||
all_fact_captions.append(fact_captions)
|
||||
|
||||
|
||||
# 3. Use the new Judge class method
|
||||
return judge.judge(task_instruction, task, result_dirs, all_fact_captions)
|
||||
|
||||
def evaluate_trajectories(task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge) -> Tuple[str, str, dict]:
|
||||
|
||||
def evaluate_trajectories(
|
||||
task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge
|
||||
) -> Tuple[str, str, dict]:
|
||||
"""Wrapper that runs fact-only MCQ judge and returns results."""
|
||||
answer, thoughts, selected_trajectory = run_judge(task, task_instruction, result_dirs, judge)
|
||||
answer, thoughts, selected_trajectory = run_judge(
|
||||
task, task_instruction, result_dirs, judge
|
||||
)
|
||||
|
||||
record = {
|
||||
"selected_trajectory": selected_trajectory,
|
||||
@@ -48,23 +56,40 @@ def evaluate_trajectories(task: str, task_instruction: str, result_dirs: List[st
|
||||
print(f"✅ Added task {task} (MCQ fact-only)")
|
||||
return answer, thoughts, record
|
||||
|
||||
|
||||
asyncio.get_event_loop().set_default_executor(
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=100)
|
||||
)
|
||||
|
||||
async def run_async(task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge):
|
||||
|
||||
async def run_async(
|
||||
task: str, task_instruction: str, result_dirs: List[str], judge: ComparativeJudge
|
||||
):
|
||||
"""Async wrapper for fact-only MCQ evaluation."""
|
||||
return await asyncio.to_thread(
|
||||
evaluate_trajectories, task=task, task_instruction=task_instruction, result_dirs=result_dirs, judge=judge
|
||||
evaluate_trajectories,
|
||||
task=task,
|
||||
task_instruction=task_instruction,
|
||||
result_dirs=result_dirs,
|
||||
judge=judge,
|
||||
)
|
||||
|
||||
|
||||
async def evaluate_and_save(result_dirs: List[str], output_file_path: str, examples_path: str, engine_params: dict):
|
||||
async def evaluate_and_save(
|
||||
result_dirs: List[str],
|
||||
output_file_path: str,
|
||||
examples_path: str,
|
||||
engine_params: dict,
|
||||
):
|
||||
"""Main evaluation function that processes tasks and saves results."""
|
||||
res = get_new_tasks_classification(results_dirs=result_dirs)
|
||||
for key in res:
|
||||
print(f"{key}: {res[key]}")
|
||||
optimal, minimum, expected_value = res["optimal"], res["minimum"], res["expected_value"]
|
||||
optimal, minimum, expected_value = (
|
||||
res["optimal"],
|
||||
res["minimum"],
|
||||
res["expected_value"],
|
||||
)
|
||||
print(f"optimal score: {optimal}, minimum score: {minimum}")
|
||||
|
||||
variance = res["variance"]
|
||||
@@ -90,13 +115,13 @@ async def evaluate_and_save(result_dirs: List[str], output_file_path: str, examp
|
||||
if str(task) in data:
|
||||
print(f"⚠️ Task {task} already exists in results — skipping.")
|
||||
continue
|
||||
|
||||
|
||||
# Load task instruction from examples path
|
||||
task_instruction = load_task_instruction(task, examples_path)
|
||||
if task_instruction is None:
|
||||
print(f"⚠️ No task instruction found for {task}, skipping...")
|
||||
continue
|
||||
|
||||
|
||||
tasks.append(run_async(task, task_instruction, result_dirs, judge))
|
||||
task_names.append(task)
|
||||
|
||||
@@ -105,11 +130,11 @@ async def evaluate_and_save(result_dirs: List[str], output_file_path: str, examp
|
||||
# Merge into existing results
|
||||
for task, (ans, thoughts, record) in zip(task_names, results):
|
||||
data[str(task)] = record
|
||||
|
||||
|
||||
os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
|
||||
with open(output_file_path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
res = evaluate_comparative_results(result_dirs, json_path=output_file_path)
|
||||
gain, maximum_gain = res
|
||||
data["score"] = {
|
||||
@@ -117,7 +142,7 @@ async def evaluate_and_save(result_dirs: List[str], output_file_path: str, examp
|
||||
"minimum": minimum,
|
||||
"expected_value": expected_value,
|
||||
"res": res,
|
||||
"actual score": minimum + gain
|
||||
"actual score": minimum + gain,
|
||||
}
|
||||
os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
|
||||
with open(output_file_path, "w") as f:
|
||||
@@ -125,7 +150,15 @@ async def evaluate_and_save(result_dirs: List[str], output_file_path: str, examp
|
||||
|
||||
return results
|
||||
|
||||
async def run_experiment(shuffled_runs: List[str], output_dir: str, examples_path: str, engine_params: dict, start_round: int = 2, max_rounds: int = None):
|
||||
|
||||
async def run_experiment(
|
||||
shuffled_runs: List[str],
|
||||
output_dir: str,
|
||||
examples_path: str,
|
||||
engine_params: dict,
|
||||
start_round: int = 2,
|
||||
max_rounds: int = None,
|
||||
):
|
||||
"""
|
||||
Run fact-only experiments progressively: start_round vs start_round+1, etc.
|
||||
"""
|
||||
@@ -139,12 +172,21 @@ async def run_experiment(shuffled_runs: List[str], output_dir: str, examples_pat
|
||||
output_file_path = os.path.join(output_dir, f"BoN{i}.json")
|
||||
|
||||
print(f"Running fact-only experiment with {i} dirs → {output_file_path}")
|
||||
await evaluate_and_save(test_dirs, output_file_path, examples_path, engine_params)
|
||||
await evaluate_and_save(
|
||||
test_dirs, output_file_path, examples_path, engine_params
|
||||
)
|
||||
|
||||
|
||||
async def main(shuffled_runs: List[str] = None, output_dir: str = None, examples_path: str = None, engine_params: dict = None, start_round: int = 2, max_rounds: int = None):
|
||||
async def main(
|
||||
shuffled_runs: List[str] = None,
|
||||
output_dir: str = None,
|
||||
examples_path: str = None,
|
||||
engine_params: dict = None,
|
||||
start_round: int = 2,
|
||||
max_rounds: int = None,
|
||||
):
|
||||
"""Main function to run fact-only judge experiments.
|
||||
|
||||
|
||||
Args:
|
||||
shuffled_runs: List of result directory paths to compare
|
||||
output_dir: Directory to save results
|
||||
@@ -156,57 +198,81 @@ async def main(shuffled_runs: List[str] = None, output_dir: str = None, examples
|
||||
if shuffled_runs is None:
|
||||
print("Error: shuffled_runs must be provided")
|
||||
return
|
||||
|
||||
|
||||
if output_dir is None:
|
||||
print("Error: output_dir must be provided")
|
||||
return
|
||||
|
||||
|
||||
if examples_path is None:
|
||||
print("Error: examples_path must be provided")
|
||||
return
|
||||
|
||||
|
||||
if engine_params is None:
|
||||
print("Error: engine_params must be provided")
|
||||
return
|
||||
|
||||
await run_experiment(shuffled_runs, output_dir, examples_path, engine_params, start_round, max_rounds)
|
||||
|
||||
await run_experiment(
|
||||
shuffled_runs, output_dir, examples_path, engine_params, start_round, max_rounds
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run fact-only judge experiments on OSWorld task directories")
|
||||
parser.add_argument("--results-dirs", nargs="+", required=True, help="List of results directories to analyze")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run fact-only judge experiments on OSWorld task directories"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--results-dirs",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="List of results directories to analyze",
|
||||
)
|
||||
parser.add_argument("--output-dir", required=True, help="Directory to save results")
|
||||
parser.add_argument("--examples-path", required=True, help="Path to examples directory containing task instructions")
|
||||
parser.add_argument("--start-round", type=int, default=2, help="Starting round number (default: 2)")
|
||||
parser.add_argument("--max-rounds", type=int, default=None, help="Maximum number of rounds to run (default: len(results_dirs))")
|
||||
parser.add_argument("--model", default="gpt-5-2025-08-07", help="Model to use for judging")
|
||||
parser.add_argument(
|
||||
"--examples-path",
|
||||
required=True,
|
||||
help="Path to examples directory containing task instructions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-round", type=int, default=2, help="Starting round number (default: 2)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-rounds",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of rounds to run (default: len(results_dirs))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", default="gpt-5-2025-08-07", help="Model to use for judging"
|
||||
)
|
||||
parser.add_argument("--engine-type", default="openai", help="Engine type")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
||||
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=1.0, help="Temperature for generation"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Engine parameters
|
||||
engine_params = {
|
||||
"model": args.model,
|
||||
"engine_type": args.engine_type,
|
||||
"temperature": args.temperature
|
||||
"temperature": args.temperature,
|
||||
}
|
||||
|
||||
|
||||
print(f"Results directories: {args.results_dirs}")
|
||||
print(f"Output directory: {args.output_dir}")
|
||||
print(f"Examples path: {args.examples_path}")
|
||||
print(f"Start round: {args.start_round}")
|
||||
print(f"Max rounds: {args.max_rounds}")
|
||||
print(f"Engine params: {engine_params}")
|
||||
|
||||
|
||||
# Run fact-only evaluation
|
||||
asyncio.run(
|
||||
main(
|
||||
shuffled_runs=args.results_dirs,
|
||||
output_dir=args.output_dir,
|
||||
examples_path=args.examples_path,
|
||||
engine_params=engine_params,
|
||||
start_round=args.start_round,
|
||||
max_rounds=args.max_rounds,
|
||||
shuffled_runs=args.results_dirs,
|
||||
output_dir=args.output_dir,
|
||||
examples_path=args.examples_path,
|
||||
engine_params=engine_params,
|
||||
start_round=args.start_round,
|
||||
max_rounds=args.max_rounds,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -6,37 +6,42 @@ from PIL import Image
|
||||
from typing import Optional, List
|
||||
import base64
|
||||
|
||||
def image_to_openai_message_format(image_path: str, caption: str = None) -> Optional[dict]:
|
||||
|
||||
def image_to_openai_message_format(
|
||||
image_path: str, caption: str = None
|
||||
) -> Optional[dict]:
|
||||
"""Convert an image file to OpenAI message format."""
|
||||
if not os.path.exists(image_path):
|
||||
print(f"Image file not found: {image_path}")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
|
||||
|
||||
if not image_bytes:
|
||||
print(f"Empty image file: {image_path}")
|
||||
return None
|
||||
|
||||
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
|
||||
if not base64_image:
|
||||
print(f"Failed to encode image to base64: {image_path}")
|
||||
return None
|
||||
|
||||
|
||||
content = []
|
||||
if caption:
|
||||
content.append({"type": "text", "text": caption})
|
||||
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{base64_image}"}
|
||||
})
|
||||
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing image {image_path}: {e}")
|
||||
return None
|
||||
@@ -45,11 +50,11 @@ def image_to_openai_message_format(image_path: str, caption: str = None) -> Opti
|
||||
def load_facts(task_dir: str) -> List[str]:
|
||||
"""Load existing facts from facts.jsonl file."""
|
||||
fact_captions_file = os.path.join(task_dir, "fact_captions.jsonl")
|
||||
|
||||
|
||||
if not os.path.exists(fact_captions_file):
|
||||
print(f"fact_captions.jsonl not found at {fact_captions_file}")
|
||||
return []
|
||||
|
||||
|
||||
fact_captions = []
|
||||
with open(fact_captions_file, "r") as f:
|
||||
for line in f:
|
||||
@@ -60,39 +65,40 @@ def load_facts(task_dir: str) -> List[str]:
|
||||
|
||||
return fact_captions
|
||||
|
||||
|
||||
def load_task_instruction(task: str, examples_path: str) -> Optional[str]:
|
||||
"""
|
||||
Load task instruction from examples path.
|
||||
|
||||
|
||||
Args:
|
||||
task: Task ID in format "domain/example_id"
|
||||
examples_path: Path to the examples directory (e.g., "/home/ubuntu/Simular/OSWorld/evaluation_examples/examples")
|
||||
|
||||
|
||||
Returns:
|
||||
Task instruction string or None if not found
|
||||
"""
|
||||
domain, example_id = task.split("/", 1)
|
||||
|
||||
|
||||
# Construct path to the JSON file
|
||||
json_file_path = os.path.join(examples_path, domain, f"{example_id}.json")
|
||||
|
||||
|
||||
if not os.path.exists(json_file_path):
|
||||
logging.warning(f"Example file not found: {json_file_path}")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(json_file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# Extract instruction from the JSON
|
||||
if "instruction" in data:
|
||||
instruction = data["instruction"]
|
||||
if instruction and instruction.strip():
|
||||
return instruction.strip()
|
||||
|
||||
|
||||
logging.warning(f"No 'instruction' key found in {json_file_path}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Error reading example file {json_file_path}: {e}")
|
||||
return None
|
||||
@@ -109,7 +115,7 @@ def get_final_screenshot_file(result_dir: str) -> str:
|
||||
# First, collect all valid step files with their indices
|
||||
step_files = {}
|
||||
pattern = re.compile(r"step[_\-]?(\d+)", re.IGNORECASE)
|
||||
|
||||
|
||||
for fname in os.listdir(result_dir):
|
||||
if not fname.lower().endswith(".png"):
|
||||
continue
|
||||
@@ -129,9 +135,12 @@ def get_final_screenshot_file(result_dir: str) -> str:
|
||||
if os.path.exists(file_path) and is_valid_image(file_path):
|
||||
return fname
|
||||
else:
|
||||
print(f"Invalid or corrupted image at step {idx}: {fname}, trying previous step...")
|
||||
print(
|
||||
f"Invalid or corrupted image at step {idx}: {fname}, trying previous step..."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def is_valid_image(file_path: str) -> bool:
|
||||
"""
|
||||
Check if an image file is valid by trying to open it with PIL.
|
||||
@@ -141,7 +150,7 @@ def is_valid_image(file_path: str) -> bool:
|
||||
# Check file size first (quick check)
|
||||
if os.path.getsize(file_path) == 0:
|
||||
return False
|
||||
|
||||
|
||||
# Try to open and verify the image
|
||||
with Image.open(file_path) as img:
|
||||
img.verify() # This will raise an exception if image is corrupted
|
||||
@@ -171,7 +180,7 @@ def get_new_tasks_classification(results_dirs: [str]):
|
||||
|
||||
constant_tasks = []
|
||||
variance_tasks = []
|
||||
constant_tasks_scores = []
|
||||
constant_tasks_scores = []
|
||||
optimal_sum = 0.0
|
||||
expected_value = 0.0
|
||||
|
||||
@@ -207,13 +216,16 @@ def get_new_tasks_classification(results_dirs: [str]):
|
||||
expected_value += sum(results) / len(results)
|
||||
|
||||
return {
|
||||
"constant": constant_tasks, #We dont evaluate constant tasks
|
||||
"variance": variance_tasks, #We evaluate variance tasks
|
||||
"minimum": sum(constant_tasks_scores), #sum of constant tasks scores (easy + hard)
|
||||
"optimal": optimal_sum, #If we get the best score, we get the optimal score
|
||||
"expected_value": expected_value, #If we get the average score across all tasks for all trajectories, we get the expected value
|
||||
"constant": constant_tasks, # We dont evaluate constant tasks
|
||||
"variance": variance_tasks, # We evaluate variance tasks
|
||||
"minimum": sum(
|
||||
constant_tasks_scores
|
||||
), # sum of constant tasks scores (easy + hard)
|
||||
"optimal": optimal_sum, # If we get the best score, we get the optimal score
|
||||
"expected_value": expected_value, # If we get the average score across all tasks for all trajectories, we get the expected value
|
||||
}
|
||||
|
||||
|
||||
def check_selected_trajectory(results_dirs: [str], selected_trajectory: str, task: str):
|
||||
"""
|
||||
results_dirs: list of directories in format results_dir/<domain>/<task_id>
|
||||
@@ -226,7 +238,8 @@ def check_selected_trajectory(results_dirs: [str], selected_trajectory: str, tas
|
||||
all_results = []
|
||||
|
||||
if not any(
|
||||
os.path.commonpath([os.path.abspath(selected_trajectory), os.path.abspath(rd)]) == os.path.abspath(rd)
|
||||
os.path.commonpath([os.path.abspath(selected_trajectory), os.path.abspath(rd)])
|
||||
== os.path.abspath(rd)
|
||||
for rd in results_dirs
|
||||
):
|
||||
return None, None
|
||||
@@ -251,9 +264,10 @@ def check_selected_trajectory(results_dirs: [str], selected_trajectory: str, tas
|
||||
optimal_val = max(all_results) if all_results else selected_val
|
||||
return selected_val, optimal_val
|
||||
|
||||
|
||||
def evaluate_comparative_results(results_dirs: [str], json_path: str = None):
|
||||
"""
|
||||
Opens comparative_judge_results.json (default) or a given path,
|
||||
Opens comparative_judge_results.json (default) or a given path,
|
||||
evaluates each task, and returns results.
|
||||
|
||||
Args:
|
||||
@@ -275,9 +289,13 @@ def evaluate_comparative_results(results_dirs: [str], json_path: str = None):
|
||||
for task, info in data.items():
|
||||
selected_trajectory = info.get("selected_trajectory")
|
||||
if selected_trajectory:
|
||||
selected_val, optimal_val = check_selected_trajectory(results_dirs, selected_trajectory, task)
|
||||
selected_val, optimal_val = check_selected_trajectory(
|
||||
results_dirs, selected_trajectory, task
|
||||
)
|
||||
if selected_val is not None and optimal_val is not None:
|
||||
print(f"task: {task}, selected_val: {selected_val}, optimal_val: {optimal_val}")
|
||||
print(
|
||||
f"task: {task}, selected_val: {selected_val}, optimal_val: {optimal_val}"
|
||||
)
|
||||
judge_score += selected_val
|
||||
optimal_score += optimal_val
|
||||
return judge_score, optimal_score
|
||||
return judge_score, optimal_score
|
||||
|
||||
@@ -24,18 +24,17 @@ def run_single_example(
|
||||
|
||||
with open(os.path.join(example_result_dir, f"step_0.png"), "wb") as _f:
|
||||
_f.write(obs["screenshot"])
|
||||
|
||||
with open(os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8") as f:
|
||||
|
||||
with open(
|
||||
os.path.join(example_result_dir, "instruction.txt"), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(instruction)
|
||||
|
||||
|
||||
done = False
|
||||
step_idx = 0
|
||||
# env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
response, actions = agent.predict(instruction, obs)
|
||||
for action in actions:
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
@@ -63,7 +62,9 @@ def run_single_example(
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png",
|
||||
}
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
with open(
|
||||
os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(json.dumps(response, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
if done:
|
||||
@@ -87,4 +88,3 @@ def setup_logger(example, example_result_dir):
|
||||
logging.FileHandler(os.path.join(example_result_dir, "runtime.log"))
|
||||
)
|
||||
return runtime_logger
|
||||
|
||||
|
||||
+97
-37
@@ -19,6 +19,7 @@ import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -51,6 +52,7 @@ active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> list:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
@@ -58,10 +60,11 @@ def distribute_tasks(test_all_meta: dict) -> list:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
active_environments = local_vars.get("active_environments", [])
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
@@ -73,23 +76,34 @@ def process_signal_handler(signum, frame, env_idx):
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list, engine_params, engine_params_for_grounding):
|
||||
|
||||
def run_env_tasks(
|
||||
task_queue: Queue,
|
||||
args: argparse.Namespace,
|
||||
shared_scores: list,
|
||||
engine_params,
|
||||
engine_params_for_grounding,
|
||||
):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
||||
snapshot_name = None
|
||||
region = getattr(args, 'region', None)
|
||||
if args.provider_name == 'aws' and region is not None:
|
||||
region = getattr(args, "region", None)
|
||||
if args.provider_name == "aws" and region is not None:
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
snapshot_name = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
snapshot_name = IMAGE_ID_MAP[region].get(
|
||||
screen_size, IMAGE_ID_MAP[region][(1920, 1080)]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get snapshot_name from IMAGE_ID_MAP: {e}")
|
||||
snapshot_name = None
|
||||
from gui_agents.s3.agents.agent_s import AgentS3
|
||||
from gui_agents.s3.agents.grounding import OSWorldACI
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
@@ -98,10 +112,11 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type = "Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, 'client_password', ''),
|
||||
client_password=getattr(args, "client_password", ""),
|
||||
)
|
||||
grounding_agent = OSWorldACI(
|
||||
env=env,
|
||||
@@ -157,7 +172,10 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
|
||||
logger.error(
|
||||
f"Exception in {current_process().name} {domain}/{example_id}: {e}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
@@ -166,19 +184,17 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
@@ -187,7 +203,10 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
logger.error(
|
||||
f"{current_process().name} error during environment cleanup: {e}"
|
||||
)
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
global is_terminating, active_environments, processes
|
||||
@@ -215,12 +234,14 @@ def signal_handler(signum, frame):
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
@@ -229,8 +250,10 @@ def config() -> argparse.Namespace:
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||
"--provider_name",
|
||||
type=str,
|
||||
default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
@@ -244,7 +267,12 @@ def config() -> argparse.Namespace:
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument(
|
||||
"--num_envs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of environments to run in parallel",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=1.0)
|
||||
@@ -269,7 +297,6 @@ def config() -> argparse.Namespace:
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=8)
|
||||
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model_provider", type=str, default="openai")
|
||||
parser.add_argument("--model", type=str, default="gpt-4o")
|
||||
@@ -285,11 +312,23 @@ def config() -> argparse.Namespace:
|
||||
default="",
|
||||
help="The API key of the main generation model.",
|
||||
)
|
||||
parser.add_argument("--model_temperature", type=float, default=None, help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
|
||||
# grounding model config
|
||||
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||
parser.add_argument(
|
||||
"--ground_provider",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The provider for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_url", type=str, required=True, help="The URL of the grounding model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_api_key",
|
||||
type=str,
|
||||
@@ -297,7 +336,10 @@ def config() -> argparse.Namespace:
|
||||
help="The API key of the grounding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||
"--ground_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The model name for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
@@ -326,15 +368,15 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
engine_params = {
|
||||
"engine_type": args.model_provider,
|
||||
"model": args.model,
|
||||
"base_url": getattr(args, 'model_url', ''),
|
||||
"api_key": getattr(args, 'model_api_key', ''),
|
||||
"temperature": getattr(args, 'model_temperature', None),
|
||||
"base_url": getattr(args, "model_url", ""),
|
||||
"api_key": getattr(args, "model_api_key", ""),
|
||||
"temperature": getattr(args, "model_temperature", None),
|
||||
}
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": args.ground_provider,
|
||||
"model": args.ground_model,
|
||||
"base_url": getattr(args, 'ground_url', ''),
|
||||
"api_key": getattr(args, 'ground_api_key', ''),
|
||||
"base_url": getattr(args, "ground_url", ""),
|
||||
"api_key": getattr(args, "ground_api_key", ""),
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
}
|
||||
@@ -349,8 +391,14 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores, engine_params, engine_params_for_grounding),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
args=(
|
||||
task_queue,
|
||||
args,
|
||||
shared_scores,
|
||||
engine_params,
|
||||
engine_params_for_grounding,
|
||||
),
|
||||
name=f"EnvProcess-{i+1}",
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
@@ -364,13 +412,21 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores, engine_params, engine_params_for_grounding),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
args=(
|
||||
task_queue,
|
||||
args,
|
||||
shared_scores,
|
||||
engine_params,
|
||||
engine_params_for_grounding,
|
||||
),
|
||||
name=f"EnvProcess-Restart-{idx+1}",
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
logger.info(
|
||||
f"Restarted process {new_p.name} with PID {new_p.pid}"
|
||||
)
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
@@ -383,10 +439,14 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
logger.info(
|
||||
"Main process received KeyboardInterrupt. Initiating graceful shutdown..."
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error while waiting for processes: {e}", exc_info=True
|
||||
)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
@@ -477,7 +537,7 @@ if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
|
||||
@@ -17,6 +17,7 @@ from gui_agents.s3.agents.agent_s import AgentS3
|
||||
from gui_agents.s3.agents.grounding import OSWorldACI
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Almost deprecated since it's not multi-env, use run_multienv_*.py instead
|
||||
@@ -71,8 +72,10 @@ def config() -> argparse.Namespace:
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)"
|
||||
"--provider_name",
|
||||
type=str,
|
||||
default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
@@ -115,11 +118,23 @@ def config() -> argparse.Namespace:
|
||||
default="",
|
||||
help="The API key of the main generation model.",
|
||||
)
|
||||
parser.add_argument("--model_temperature", type=float, default=None, help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
|
||||
# grounding model config
|
||||
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||
parser.add_argument(
|
||||
"--ground_provider",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The provider for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_url", type=str, required=True, help="The URL of the grounding model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_api_key",
|
||||
type=str,
|
||||
@@ -127,7 +142,10 @@ def config() -> argparse.Namespace:
|
||||
help="The API key of the grounding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||
"--ground_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The model name for the grounding model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
@@ -182,15 +200,15 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
engine_params = {
|
||||
"engine_type": args.model_provider,
|
||||
"model": args.model,
|
||||
"base_url": getattr(args, 'model_url', ''),
|
||||
"api_key": getattr(args, 'model_api_key', ''),
|
||||
"temperature": getattr(args, 'model_temperature', None),
|
||||
"base_url": getattr(args, "model_url", ""),
|
||||
"api_key": getattr(args, "model_api_key", ""),
|
||||
"temperature": getattr(args, "model_temperature", None),
|
||||
}
|
||||
engine_params_for_grounding = {
|
||||
"engine_type": args.ground_provider,
|
||||
"model": args.ground_model,
|
||||
"base_url": getattr(args, 'ground_url', ''),
|
||||
"api_key": getattr(args, 'ground_api_key', ''),
|
||||
"base_url": getattr(args, "ground_url", ""),
|
||||
"api_key": getattr(args, "ground_api_key", ""),
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
}
|
||||
@@ -201,7 +219,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
action_space=args.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type = "Ubuntu",
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
@@ -266,7 +284,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||
# Only attempt to end recording if controller exists (not Docker provider)
|
||||
if hasattr(env, 'controller') and env.controller is not None:
|
||||
if hasattr(env, "controller") and env.controller is not None:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
@@ -358,7 +376,7 @@ if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
|
||||
+2
-2
@@ -2,7 +2,7 @@ from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="gui-agents",
|
||||
version="0.2.5.post3",
|
||||
version="0.3.0",
|
||||
description="A library for creating general purpose GUI agents using multimodal LLMs.",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
@@ -35,7 +35,7 @@ setup(
|
||||
extras_require={"dev": ["black"]}, # Code formatter for linting
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"agent_s=gui_agents.s2_5.cli_app:main",
|
||||
"agent_s=gui_agents.s3.cli_app:main",
|
||||
],
|
||||
},
|
||||
classifiers=[
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário