diff --git a/.github/workflows/darglint.yml b/.github/workflows/darglint.yml index 48bbf67a..93668135 100644 --- a/.github/workflows/darglint.yml +++ b/.github/workflows/darglint.yml @@ -21,7 +21,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.12' cache: 'pip' # caching pip dependencies - name: Pip install diff --git a/.github/workflows/python_version_compatibility.yml b/.github/workflows/python_version_compatibility.yml new file mode 100644 index 00000000..95700275 --- /dev/null +++ b/.github/workflows/python_version_compatibility.yml @@ -0,0 +1,40 @@ +name: Python Compatibility (Info Only) + +on: + push: + branches: + - main + pull_request: + +jobs: + info-check: + runs-on: ubuntu-latest + continue-on-error: true + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + # Optional: Cache uv for faster runs + - name: Cache uv + uses: actions/cache@v4 + with: + path: ~/.cargo/bin/uv + key: uv-${{ runner.os }} + + - name: Install uv + run: | + if [ ! -f ~/.cargo/bin/uv ]; then + curl -LsSf https://astral.sh/uv/install.sh | sh + fi + + - name: Check Python ${{ matrix.python-version }} + continue-on-error: true + run: | + export PATH="$HOME/.cargo/bin:$PATH" + if uvx --python ${{ matrix.python-version }} --from python --with-requirements requirements.txt python -c "print('✅ Compatible')"; then + echo "✅ Python ${{ matrix.python-version }} works" + else + echo "❌ Python ${{ matrix.python-version }} incompatible" + fi \ No newline at end of file diff --git a/.gitignore b/.gitignore index aa97d987..6339cccb 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,4 @@ results/ outputs/ miniwob-plusplus/ .miniwob-server.pid +debugging_results/ \ No newline at end of file diff --git a/src/agentlab/agents/agent_args.py b/src/agentlab/agents/agent_args.py index b2cd0eb6..78036c31 100644 --- a/src/agentlab/agents/agent_args.py +++ b/src/agentlab/agents/agent_args.py @@ -1,5 +1,5 @@ import bgym -from bgym import AbstractAgentArgs +from bgym import AbstractAgentArgs, Benchmark class AgentArgs(AbstractAgentArgs): @@ -14,7 +14,7 @@ class MyAgentArgs(AgentArgs): Note: for working properly with AgentXRay, the arguments need to be serializable and hasable. """ - def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool): + def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): """Optional method to set benchmark specific flags. This allows the agent to have minor adjustments based on the benchmark. diff --git a/src/agentlab/agents/agent_utils.py b/src/agentlab/agents/agent_utils.py new file mode 100644 index 00000000..991e27e6 --- /dev/null +++ b/src/agentlab/agents/agent_utils.py @@ -0,0 +1,267 @@ +from logging import warning +from typing import Optional, Tuple + +import numpy as np +from PIL import Image, ImageDraw +from playwright.sync_api import Page + +""" +This module contains utility functions for handling observations and actions in the context of agent interactions. +""" + + +def tag_screenshot_with_action(screenshot: Image, action: str) -> Image: + """ + If action is a coordinate action, try to render it on the screenshot. + + e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot + + Args: + screenshot: The screenshot to tag. + action: The action to tag the screenshot with. + + Returns: + The tagged screenshot. + + Raises: + ValueError: If the action parsing fails. + """ + if action.startswith("mouse_click"): + try: + coords = action[action.index("(") + 1 : action.index(")")].split(",") + coords = [c.strip() for c in coords] + if len(coords) not in [2, 3]: + raise ValueError(f"Invalid coordinate format: {coords}") + if coords[0].startswith("x="): + coords[0] = coords[0][2:] + if coords[1].startswith("y="): + coords[1] = coords[1][2:] + x, y = float(coords[0].strip()), float(coords[1].strip()) + draw = ImageDraw.Draw(screenshot) + radius = 5 + draw.ellipse( + (x - radius, y - radius, x + radius, y + radius), fill="blue", outline="blue" + ) + except (ValueError, IndexError) as e: + warning(f"Failed to parse action '{action}': {e}") + + elif action.startswith("mouse_drag_and_drop"): + try: + func_name, parsed_args = parse_func_call_string(action) + if func_name == "mouse_drag_and_drop" and parsed_args is not None: + args, kwargs = parsed_args + x1, y1, x2, y2 = None, None, None, None + + if args and len(args) >= 4: + # Positional arguments: mouse_drag_and_drop(x1, y1, x2, y2) + x1, y1, x2, y2 = map(float, args[:4]) + elif kwargs: + # Keyword arguments: mouse_drag_and_drop(from_x=x1, from_y=y1, to_x=x2, to_y=y2) + x1 = float(kwargs.get("from_x", 0)) + y1 = float(kwargs.get("from_y", 0)) + x2 = float(kwargs.get("to_x", 0)) + y2 = float(kwargs.get("to_y", 0)) + + if all(coord is not None for coord in [x1, y1, x2, y2]): + draw = ImageDraw.Draw(screenshot) + # Draw the main line + draw.line((x1, y1, x2, y2), fill="red", width=2) + # Draw arrowhead at the end point using the helper function + draw_arrowhead(draw, (x1, y1), (x2, y2)) + except (ValueError, IndexError) as e: + warning(f"Failed to parse action '{action}': {e}") + return screenshot + + +def add_mouse_pointer_from_action(screenshot: Image, action: str) -> Image.Image: + + if action.startswith("mouse_click"): + try: + coords = action[action.index("(") + 1 : action.index(")")].split(",") + coords = [c.strip() for c in coords] + if len(coords) not in [2, 3]: + raise ValueError(f"Invalid coordinate format: {coords}") + if coords[0].startswith("x="): + coords[0] = coords[0][2:] + if coords[1].startswith("y="): + coords[1] = coords[1][2:] + x, y = int(coords[0].strip()), int(coords[1].strip()) + screenshot = draw_mouse_pointer(screenshot, x, y) + except (ValueError, IndexError) as e: + warning(f"Failed to parse action '{action}': {e}") + return screenshot + + +def draw_mouse_pointer(image: Image.Image, x: int, y: int) -> Image.Image: + """ + Draws a semi-transparent mouse pointer at (x, y) on the image. + Returns a new image with the pointer drawn. + + Args: + image: The image to draw the mouse pointer on. + x: The x coordinate for the mouse pointer. + y: The y coordinate for the mouse pointer. + + Returns: + A new image with the mouse pointer drawn. + """ + pointer_size = 20 # Length of the pointer + overlay = image.convert("RGBA").copy() + draw = ImageDraw.Draw(overlay) + + # Define pointer shape (a simple arrow) + pointer_shape = [ + (x, y), + (x + pointer_size, y + pointer_size // 2), + (x + pointer_size // 2, y + pointer_size // 2), + (x + pointer_size // 2, y + pointer_size), + ] + + draw.polygon(pointer_shape, fill=(0, 0, 0, 128)) # 50% transparent black + + return Image.alpha_composite(image.convert("RGBA"), overlay) + + +def draw_arrowhead(draw, start, end, arrow_length=15, arrow_angle=30): + from math import atan2, cos, radians, sin + + angle = atan2(end[1] - start[1], end[0] - start[0]) + left = ( + end[0] - arrow_length * cos(angle - radians(arrow_angle)), + end[1] - arrow_length * sin(angle - radians(arrow_angle)), + ) + right = ( + end[0] - arrow_length * cos(angle + radians(arrow_angle)), + end[1] - arrow_length * sin(angle + radians(arrow_angle)), + ) + draw.line([end, left], fill="red", width=4) + draw.line([end, right], fill="red", width=4) + + +def draw_click_indicator(image: Image.Image, x: int, y: int) -> Image.Image: + """ + Draws a click indicator (+ shape with disconnected lines) at (x, y) on the image. + Returns a new image with the click indicator drawn. + + Args: + image: The image to draw the click indicator on. + x: The x coordinate for the click indicator. + y: The y coordinate for the click indicator. + + Returns: + A new image with the click indicator drawn. + """ + line_length = 10 # Length of each line segment + gap = 4 # Gap from center point + line_width = 2 # Thickness of lines + + overlay = image.convert("RGBA").copy() + draw = ImageDraw.Draw(overlay) + + # Draw 4 lines forming a + shape with gaps in the center + # Each line has a white outline and black center for visibility on any background + + # Top line + draw.line( + [(x, y - gap - line_length), (x, y - gap)], fill=(255, 255, 255, 200), width=line_width + 2 + ) # White outline + draw.line( + [(x, y - gap - line_length), (x, y - gap)], fill=(0, 0, 0, 255), width=line_width + ) # Black center + + # Bottom line + draw.line( + [(x, y + gap), (x, y + gap + line_length)], fill=(255, 255, 255, 200), width=line_width + 2 + ) # White outline + draw.line( + [(x, y + gap), (x, y + gap + line_length)], fill=(0, 0, 0, 255), width=line_width + ) # Black center + + # Left line + draw.line( + [(x - gap - line_length, y), (x - gap, y)], fill=(255, 255, 255, 200), width=line_width + 2 + ) # White outline + draw.line( + [(x - gap - line_length, y), (x - gap, y)], fill=(0, 0, 0, 255), width=line_width + ) # Black center + + # Right line + draw.line( + [(x + gap, y), (x + gap + line_length, y)], fill=(255, 255, 255, 200), width=line_width + 2 + ) # White outline + draw.line( + [(x + gap, y), (x + gap + line_length, y)], fill=(0, 0, 0, 255), width=line_width + ) # Black center + + return Image.alpha_composite(image.convert("RGBA"), overlay) + + +def zoom_webpage(page: Page, zoom_factor: float = 1.5): + """ + Zooms the webpage to the specified zoom factor. + + NOTE: Click actions with bid doesn't work properly when zoomed in. + + Args: + page: The Playwright Page object. + zoom_factor: The zoom factor to apply (default is 1.5). + + Returns: + Page: The modified Playwright Page object. + + Raises: + ValueError: If zoom_factor is less than or equal to 0. + """ + + if zoom_factor <= 0: + raise ValueError("Zoom factor must be greater than 0.") + + page.evaluate(f"document.documentElement.style.zoom='{zoom_factor*100}%'") + return page + + +def parse_func_call_string(call_str: str) -> Tuple[Optional[str], Optional[Tuple[list, dict]]]: + """ + Parse a function call string and extract the function name and arguments. + + Args: + call_str (str): A string like "mouse_click(100, 200)" or "mouse_drag_and_drop(x=10, y=20)" + + Returns: + Tuple (func_name, (args, kwargs)), or (None, None) if parsing fails + """ + import ast + + try: + tree = ast.parse(call_str.strip(), mode="eval") + if not isinstance(tree.body, ast.Call): + return None, None + + call_node = tree.body + + # Function name + if isinstance(call_node.func, ast.Name): + func_name = call_node.func.id + else: + return None, None + + # Positional arguments + args = [] + for arg in call_node.args: + try: + args.append(ast.literal_eval(arg)) + except (ValueError, TypeError): + return None, None + + # Keyword arguments + kwargs = {} + for kw in call_node.keywords: + try: + kwargs[kw.arg] = ast.literal_eval(kw.value) + except (ValueError, TypeError): + return None, None + + return func_name, (args, kwargs) + + except (SyntaxError, ValueError, TypeError): + return None, None diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 2cb474e9..92ad25b9 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -9,13 +9,9 @@ from warnings import warn import bgym +from bgym import HighLevelActionSetArgs from browsergym.core.action.base import AbstractActionSet -from browsergym.utils.obs import ( - flatten_axtree_to_str, - flatten_dom_to_str, - overlay_som, - prune_html, -) +from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html from agentlab.llm.llm_utils import ( BaseMessage, @@ -99,7 +95,7 @@ class ObsFlags(Flags): @dataclass class ActionFlags(Flags): - action_set: bgym.HighLevelActionSetArgs = None # should be set by the set_benchmark method + action_set: HighLevelActionSetArgs = None # should be set by the set_benchmark method long_description: bool = True individual_examples: bool = False diff --git a/src/agentlab/agents/generic_agent/agent_configs.py b/src/agentlab/agents/generic_agent/agent_configs.py index 914e3249..f50367d8 100644 --- a/src/agentlab/agents/generic_agent/agent_configs.py +++ b/src/agentlab/agents/generic_agent/agent_configs.py @@ -3,6 +3,7 @@ """ import bgym +from bgym import HighLevelActionSetArgs from agentlab.agents import dynamic_prompting as dp from agentlab.experiments import args @@ -32,7 +33,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=False, ), @@ -80,7 +81,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=False, ), @@ -127,7 +128,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=False, ), @@ -177,7 +178,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=True, ), @@ -232,7 +233,7 @@ filter_visible_elements_only=False, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=False, ), @@ -323,7 +324,7 @@ filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]), ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=args.Choice([["bid"], ["bid", "coord"]]), multiaction=args.Choice([True, False], p=[0.7, 0.3]), ), diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index a65b3eb3..d1f48f76 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -10,9 +10,11 @@ from copy import deepcopy from dataclasses import asdict, dataclass +from functools import partial from warnings import warn import bgym +from bgym import Benchmark from browsergym.experiments.agent import Agent, AgentInfo from agentlab.agents import dynamic_prompting as dp @@ -22,7 +24,6 @@ from agentlab.llm.tracking import cost_tracker_decorator from .generic_agent_prompt import GenericPromptFlags, MainPrompt -from functools import partial @dataclass @@ -37,7 +38,7 @@ def __post_init__(self): except AttributeError: pass - def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): + def set_benchmark(self, benchmark: Benchmark, demo_mode): """Override Some flags based on the benchmark.""" if benchmark.name.startswith("miniwob"): self.flags.obs.use_html = True diff --git a/src/agentlab/agents/generic_agent/reproducibility_agent.py b/src/agentlab/agents/generic_agent/reproducibility_agent.py index bf1f01c9..154aeae5 100644 --- a/src/agentlab/agents/generic_agent/reproducibility_agent.py +++ b/src/agentlab/agents/generic_agent/reproducibility_agent.py @@ -19,6 +19,7 @@ from pathlib import Path import bgym +from bgym import HighLevelActionSetArgs from browsergym.experiments.agent import AgentInfo from bs4 import BeautifulSoup @@ -144,7 +145,7 @@ def _make_backward_compatible(agent_args: GenericAgentArgs): if isinstance(action_set, str): action_set = action_set.split("+") - agent_args.flags.action.action_set = bgym.HighLevelActionSetArgs( + agent_args.flags.action.action_set = HighLevelActionSetArgs( subsets=action_set, multiaction=agent_args.flags.action.multi_actions, ) diff --git a/src/agentlab/agents/tool_use_agent/__init__.py b/src/agentlab/agents/tool_use_agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agentlab/agents/tool_use_agent/hint_db.csv b/src/agentlab/agents/tool_use_agent/hint_db.csv new file mode 100644 index 00000000..86020033 --- /dev/null +++ b/src/agentlab/agents/tool_use_agent/hint_db.csv @@ -0,0 +1,19 @@ +time_stamp,task_name,task_seed,base_llm,agent_name,domain_name,user_name,source,semantic_keys,hint +June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,drop down menu,"some drop down menu will have a list of choice to select from, after typing. Make sure you select an element from that list." +June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Filling up form,Make sure the correct field is activated with a blue rectangle around it before writing into it. Not because the mouse is over it that it is active. +June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Filling up form,GUI elements surrounded by a red rectangle often means there is an error in the content +June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Search results,The scroll bar indicates that there is more than 1 flights available in the search. Make sure to select the one matching the task goal among all possible flights. +June 7,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Filling up form,"If you suspect an error in an element of a form, ""ControlOrMeta+a"" to select all and overwrite the content" +June 9,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Filling up form,"The main mistake is to type in the wrong field. Make sure the correct field is activated by clicking into it and seeing it activated with a blue rectangle before proceeding with writing. Otherwise, it will append to the currently activated field." +June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"For dragging tasks, use only mouse_down, move_mouse, and mouse_up. Do not perform a single continuous drag. Instead, move the mouse in steps using move_mouse, and then release with mouse_up when the target is reached." +June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Avoid dragging items outside the grid. Ensure the drop coordinates are within valid grid boundaries, or the task will fail." +June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Grid items are reactive: dragging an item over another may cause them to swap. Only perform mouse_down when the dragged item is correctly positioned. Intermediate movement helps trigger these reactive behaviors reliably." +June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Use intermediate positions when moving items. This often leads to more stable behavior than dragging directly to the final target." +June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Hovering near a target item can prompt it to move to an adjacent grid cell. Use this behavior to clear the destination before dropping the dragged item" +June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Move items along slightly curved instead of straight lines to better imitate human-like dragging behavior." +June 11,miniwob.drag-items,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Move items along slightly curved instead of straight lines to better imitate human-like dragging behavior." +June 11,miniwob.drag-items,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"For dragging tasks, use only mouse_down, move_mouse, and mouse_up. Do not perform a single continuous drag. Instead, move the mouse in steps using move_mouse, and then release with mouse_up when the target is reached." +June 18,miniwob.count-shape,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Shape and letters size comparison in miniwob,"Shapes or items have different colors and different size. Size is relative to the other objects in the white area and is either ""large"" or ""small"". Shapes that are larger than the average shape or letter are considered ""large"". Others are ""small""." +June 18,miniwob.count-shape,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,communicate answer in miniwob,Answer by clicking one of the buttons describing multiple choices. +June 18,miniwob.count-shape,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Simbols of colors in miniwob,"Colors a distinct in this task, e.g., cyan is not a type of blue. " +June 18,miniwob.form-sequence-2,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Reporting results in miniwob,Make sure to click submit to finish the task. \ No newline at end of file diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py new file mode 100644 index 00000000..86140d02 --- /dev/null +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -0,0 +1,544 @@ +import fnmatch +import json +from abc import ABC, abstractmethod +from copy import copy +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any + +import bgym +import numpy as np +import pandas as pd +from browsergym.core.observation import extract_screenshot +from browsergym.utils.obs import ( + flatten_axtree_to_str, + flatten_dom_to_str, + overlay_som, + prune_html, +) +from PIL import Image + +from agentlab.agents import agent_utils +from agentlab.agents.agent_args import AgentArgs +from agentlab.llm.llm_utils import image_to_png_base64_url +from agentlab.llm.response_api import ( + ClaudeResponseModelArgs, + LLMOutput, + MessageBuilder, + OpenAIChatModelArgs, + OpenAIResponseModelArgs, +) +from agentlab.llm.tracking import cost_tracker_decorator + + +@dataclass +class Block(ABC): + + def _init(self): + """Initialize the block.""" + pass + + def make(self) -> "Block": + """Returns a copy so the init can start adding some stuff to `self` without changing the + original datatclass that should only contain a config. + The aim is avoid having 2 calss definition for each block, e.g. Block and BlockArgs. + + Returns: + Block: A copy of the current block instance with initialization applied. + """ + block = self.__class__(**asdict(self)) + block._init() + return block + + @abstractmethod + def apply(self, llm, messages: list[MessageBuilder], **kwargs): + pass + + +@dataclass +class MsgGroup: + name: str = None + messages: list[MessageBuilder] = field(default_factory=list) + summary: MessageBuilder = None + + +class StructuredDiscussion: + """ + A structured discussion that groups messages into named groups with a potential summary for each group. + + When the discussion is flattened, only the last `keep_last_n_obs` groups are kept in the final list, + the other groups are replaced by their summaries if they have one. + """ + + def __init__(self, keep_last_n_obs=None): + self.groups: list[MsgGroup] = [] + self.keep_last_n_obs: int | None = keep_last_n_obs + + def append(self, message: MessageBuilder): + """Append a message to the last group.""" + self.groups[-1].messages.append(message) + + def new_group(self, name: str = None): + """Start a new group of messages.""" + if name is None: + name = f"group_{len(self.groups)}" + self.groups.append(MsgGroup(name)) + + def flatten(self) -> list[MessageBuilder]: + """Flatten the groups into a single list of messages.""" + + keep_last_n_obs = self.keep_last_n_obs or len(self.groups) + messages = [] + for i, group in enumerate(self.groups): + is_tail = i >= len(self.groups) - keep_last_n_obs + + if not is_tail and group.summary is not None: + messages.append(group.summary) + else: + messages.extend(group.messages) + # Mark all summarized messages for caching + if i == len(self.groups) - keep_last_n_obs: + messages[i].mark_all_previous_msg_for_caching() + return messages + + def set_last_summary(self, summary: MessageBuilder): + # append None to summaries until we reach the current group index + self.groups[-1].summary = summary + + def get_last_summary(self) -> MessageBuilder | None: + """Get the last summary message.""" + if len(self.groups) == 0: + return None + return self.groups[-1].summary + + def is_goal_set(self) -> bool: + """Check if the goal is set in the first group.""" + return len(self.groups) > 0 + + +SYS_MSG = """You are a web agent. Based on the observation, you will decide which action to take to accomplish your goal. +You strive for excellence and need to be as meticulous as possible. Make sure to explore when not sure. +""" + + +@dataclass +class Goal(Block): + """Block to add the goal to the messages.""" + + goal_as_system_msg: bool = True + + def apply(self, llm, discussion: StructuredDiscussion, obs: dict) -> dict: + system_message = llm.msg.system().add_text(SYS_MSG) + discussion.append(system_message) + + if self.goal_as_system_msg: + goal_message = llm.msg.system() + else: + goal_message = llm.msg.user() + + goal_message.add_text("# Goal:\n") + for content in obs["goal_object"]: + if content["type"] == "text": + goal_message.add_text(content["text"]) + elif content["type"] == "image_url": + goal_message.add_image(content["image_url"]) + discussion.append(goal_message) + + +AXTREE_NOTE = """ +AXTree extracts most of the interactive elements of the DOM in a tree structure. It may also contain information that is not visible in the screenshot. +A line starting with [bid] is a node in the AXTree. It is a unique alpha-numeric identifier to be used when calling tools. +""" + + +@dataclass +class Obs(Block): + """Block to add the observation to the messages.""" + + use_last_error: bool = True + use_screenshot: bool = True + use_axtree: bool = False + use_dom: bool = False + use_som: bool = False + use_tabs: bool = False + add_mouse_pointer: bool = False + use_zoomed_webpage: bool = False + + def apply( + self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput + ) -> dict: + + if last_llm_output.tool_calls is None: + obs_msg = llm.msg.user() # type: MessageBuilder + else: + obs_msg = llm.msg.tool(last_llm_output.raw_response) # type: MessageBuilder + + if self.use_last_error: + if obs["last_action_error"] != "": + obs_msg.add_text(f"Last action error:\n{obs['last_action_error']}") + + if self.use_screenshot: + + if self.use_som: + screenshot = obs["screenshot_som"] + else: + screenshot = obs["screenshot"] + + if self.add_mouse_pointer: + # TODO this mouse pointer should be added at the browsergym level + screenshot = np.array( + agent_utils.add_mouse_pointer_from_action( + Image.fromarray(obs["screenshot"]), obs["last_action"] + ) + ) + + obs_msg.add_image(image_to_png_base64_url(screenshot)) + if self.use_axtree: + obs_msg.add_text(f"AXTree:\n{AXTREE_NOTE}\n{obs['axtree_txt']}") + if self.use_dom: + obs_msg.add_text(f"DOM:\n{obs['pruned_html']}") + if self.use_tabs: + obs_msg.add_text(_format_tabs(obs)) + + discussion.append(obs_msg) + return obs_msg + + +def _format_tabs(obs): + """Format the open tabs in a llm-readable way.""" + prompt_pieces = ["Currently open tabs:"] + for page_index, (page_url, page_title) in enumerate( + zip(obs["open_pages_urls"], obs["open_pages_titles"]) + ): + active_or_not = " (active tab)" if page_index == obs["active_page_index"] else "" + prompt_piece = f"""\ +Tab {page_index}{active_or_not}: + Title: {page_title} + URL: {page_url} +""" + prompt_pieces.append(prompt_piece) + return "\n".join(prompt_pieces) + + +@dataclass +class GeneralHints(Block): + + use_hints: bool = True + + def apply(self, llm, discussion: StructuredDiscussion) -> dict: + if not self.use_hints: + return + + hints = [] + + hints.append( + """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n""" + ) + + discussion.append(llm.msg.user().add_text("\n".join(hints))) + + +@dataclass +class Summarizer(Block): + """Block to summarize the last action and the current state of the environment.""" + + do_summary: bool = False + high_details: bool = True + + def apply(self, llm, discussion: StructuredDiscussion) -> dict: + if not self.do_summary: + return + + msg = llm.msg.user().add_text("""Summarize\n""") + + discussion.append(msg) + # TODO need to make sure we don't force tool use here + summary_response = llm(messages=discussion.flatten(), tool_choice="none") + + summary_msg = llm.msg.assistant().add_text(summary_response.think) + discussion.append(summary_msg) + discussion.set_last_summary(summary_msg) + return summary_msg + + def apply_init(self, llm, discussion: StructuredDiscussion) -> dict: + """Initialize the summarizer block.""" + if not self.do_summary: + return + + system_msg = llm.msg.system() + if self.high_details: + # Add a system message to the LLM to indicate that it should summarize + system_msg.add_text( + """# Summarizer instructions:\nWhen asked to summarize, do the following: +1) Summarize the effect of the last action, with attention to details. +2) Give a semantic description of the current state of the environment, with attention to details. If there was a repeating mistake, mention the cause of it. +3) Reason about the overall task at a high level. +4) What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none. +5) Reason about the next action to take, based on the current state and the goal. +""" + ) + else: + system_msg.add_text( + """When asked to summarize, give a semantic description of the current state of the environment.""" + ) + discussion.append(system_msg) + + +@dataclass +class TaskHint(Block): + use_task_hint: bool = True + hint_db_rel_path: str = "hint_db.csv" + + def _init(self): + """Initialize the block.""" + hint_db_path = Path(__file__).parent / self.hint_db_rel_path + self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + + def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: + if not self.use_task_hint: + return + + task_hints = self.hint_db[ + self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + ] + + hints = [] + for hint in task_hints["hint"]: + hint = hint.strip() + if hint: + hints.append(f"- {hint}") + + if len(hints) > 0: + hints_str = ( + "# Hints:\nHere are some hints for the task you are working on:\n" + + "\n".join(hints) + ) + msg = llm.msg.user().add_text(hints_str) + + discussion.append(msg) + + +class ToolCall(Block): + + def __init__(self, tool_server): + self.tool_server = tool_server + + def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict: + # build the message by adding components to obs + response: LLMOutput = llm(messages=self.messages) + + messages.append(response.assistant_message) # this is tool call + + tool_answer = self.tool_server.call_tool(response) + tool_msg = llm.msg.tool() # type: MessageBuilder + tool_msg.add_tool_id(response.last_computer_call_id) + tool_msg.update_last_raw_response(response) + tool_msg.add_text(str(tool_answer)) + messages.append(tool_msg) + + +@dataclass +class PromptConfig: + tag_screenshot: bool = True # Whether to tag the screenshot with the last action. + goal: Goal = None + obs: Obs = None + summarizer: Summarizer = None + general_hints: GeneralHints = None + task_hint: TaskHint = None + keep_last_n_obs: int = 1 + multiaction: bool = False + action_subsets: tuple[str] = field(default_factory=lambda: ("coord",)) + + +@dataclass +class ToolUseAgentArgs(AgentArgs): + model_args: OpenAIResponseModelArgs = None + config: PromptConfig = None + use_raw_page_output: bool = False # This attribute is used in loop.py to setup the env. + + def __post_init__(self): + try: + self.agent_name = f"ToolUse-{self.model_args.model_name}".replace("/", "_") + except AttributeError: + pass + + def make_agent(self) -> bgym.Agent: + if self.config is None: + self.config = DEFAULT_PROMPT_CONFIG + return ToolUseAgent( + model_args=self.model_args, + config=self.config, + ) + + def prepare(self): + return self.model_args.prepare_server() + + def close(self): + return self.model_args.close_server() + + +class ToolUseAgent(bgym.Agent): + def __init__( + self, + model_args: OpenAIResponseModelArgs, + config: PromptConfig = None, + ): + self.model_args = model_args + self.config = config + self.action_set = bgym.HighLevelActionSet( + self.config.action_subsets, multiaction=self.config.multiaction + ) + self.tools = self.action_set.to_tool_description(api=model_args.api) + + self.call_ids = [] + + self.llm = model_args.make_model(extra_kwargs={"tools": self.tools}) + self.msg_builder = model_args.get_message_builder() + self.llm.msg = self.msg_builder + + self.task_hint = self.config.task_hint.make() + self.obs_block = self.config.obs.make() + + self.discussion = StructuredDiscussion(self.config.keep_last_n_obs) + self.last_response: LLMOutput = LLMOutput() + self._responses: list[LLMOutput] = [] + + def obs_preprocessor(self, obs): + obs = copy(obs) + + page = obs.pop("page", None) + if page is not None: + obs["screenshot"] = extract_screenshot(page) + else: + if self.config.obs.use_dom: + obs["dom_txt"] = flatten_dom_to_str( + obs["dom_object"], + extra_properties=obs["extra_element_properties"], + ) + obs["pruned_html"] = prune_html(obs["dom_txt"]) + + if self.config.obs.use_axtree: + obs["axtree_txt"] = flatten_axtree_to_str( + obs["axtree_object"], + extra_properties=obs["extra_element_properties"], + ) + + if self.config.obs.use_som: + obs["screenshot_som"] = overlay_som( + obs["screenshot"], extra_properties=obs["extra_element_properties"] + ) + if self.config.obs.use_zoomed_webpage: + pass + + return obs + + def set_task_name(self, task_name: str): + """Cheater function that is supposed to be called by loop.py before callling get_action""" + self.task_name = task_name + + @cost_tracker_decorator + def get_action(self, obs: Any) -> float: + self.llm.reset_stats() + if not self.discussion.is_goal_set(): + self.discussion.new_group("goal") + self.config.goal.apply(self.llm, self.discussion, obs) + self.config.summarizer.apply_init(self.llm, self.discussion) + self.config.general_hints.apply(self.llm, self.discussion) + self.task_hint.apply(self.llm, self.discussion, self.task_name) + + self.discussion.new_group() + + self.obs_block.apply(self.llm, self.discussion, obs, last_llm_output=self.last_response) + + self.config.summarizer.apply(self.llm, self.discussion) + + messages = self.discussion.flatten() + response: LLMOutput = self.llm( + messages=messages, + tool_choice="any", + cache_tool_definition=True, + cache_complete_prompt=False, + use_cache_breakpoints=True, + ) + + action = response.action + think = response.think + last_summary = self.discussion.get_last_summary() + if last_summary is not None: + think = last_summary.content[0]["text"] + "\n" + think + + self.discussion.new_group() + self.discussion.append(response.tool_calls) + + self.last_response = response + self._responses.append(response) # may be useful for debugging + # self.messages.append(response.assistant_message) # this is tool call + + tools_str = json.dumps(self.tools, indent=2) + tools_msg = MessageBuilder("tool_description").add_text(tools_str) + + # Adding these extra messages to visualize in gradio + messages.insert(0, tools_msg) # insert at the beginning of the messages + messages.append(response.tool_calls) + + agent_info = bgym.AgentInfo( + think=think, + chat_messages=messages, + stats=self.llm.stats.stats_dict, + ) + return action, agent_info + + +OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs( + model_name="gpt-4.1", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + +OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs( + model_name="gpt-4o-2024-08-06", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + +CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs( + model_name="claude-3-7-sonnet-20250219", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + + +DEFAULT_PROMPT_CONFIG = PromptConfig( + tag_screenshot=True, + goal=Goal(goal_as_system_msg=True), + obs=Obs( + use_last_error=True, + use_screenshot=True, + use_axtree=True, + use_dom=False, + use_som=False, + use_tabs=False, + ), + summarizer=Summarizer(do_summary=True), + general_hints=GeneralHints(use_hints=False), + task_hint=TaskHint(use_task_hint=True), + keep_last_n_obs=None, # keep only the last observation in the discussion + multiaction=False, # whether to use multi-action or not + # action_subsets=("bid",), + action_subsets=("coord"), + # action_subsets=("coord", "bid"), +) + +AGENT_CONFIG = ToolUseAgentArgs( + model_args=CLAUDE_MODEL_CONFIG, + config=DEFAULT_PROMPT_CONFIG, +) diff --git a/src/agentlab/agents/visual_agent/agent_configs.py b/src/agentlab/agents/visual_agent/agent_configs.py index 404afaec..df8d819b 100644 --- a/src/agentlab/agents/visual_agent/agent_configs.py +++ b/src/agentlab/agents/visual_agent/agent_configs.py @@ -1,9 +1,11 @@ +import bgym +from bgym import HighLevelActionSetArgs + +import agentlab.agents.dynamic_prompting as dp from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT from .visual_agent import VisualAgentArgs from .visual_agent_prompts import PromptFlags -import agentlab.agents.dynamic_prompting as dp -import bgym # the other flags are ignored for this agent. DEFAULT_OBS_FLAGS = dp.ObsFlags( @@ -16,7 +18,7 @@ ) DEFAULT_ACTION_FLAGS = dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]), + action_set=HighLevelActionSetArgs(subsets=["coord"]), long_description=True, individual_examples=False, ) diff --git a/src/agentlab/agents/visual_agent/visual_agent.py b/src/agentlab/agents/visual_agent/visual_agent.py index 8efee11d..d76cedf3 100644 --- a/src/agentlab/agents/visual_agent/visual_agent.py +++ b/src/agentlab/agents/visual_agent/visual_agent.py @@ -11,6 +11,7 @@ from dataclasses import asdict, dataclass import bgym +from bgym import Benchmark from browsergym.experiments.agent import Agent, AgentInfo from agentlab.agents import dynamic_prompting as dp @@ -19,7 +20,7 @@ from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator -from .visual_agent_prompts import PromptFlags, MainPrompt +from .visual_agent_prompts import MainPrompt, PromptFlags @dataclass @@ -34,7 +35,7 @@ def __post_init__(self): except AttributeError: pass - def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): + def set_benchmark(self, benchmark: Benchmark, demo_mode): """Override Some flags based on the benchmark.""" self.flags.obs.use_tabs = benchmark.is_multi_tab diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 6154007e..e09b4af8 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -14,8 +14,9 @@ from attr import dataclass from langchain.schema import BaseMessage, HumanMessage from openai import OpenAI -from PIL import Image, ImageDraw +from PIL import Image +from agentlab.agents import agent_utils from agentlab.analyze import inspect_results from agentlab.experiments.exp_utils import RESULTS_DIR from agentlab.experiments.loop import ExpResult, StepInfo @@ -23,6 +24,7 @@ from agentlab.llm.chat_api import make_system_message, make_user_message from agentlab.llm.llm_utils import BaseMessage as AgentLabBaseMessage from agentlab.llm.llm_utils import Discussion +from agentlab.llm.response_api import MessageBuilder select_dir_instructions = "Select Experiment Directory" AGENT_NAME_KEY = "agent.agent_name" @@ -497,6 +499,8 @@ def run_gradio(results_dir: Path): # keep track of active tab tabs.select(tab_select) + demo.load(fn=refresh_exp_dir_choices, inputs=exp_dir_choice, outputs=exp_dir_choice) + demo.queue() do_share = os.getenv("AGENTXRAY_SHARE_GRADIO", "false").lower() == "true" @@ -530,47 +534,12 @@ def wrapper(*args, **kwargs): return decorator -def tag_screenshot_with_action(screenshot: Image, action: str) -> Image: - """ - If action is a coordinate action, try to render it on the screenshot. - - e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot - - Args: - screenshot: The screenshot to tag. - action: The action to tag the screenshot with. - - Returns: - The tagged screenshot. - - Raises: - ValueError: If the action parsing fails. - """ - if action.startswith("mouse_click"): - try: - coords = action[action.index("(") + 1 : action.index(")")].split(",") - coords = [c.strip() for c in coords] - if len(coords) not in [2, 3]: - raise ValueError(f"Invalid coordinate format: {coords}") - if coords[0].startswith("x="): - coords[0] = coords[0][2:] - if coords[1].startswith("y="): - coords[1] = coords[1][2:] - x, y = float(coords[0].strip()), float(coords[1].strip()) - draw = ImageDraw.Draw(screenshot) - radius = 5 - draw.ellipse( - (x - radius, y - radius, x + radius, y + radius), fill="red", outline="red" - ) - except (ValueError, IndexError) as e: - warning(f"Failed to parse action '{action}': {e}") - return screenshot - - def update_screenshot(som_or_not: str): global info action = info.exp_result.steps_info[info.step].action - return tag_screenshot_with_action(get_screenshot(info, som_or_not=som_or_not), action) + return agent_utils.tag_screenshot_with_action( + get_screenshot(info, som_or_not=som_or_not), action + ) def get_screenshot(info: Info, step: int = None, som_or_not: str = "Raw Screenshots"): @@ -589,7 +558,9 @@ def update_screenshot_pair(som_or_not: str): s2 = get_screenshot(info, info.step + 1, som_or_not) if s1 is not None: - s1 = tag_screenshot_with_action(s1, info.exp_result.steps_info[info.step].action) + s1 = agent_utils.tag_screenshot_with_action( + s1, info.exp_result.steps_info[info.step].action + ) return s1, s2 @@ -627,12 +598,76 @@ def update_axtree(): return get_obs(key="axtree_txt", default="No AXTree") +def dict_to_markdown(d: dict): + """ + Convert a dictionary to a clean markdown representation, recursively. + + Args: + d (dict): A dictionary where keys are strings and values can be strings, + lists of dictionaries, or nested dictionaries. + + Returns: + str: A markdown-formatted string representation of the dictionary. + """ + if not isinstance(d, dict): + warning(f"Expected dict, got {type(d)}") + return repr(d) + if not d: + return "No Data" + res = "" + for k, v in d.items(): + if isinstance(v, dict): + res += f"### {k}\n{dict_to_markdown(v)}\n" + elif isinstance(v, list): + res += f"### {k}\n" + for i, item in enumerate(v): + if isinstance(item, dict): + res += f"#### Item {i}\n{dict_to_markdown(item)}\n" + else: + res += f"- {item}\n" + else: + res += f"- **{k}**: {v}\n" + return res + + +def dict_msg_to_markdown(d: dict): + if "role" not in d: + return dict_to_markdown(d) + parts = [] + for item in d["content"]: + + if hasattr(item, "dict"): + item = item.dict() + + match item["type"]: + case "image": + parts.append(f"![Image]({item['image']})") + case "text": + parts.append(f"\n```\n{item['text']}\n```\n") + case "tool_use": + tool_use = f"Tool Use: {item['name']} {item['input']} (id = {item['id']})" + parts.append(f"\n```\n{tool_use}\n```\n") + case _: + parts.append(f"\n```\n{str(item)}\n```\n") + + markdown = f"### {d["role"].capitalize()}\n" + markdown += "\n".join(parts) + return markdown + + def update_chat_messages(): global info agent_info = info.exp_result.steps_info[info.step].agent_info chat_messages = agent_info.get("chat_messages", ["No Chat Messages"]) if isinstance(chat_messages, Discussion): return chat_messages.to_markdown() + + if isinstance(chat_messages, list) and isinstance(chat_messages[0], MessageBuilder): + chat_messages = [ + m.to_markdown() if isinstance(m, MessageBuilder) else dict_msg_to_markdown(m) + for m in chat_messages + ] + return "\n\n".join(chat_messages) messages = [] # TODO(ThibaultLSDC) remove this at some point for i, m in enumerate(chat_messages): if isinstance(m, BaseMessage): # TODO remove once langchain is deprecated @@ -968,9 +1003,16 @@ def get_agent_report(result_df: pd.DataFrame): def update_global_stats(): - stats = inspect_results.global_report(info.result_df, reduce_fn=inspect_results.summarize_stats) - stats.reset_index(inplace=True) - return stats + try: + stats = inspect_results.global_report( + info.result_df, reduce_fn=inspect_results.summarize_stats + ) + stats.reset_index(inplace=True) + return stats + + except Exception as e: + warning(f"Error while updating global stats: {e}") + return None def update_error_report(): diff --git a/src/agentlab/analyze/inspect_results.py b/src/agentlab/analyze/inspect_results.py index 7d043fce..a6684835 100644 --- a/src/agentlab/analyze/inspect_results.py +++ b/src/agentlab/analyze/inspect_results.py @@ -251,7 +251,11 @@ def summarize(sub_df): ) else: err = sub_df["err_msg"].notnull() - n_completed = (err | sub_df["truncated"] | sub_df["terminated"]).sum() + n_completed = err.copy() + for col in ["truncated", "terminated"]: + if col in sub_df: + n_completed = n_completed | sub_df[col] + n_completed = n_completed.sum() if n_completed == 0: return None @@ -271,6 +275,11 @@ def summarize(sub_df): ) if "stats.cum_cost" in sub_df: record["cum_cost"] = sub_df["stats.cum_cost"].sum(skipna=True).round(4) + if "stats.cum_effective_cost" in sub_df: + record["cum_effective_cost"] = ( + sub_df["stats.cum_effective_cost"].sum(skipna=True).round(4) + ) + record.pop("cum_cost", None) return pd.Series(record) @@ -280,7 +289,12 @@ def summarize_stats(sub_df): # make sure there are completed runs err = sub_df["err_msg"].notnull() - n_completed = (err | sub_df["truncated"] | sub_df["terminated"]).sum() + n_completed = err.copy() + for col in ["truncated", "terminated"]: + if col in sub_df: + n_completed = n_completed | sub_df[col] + n_completed = n_completed.sum() + if n_completed == 0: return None diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index 5a9580ca..ac77e01f 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -45,7 +45,9 @@ class EnvArgs(DataClassJsonMixin): storage_state: Optional[str | Path | dict] = None task_kwargs: Optional[dict] = None # use default value from BrowserGym - def make_env(self, action_mapping, exp_dir, exp_task_kwargs: dict = {}): + def make_env( + self, action_mapping, exp_dir, exp_task_kwargs: dict = {}, use_raw_page_output=True + ): """ Instantiates the BrowserGym environment corresponding to the arguments (with some tweaks). @@ -53,6 +55,7 @@ def make_env(self, action_mapping, exp_dir, exp_task_kwargs: dict = {}): action_mapping: overrides the action mapping of the environment. exp_dir: will set some environment parameters (e.g., record_video_dir) with respect to the directory where the experiment is running. exp_task_kwargs: use with caution! Will override task parameters to experiment-specific values. Useful to set different server configs for different experiments, or output file paths within the experiment's folder (e.g., assistantbench). + use_raw_page_output: if True, the environment will also return raw page output in the observation. Returns: env: the gym environment. @@ -85,6 +88,7 @@ def make_env(self, action_mapping, exp_dir, exp_task_kwargs: dict = {}): headless=self.headless, wait_for_user_message=self.wait_for_user_message, action_mapping=action_mapping, # action mapping is provided by the agent + use_raw_page_output=use_raw_page_output, **extra_kwargs, ) @@ -233,10 +237,6 @@ def make_stats(self): stats = {} stats.update(self.agent_info.pop("stats", {})) - messages = self.agent_info.get("chat_messages", None) - if messages is not None: - stats["n_token_agent_messages"] = count_messages_token(messages) - t = self.profiling stats["step_elapsed"] = t.env_stop - t.env_start stats["agent_elapsed"] = t.agent_stop - t.agent_start @@ -396,11 +396,15 @@ def run(self): try: logger.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}") agent = self.agent_args.make_agent() + if hasattr(agent, "set_task_name"): + agent.set_task_name(self.env_args.task_name) + logger.debug("Agent created.") env = self.env_args.make_env( action_mapping=agent.action_set.to_python_code, exp_dir=self.exp_dir, + use_raw_page_output=getattr(self.agent_args, "use_raw_page_output", False), ) logger.debug("Environment created.") @@ -875,7 +879,7 @@ def _move_old_exp(exp_dir): def _get_env_name(task_name: str): """Register tasks if needed (lazy import) and return environment name.""" - # lazy benchmark import + # lazy import if task_name.startswith("miniwob"): import browsergym.miniwob elif task_name.startswith("workarena"): diff --git a/src/agentlab/experiments/reproducibility_util.py b/src/agentlab/experiments/reproducibility_util.py index 01f3fdc9..0b0f91b4 100644 --- a/src/agentlab/experiments/reproducibility_util.py +++ b/src/agentlab/experiments/reproducibility_util.py @@ -8,6 +8,7 @@ import bgym import pandas as pd +from bgym import Benchmark from git import InvalidGitRepositoryError, Repo from git.config import GitConfigParser @@ -20,7 +21,7 @@ def _get_repo(module): def _get_benchmark_version( - benchmark: bgym.Benchmark, allow_bypass_benchmark_version: bool = False + benchmark: Benchmark, allow_bypass_benchmark_version: bool = False ) -> str: benchmark_name = benchmark.name @@ -178,7 +179,7 @@ def _get_git_info(module, changes_white_list=()) -> tuple[str, list[tuple[str, P def get_reproducibility_info( agent_names: str | list[str], - benchmark: bgym.Benchmark, + benchmark: Benchmark, study_id: str = "", comment=None, changes_white_list=( # Files that are often modified during experiments but do not affect reproducibility diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index 7de3db98..810b8bc2 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -6,13 +6,13 @@ import uuid from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import datetime from multiprocessing import Manager, Pool, Queue from pathlib import Path import bgym -from bgym import Benchmark +from bgym import DEFAULT_BENCHMARKS, Benchmark from slugify import slugify from agentlab.agents.agent_args import AgentArgs @@ -32,7 +32,7 @@ def make_study( agent_args: list[AgentArgs] | AgentArgs, - benchmark: bgym.Benchmark | str, + benchmark: Benchmark | str, logging_level=logging.WARNING, logging_level_stdout=logging.WARNING, suffix="", @@ -47,8 +47,8 @@ def make_study( The agent configuration(s) to run. *IMPORTANT*: these objects will be pickled and unpickled. Make sure they are imported from a package that is accessible from PYTHONPATH. Otherwise, it won't load in agentlab-xray. - benchmark: bgym.Benchmark | str - The benchmark to run the agents on. See bgym.DEFAULT_BENCHMARKS for the main ones. You + benchmark: Benchmark | str + The benchmark to run the agents on. See DEFAULT_BENCHMARKS for the main ones. You can also make your own by modifying an existing one. logging_level: int The logging level for file log. @@ -89,7 +89,7 @@ def make_study( agent_args = [agent_args] if isinstance(benchmark, str): - benchmark = bgym.DEFAULT_BENCHMARKS[benchmark.lower()]() + benchmark = DEFAULT_BENCHMARKS[benchmark.lower()]() if len(agent_args) > 1 and ("webarena" in benchmark.name or parallel_servers is not None): logger.warning( @@ -184,8 +184,8 @@ class Study(AbstractStudy): The agent configuration(s) to run. *IMPORTANT*: these objects will be pickled and unpickled. Make sure they are imported from a package that is accessible from PYTHONPATH. Otherwise, it won't load in agentlab-xray. - benchmark: bgym.Benchmark | str - The benchmark to run the agents on. See bgym.DEFAULT_BENCHMARKS for the main ones. You + benchmark: Benchmark | str + The benchmark to run the agents on. See DEFAULT_BENCHMARKS for the main ones. You can also make your own by modifying an existing one. dir: Path The directory where the study will be saved. If None, a directory will be created in @@ -241,7 +241,10 @@ def __post_init__(self): """Initialize the study. Set the uuid, and generate the exp_args_list.""" self.uuid = uuid.uuid4() if isinstance(self.benchmark, str): - self.benchmark = bgym.DEFAULT_BENCHMARKS[self.benchmark.lower()]() + self.benchmark = DEFAULT_BENCHMARKS[self.benchmark.lower()]() + + self.benchmark.env_args_list = _convert_env_args(self.benchmark.env_args_list) + if isinstance(self.dir, str): self.dir = Path(self.dir) self.make_exp_args_list() @@ -328,28 +331,31 @@ def run( self._run(n_jobs, parallel_backend, strict_reproducibility) suffix = f"trial_{i + 1}_of_{n_relaunch}" - _, summary_df, _ = self.get_results(suffix=suffix) + _, summary_df, error_report = self.get_results(suffix=suffix) logger.info("\n" + str(summary_df)) n_incomplete, n_error = self.find_incomplete(include_errors=relaunch_errors) if n_error / n_exp > 0.3: - logger.warning("More than 30% of the experiments errored. Stopping the study.") - return + logger.warning("More than 30% of the experiments errored. Stopping the retries.") + break if last_error_count is not None and n_error >= last_error_count: logger.warning( - "Last trial did not reduce the number of errors. Stopping the study." + "Last trial did not reduce the number of errors. Stopping the retries." ) - return + break if n_incomplete == 0: logger.info(f"Study {self.name} finished.") - return + break - logger.warning( - f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments." - ) + logger.info("# Error Report:\n-------------\n\n" + error_report) + + if n_incomplete != 0: + logger.warning( + f"Study {self.name} did not finish after {n_relaunch} trials. There are {n_incomplete} incomplete experiments." + ) def _run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False): """Run all experiments in the study in parallel when possible. @@ -358,7 +364,7 @@ def _run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False n_jobs: int Number of parallel jobs. parallel_backend: str - Parallel backend to use. Either "joblib", "dask" or "sequential". + Parallel backend to use. Either "joblib", "ray" or "sequential". strict_reproducibility: bool If True, all modifications have to be committed before running the experiments. Also, if relaunching a study, it will not be possible if the code has changed. @@ -436,7 +442,7 @@ def load_most_recent(root_dir: Path = None, contains=None) -> "Study": def agents_on_benchmark( self, agents: list[AgentArgs] | AgentArgs, - benchmark: bgym.Benchmark, + benchmark: Benchmark, demo_mode=False, logging_level: int = logging.INFO, logging_level_stdout: int = logging.INFO, @@ -447,7 +453,7 @@ def agents_on_benchmark( Args: agents: list[AgentArgs] | AgentArgs The agent configuration(s) to run. - benchmark: bgym.Benchmark + benchmark: Benchmark The benchmark to run the agents on. demo_mode: bool If True, the experiments will be run in demo mode. @@ -719,6 +725,35 @@ def set_demo_mode(env_args_list: list[EnvArgs]): env_args.slow_mo = 1000 +def _convert_env_args(env_args_list): + """Return a list where every element is the *new* EnvArgs. + + For backward compatibility, we need to convert the old EnvArgs to the new one. + + Args: + env_args_list (list): list of EnvArgs objects to convert + + Returns: + list: list of converted EnvArgs objects + + Raises: + TypeError: If an element in env_args_list is not of expected type. + """ + from bgym import EnvArgs as BGymEnvArgs + + new_list = [] + for ea in env_args_list: + # already new → keep as‑is + if isinstance(ea, EnvArgs): + new_list.append(ea) + # old → convert + elif isinstance(ea, BGymEnvArgs): + new_list.append(EnvArgs(**asdict(ea))) + else: + raise TypeError(f"Unexpected type: {type(ea)}") + return new_list + + # def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark): # if benchmark.name.startswith("visualwebarena"): # sequential_subset = benchmark.subset_from_glob("requires_reset", "True") diff --git a/src/agentlab/experiments/view_dep_graph.py b/src/agentlab/experiments/view_dep_graph.py index abbf7f87..2f4058ae 100644 --- a/src/agentlab/experiments/view_dep_graph.py +++ b/src/agentlab/experiments/view_dep_graph.py @@ -2,11 +2,12 @@ etc. You may have to detust it to make it work for you.""" import math + import bgym import matplotlib.pyplot as plt - import networkx as nx import numpy as np +from bgym import DEFAULT_BENCHMARKS def clean_dict(dependency_dict: dict[str, list[str]]) -> dict[str, list[str]]: @@ -308,8 +309,8 @@ def compress_chains(G): return G_compressed -# benchmark = bgym.DEFAULT_BENCHMARKS["webarena"]() -benchmark = bgym.DEFAULT_BENCHMARKS["visualwebarena"]() +# benchmark = DEFAULT_BENCHMARKS["webarena"]() +benchmark = DEFAULT_BENCHMARKS["visualwebarena"]() dep_graph = benchmark.dependency_graph_over_tasks() dep_graph = clean_dict(dep_graph) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 2536200e..5a68c16f 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -11,7 +11,9 @@ from typing import TYPE_CHECKING, Any, Union from warnings import warn +import anthropic import numpy as np +import openai import tiktoken import yaml from langchain.schema import BaseMessage @@ -90,6 +92,332 @@ def retry( raise ParseError(f"Could not parse a valid value after {n_retry} retries.") +def generic_call_api_with_retries( + client_function, + api_params, + is_response_valid_fn, + rate_limit_exceptions, + api_error_exceptions, + get_status_code_fn=None, + max_retries=10, + initial_retry_delay_seconds=20, + max_retry_delay_seconds=60 * 5, +): + """ + Makes an API call with retries for transient failures, rate limiting, + and responses deemed invalid by a custom validation function. + (Refactored for improved readability with helper functions) + + Args: + client_function: The API client function to call. + api_params: Parameters to pass to the client function. + is_response_valid_fn: Function to validate if the response is valid. + rate_limit_exceptions: Tuple of exception types for rate limiting. + api_error_exceptions: Tuple of exception types for API errors. + get_status_code_fn: Optional function to extract status code from exceptions. + max_retries: Maximum number of retry attempts. + initial_retry_delay_seconds: Initial delay between retries in seconds. + max_retry_delay_seconds: Maximum delay between retries in seconds. + + Returns: + The API response if successful. + + Raises: + Exception: For unexpected errors that are immediately re-raised. + RuntimeError: If API call fails after maximum retries. + """ + + def _calculate_delay( + current_attempt, initial_delay, max_delay, is_first_attempt_for_type=False + ): + """Calculates exponential backoff delay.""" + # For invalid response content (not an exception), the first "attempt" at retrying this specific issue + # might use a slightly different delay calculation if desired (e.g. attempt-1 for the exponent). + # For exceptions, the attempt number directly applies. + # Here, we use 'current_attempt' for exception-driven retries, + # and 'current_attempt -1' for the first retry due to invalid content (is_first_attempt_for_type). + if is_first_attempt_for_type: # First retry due to invalid content + # The first retry after an invalid response (attempt 1 for this *type* of failure) + effective_attempt = current_attempt - 1 # Use 0 for the first exponent + else: # Retries due to exceptions or subsequent invalid content retries + effective_attempt = current_attempt # Use current_attempt for exponent + + # Ensure effective_attempt for exponent is at least 0 + exponent_attempt = max( + 0, effective_attempt if not is_first_attempt_for_type else current_attempt - 1 + ) + + return min(initial_delay * (2**exponent_attempt), max_delay) + + def _handle_invalid_response_content(attempt): + logging.warning( + f"[Attempt {attempt}/{max_retries}] API response deemed invalid by validation function. Retrying after delay..." + ) + if attempt < max_retries: + # For the first retry due to invalid content, use attempt-1 for exponent + delay = _calculate_delay( + attempt, + initial_retry_delay_seconds, + max_retry_delay_seconds, + is_first_attempt_for_type=True, + ) + logging.debug(f"Sleeping for {delay:.2f} seconds due to invalid response content.") + time.sleep(delay) + return True # Indicate retry + return False # Max retries reached for this path + + def _handle_rate_limit_error(e, attempt): + logging.warning( + f"[Attempt {attempt}/{max_retries}] Rate limit error: {e}. Retrying after delay..." + ) + if attempt < max_retries: + delay = _calculate_delay(attempt, initial_retry_delay_seconds, max_retry_delay_seconds) + logging.debug(f"Sleeping for {delay:.2f} seconds due to rate limit.") + time.sleep(delay) + return True # Indicate retry + return False # Max retries reached for this path + + def _handle_api_error(e, attempt): + logging.error(f"[Attempt {attempt}/{max_retries}] APIError: {e}") + status_code = None + if get_status_code_fn: + try: + status_code = get_status_code_fn(e) + except Exception as ex_status_fn: + logging.warning( + f"Could not get status code from exception {type(e)} using get_status_code_fn: {ex_status_fn}" + ) + + if status_code == 429 or (status_code and status_code >= 500): + log_msg = "Rate limit (429)" if status_code == 429 else f"Server error ({status_code})" + logging.warning(f"{log_msg} indicated by status code. Retrying after delay...") + if attempt < max_retries: + delay = _calculate_delay( + attempt, initial_retry_delay_seconds, max_retry_delay_seconds + ) + logging.debug( + f"Sleeping for {delay:.2f} seconds due to API error status {status_code}." + ) + time.sleep(delay) + return True # Indicate retry + return False # Max retries reached for this path + else: + logging.error( + f"Non-retriable or unrecognized API error occurred (status: {status_code}). Raising." + ) + raise e # Re-raise non-retriable error + + # Main retry loop + for attempt in range(1, max_retries + 1): + try: + response = client_function(**api_params) + + if is_response_valid_fn(response): + logging.info(f"[Attempt {attempt}/{max_retries}] API call succeeded.") + return response + else: + if _handle_invalid_response_content(attempt): + continue + else: # Max retries reached after invalid content + break + + except rate_limit_exceptions as e: + if _handle_rate_limit_error(e, attempt): + continue + else: # Max retries reached after rate limit + break + + except api_error_exceptions as e: + # _handle_api_error will raise if non-retriable, or return True to continue + if _handle_api_error(e, attempt): + continue + else: # Max retries reached for retriable API error + break + + except Exception as e: # Catch-all for truly unexpected errors + logging.exception( + f"[Attempt {attempt}/{max_retries}] Unexpected exception: {e}. Raising." + ) + raise e # Re-raise unexpected errors immediately + + logging.error(f"Exceeded maximum {max_retries} retry attempts. API call failed.") + raise RuntimeError(f"API call failed after {max_retries} retries.") + + +def call_openai_api_with_retries(client_function, api_params, max_retries=10): + """ + Makes an OpenAI API call with retries for transient failures, + rate limiting, and invalid or error-containing responses. + (This is now a wrapper around generic_call_api_with_retries for OpenAI) + + Args: + client_function: The OpenAI API client function to call. + api_params: Parameters to pass to the client function. + max_retries: Maximum number of retry attempts. + + Returns: + The OpenAI API response if successful. + """ + + def is_openai_response_valid(response): + # Check for explicit error field in response object first + if getattr(response, "error", None): + logging.warning(f"OpenAI API response contains an error attribute: {response.error}") + return False # Treat as invalid for retry purposes + if hasattr(response, "choices") and response.choices: # Chat Completion API + return True + if hasattr(response, "output") and response.output: # Response API + return True + logging.warning("OpenAI API response is missing 'choices' or 'output' is empty.") + return False + + def get_openai_status_code(exception): + return getattr(exception, "http_status", None) + + return generic_call_api_with_retries( + client_function=client_function, + api_params=api_params, + is_response_valid_fn=is_openai_response_valid, + rate_limit_exceptions=(openai.RateLimitError,), + api_error_exceptions=(openai.APIError,), # openai.RateLimitError is caught first + get_status_code_fn=get_openai_status_code, + max_retries=max_retries, + # You can also pass initial_retry_delay_seconds and max_retry_delay_seconds + # if you want to customize them from their defaults in the generic function. + ) + + +def call_anthropic_api_with_retries(client_function, api_params, max_retries=10): + """ + Makes an Anthropic API call with retries for transient failures, + rate limiting, and invalid responses. + (This is a wrapper around generic_call_api_with_retries for Anthropic) + + Args: + client_function: The Anthropic API client function to call. + api_params: Parameters to pass to the client function. + max_retries: Maximum number of retry attempts. + + Returns: + The Anthropic API response if successful. + """ + + def is_anthropic_response_valid(response): + """Checks if the Anthropic response is valid.""" + # A successful Anthropic message response typically has: + # - a 'type' attribute equal to 'message' (for message creation) + # - a 'content' attribute which is a list of blocks + # - no 'error' attribute at the top level of the response object itself + # (errors are usually raised as exceptions by the client) + + if not response: + logging.warning("Anthropic API response is None or empty.") + return False + + # Check for explicit error type if the API might return it in a 200 OK + # For anthropic.types.Message, an error would typically be an exception. + # However, if the client_function could return a dict with an 'error' key: + if isinstance(response, dict) and response.get("type") == "error": + logging.warning(f"Anthropic API response indicates an error: {response.get('error')}") + return False + + # For anthropic.types.Message objects from client.messages.create + if hasattr(response, "type") and response.type == "message": + if hasattr(response, "content") and isinstance(response.content, list): + # Optionally, check if content is not empty, though an empty content list + # might be valid for some assistant stop reasons. + return True + else: + logging.warning( + "Anthropic API response is of type 'message' but missing valid 'content'." + ) + return False + + logging.warning( + f"Anthropic API response does not appear to be a valid message object. Type: {getattr(response, 'type', 'N/A')}" + ) + return False + + def get_anthropic_status_code(exception): + """Extracts HTTP status code from an Anthropic exception.""" + # anthropic.APIStatusError has a 'status_code' attribute + return getattr(exception, "status_code", None) + + # Define Anthropic specific exceptions. + # anthropic.RateLimitError for specific rate limit errors. + # anthropic.APIError is a base class for many errors. + # anthropic.APIStatusError provides status_code. + # anthropic.APIConnectionError for network issues. + # Order can matter if there's inheritance; specific ones first. + + # Ensure these are the correct exception types from your installed anthropic library version. + anthropic_rate_limit_exception = anthropic.RateLimitError + # Broader API errors, APIStatusError is more specific for HTTP status related issues. + # APIConnectionError for network problems. APIError as a general catch-all. + anthropic_api_error_exceptions = ( + anthropic.APIStatusError, # Catches errors with a status_code + anthropic.APIConnectionError, # Catches network-related issues + anthropic.APIError, # General base class for other Anthropic API errors + ) + + return generic_call_api_with_retries( + client_function=client_function, + api_params=api_params, + is_response_valid_fn=is_anthropic_response_valid, + rate_limit_exceptions=(anthropic_rate_limit_exception,), + api_error_exceptions=anthropic_api_error_exceptions, + get_status_code_fn=get_anthropic_status_code, + max_retries=max_retries, + # You can also pass initial_retry_delay_seconds and max_retry_delay_seconds + # if you want to customize them from their defaults in the generic function. + ) + + +def supports_tool_calling_for_openrouter( + model_name: str, +) -> bool: + """ + Check if the openrouter model supports tool calling. + + Args: + model_name (str): The name of the model. + + Returns: + bool: True if the model supports tool calling, False otherwise. + """ + import os + + import openai + + client = openai.Client( + api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1" + ) + try: + response = client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Call the test tool"}], + tools=[ + { + "type": "function", + "function": { + "name": "dummy_tool", + "description": "Just a test tool", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + ], + tool_choice="required", + ) + response = response.to_dict() + return "tool_calls" in response["choices"][0]["message"] + except Exception as e: + print(f"Skipping tool callign support check in openrouter for {model_name}: {e}") + return True + + def retry_multiple( chat: "ChatModel", messages: "Discussion", @@ -381,6 +709,17 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): return f"data:image/jpeg;base64,{image_base64}" +def image_to_png_base64_url(image: np.ndarray | Image.Image): + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + if image.mode in ("RGBA", "LA"): + image = image.convert("RGB") + buffered = io.BytesIO() + image.save(buffered, "PNG") + image_base64 = base64.b64encode(buffered.getvalue()).decode() + return f"data:image/png;base64,{image_base64}" + + class BaseMessage(dict): def __init__(self, role: str, content: Union[str, list[dict]], **kwargs): allowed_attrs = {"log_probs"} @@ -401,7 +740,13 @@ def __str__(self, warn_if_image=False) -> str: else: logging.info(msg) - return "\n".join([elem["text"] for elem in self["content"] if elem["type"] == "text"]) + return "\n".join( + [ + elem["text"] + for elem in self["content"] + if elem["type"] == "text" or elem["type"] == "input_text" + ] + ) def add_content(self, type: str, content: Any): if isinstance(self["content"], str): diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py new file mode 100644 index 00000000..755886f8 --- /dev/null +++ b/src/agentlab/llm/response_api.py @@ -0,0 +1,735 @@ +import json +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Type, Union + +import openai +from anthropic import Anthropic +from openai import OpenAI + +from agentlab.llm.llm_utils import image_to_png_base64_url + +from .base_api import BaseModelArgs +from .llm_utils import ( + call_anthropic_api_with_retries, + call_openai_api_with_retries, +) +from .tracking import TrackAPIPricingMixin + +"""This module contains utlity classes for building input messages and interacting with LLM APIs. +It includes: + 1. Message Builder for building input messages + 2. Base Reponse class for different LLM APIs (OpenAI, Anthropic, etc.) + 3. Factory classes (inherits from BaseModelArgs) for creating instances of LLM Response models. +""" + + +ContentItem = Dict[str, Any] +Message = Dict[str, Union[str, List[ContentItem]]] + + +@dataclass +class LLMOutput: + """Serializable object for the output of a response LLM.""" + + raw_response: Any = field(default_factory=dict) + think: str = field(default="") + action: str = field(default=None) # Default action if no tool call is made + tool_calls: Any = field(default=None) # This will hold the tool call response if any + + +class MessageBuilder: + def __init__(self, role: str): + + self.role = role + self.last_raw_response: LLMOutput = None + self.content: List[ContentItem] = [] + self.tool_call_id: Optional[str] = None + + @classmethod + def system(cls) -> "MessageBuilder": + return cls("system") + + @classmethod + def user(cls) -> "MessageBuilder": + return cls("user") + + @classmethod + def assistant(cls) -> "MessageBuilder": + return cls("assistant") + + @classmethod + def tool(cls, last_raw_response) -> "MessageBuilder": + return cls("tool").update_last_raw_response(last_raw_response) + + @abstractmethod + def prepare_message(self) -> List[Message]: + """Prepare the message for the API call.""" + raise NotImplementedError("Subclasses must implement this method.") + + def update_last_raw_response(self, last_raw_response: Any) -> "MessageBuilder": + self.last_raw_response = last_raw_response + return self + + def add_text(self, text: str) -> "MessageBuilder": + self.content.append({"text": text}) + return self + + def add_image(self, image: str) -> "MessageBuilder": + self.content.append({"image": image}) + return self + + def to_markdown(self) -> str: + parts = [] + for item in self.content: + if "text" in item: + parts.append(f"\n```\n{item['text']}\n```\n") + elif "image" in item: + parts.append(f"![Image]({item['image']})") + + markdown = f"### {self.role.capitalize()}\n" + markdown += "\n".join(parts) + + return markdown + + def add_image_url(self, image_url: str) -> "MessageBuilder": + """Add an image URL to the message content.""" + self.content.append({"image": image_to_png_base64_url(image_url)}) + return self + + def mark_all_previous_msg_for_caching(self): + """Insert a cache breakpoint in the message content.""" + # This is a placeholder for future implementation. + raise NotImplementedError + + +# TODO: Support parallel tool calls. + + +class OpenAIResponseAPIMessageBuilder(MessageBuilder): + @classmethod + def system(cls) -> "OpenAIResponseAPIMessageBuilder": + # OpenAI Responses API uses 'developer' role for system messages + return cls("developer") + + def prepare_message(self) -> List[Message]: + content = [] + for item in self.content: + if "text" in item: + content_type = "input_text" if self.role != "assistant" else "output_text" + content.append({"type": content_type, "text": item["text"]}) + + elif "image" in item: + content.append({"type": "input_image", "image_url": item["image"]}) + + output = [{"role": self.role, "content": content}] + if self.role != "tool": + return output + else: + tool_call_response = self.handle_tool_call(content) + return tool_call_response + + def handle_tool_call(self, content): + """Handle the tool call response from the last raw response.""" + output = [] + head_content, *tail_content = content + api_response = self.last_raw_response + fn_calls = [content for content in api_response.output if content.type == "function_call"] + assert len(fn_calls) > 0, "No function calls found in the last response" + if len(fn_calls) > 1: + logging.warning("Using only the first tool call from many.") + + first_fn_call_id = fn_calls[0].call_id + fn_output = head_content.get("text", "Function call answer in next message") + fn_call_response = { + "type": "function_call_output", + "call_id": first_fn_call_id, + "output": fn_output, + } + output.append(fn_call_response) + if tail_content: + # if there are more content items, add them as a new user message + output.append({"role": "user", "content": tail_content}) + return output + + +class AnthropicAPIMessageBuilder(MessageBuilder): + + def prepare_message(self) -> List[Message]: + content = [self.transform_content(item) for item in self.content] + output = {"role": self.role, "content": content} + + if self.role == "system": + logging.info( + "Treating system message as 'user'. In the Anthropic API, system messages should be passed as a direct input to the client." + ) + output["role"] = "user" + + if self.role == "tool": + + api_response = self.last_raw_response + fn_calls = [content for content in api_response.content if content.type == "tool_use"] + assert len(fn_calls) > 0, "No tool calls found in the last response" + if len(fn_calls) > 1: + logging.warning("Using only the first tool call from many.") + tool_call_id = fn_calls[0].id # Using the first tool call ID + + output["role"] = "user" + output["content"] = [ + { + "type": "tool_result", + "tool_use_id": tool_call_id, + "content": output["content"], + } + ] + if self.role == "assistant": + # Strip whitespace from assistant text responses. See anthropic error code 400. + for c in output["content"]: + if "text" in c: + c["text"] = c["text"].strip() + return [output] + + def transform_content(self, content: ContentItem) -> ContentItem: + """Transform content item to the format expected by Anthropic API.""" + if "text" in content: + return {"type": "text", "text": content["text"]} + elif "image" in content: + img_str: str = content["image"] + # make sure to get rid of the image type for anthropic + # e.g. "data:image/png;base64" + if img_str.startswith("data:image/png;base64,"): + img_str = img_str[len("data:image/png;base64,") :] + return { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": img_str, + }, + } + else: + raise ValueError(f"Unsupported content type: {content}") + + def mark_all_previous_msg_for_caching(self) -> List[Message]: + """Insert a cache breakpoint in the message content to mark all previous messages for caching.""" + self._cache_breakpoint = True + + +class OpenAIChatCompletionAPIMessageBuilder(MessageBuilder): + + def prepare_message(self) -> List[Message]: + """Prepare the message for the OpenAI API.""" + content = [self.transform_content(item) for item in self.content] + if self.role == "tool": + return self.handle_tool_call(content) + else: + return [{"role": self.role, "content": content}] + + def transform_content(self, content: ContentItem) -> ContentItem: + """Transform content item to the format expected by OpenAI ChatCompletion.""" + if "text" in content: + return {"type": "text", "text": content["text"]} + elif "image" in content: + return {"type": "image_url", "image_url": {"url": content["image"]}} + else: + raise ValueError(f"Unsupported content type: {content}") + + def handle_tool_call(self, content) -> List[Message]: + """Handle the tool call response from the last raw response.""" + output = [] + content_head, *content_tail = content + api_response = self.last_raw_response.choices[0].message + fn_calls = getattr(api_response, "tool_calls", None) + assert fn_calls is not None, "Tool calls not found in the last response" + if len(fn_calls) > 1: + logging.warning("Using only the first tool call from many.") + + # a function_call_output dict has keys "role", "tool_call_id" and "content" + tool_call_reponse = { + "role": "tool", + "tool_call_id": fn_calls[0].id, # using the first tool call ID + "content": content_head.get("text", "Tool call answer in next message"), + "name": fn_calls[0].function.name, # required with OpenRouter + } + + output.append(tool_call_reponse) + if content_tail: + # if there are more content items, add them as a new user message + output.append({"role": "user", "content": content_tail}) + return output + + +# # Base class for all API Endpoints +class BaseResponseModel(ABC): + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + ): + self.model_name = model_name + self.api_key = api_key + self.temperature = temperature + self.max_tokens = max_tokens + self.extra_kwargs = extra_kwargs or {} + + super().__init__() + + def __call__(self, messages: list[dict | MessageBuilder], **kwargs) -> dict: + """Make a call to the model and return the parsed response.""" + response = self._call_api(messages, **kwargs) + return self._parse_response(response) + + @abstractmethod + def _call_api(self, messages: list[dict | MessageBuilder], **kwargs) -> Any: + """Make a call to the model API and return the raw response.""" + pass + + @abstractmethod + def _parse_response(self, response: Any) -> LLMOutput: + """Parse the raw response from the model API and return a structured response.""" + pass + + +class BaseModelWithPricing(TrackAPIPricingMixin, BaseResponseModel): + pass + + +class OpenAIResponseModel(BaseModelWithPricing): + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + self.tools = kwargs.pop("tools", None) + self.tool_choice = kwargs.pop("tool_choice", None) + super().__init__( + model_name=model_name, + api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + extra_kwargs=extra_kwargs, + **kwargs, + ) + self.client = OpenAI(api_key=api_key) + + def _call_api(self, messages: list[Any | MessageBuilder], **kwargs) -> dict: + input = [] + for msg in messages: + input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg]) + + api_params: Dict[str, Any] = { + "model": self.model_name, + "input": input, + "temperature": self.temperature, + "max_output_tokens": self.max_tokens, + **self.extra_kwargs, + } + + if self.tools is not None: + api_params["tools"] = self.tools + if self.tool_choice is not None: + api_params["tool_choice"] = self.tool_choice + + # api_params |= kwargs # Merge any additional parameters passed + response = call_openai_api_with_retries( + self.client.responses.create, + api_params, + ) + + return response + + def _parse_response(self, response: dict) -> dict: + result = LLMOutput( + raw_response=response, + think="", + action=None, + tool_calls=None, + ) + interesting_keys = ["output_text"] + for output in response.output: + if output.type == "function_call": + arguments = json.loads(output.arguments) + func_args_str = ", ".join( + [ + f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" + for k, v in arguments.items() + ] + ) + result.action = f"{output.name}({func_args_str})" + result.tool_calls = output + break + elif output.type == "reasoning": + if len(output.summary) > 0: + result.think += output.summary[0].text + "\n" + + elif output.type == "message" and output.content: + result.think += output.content[0].text + "\n" + for key in interesting_keys: + if key_content := getattr(output, "output_text", None) is not None: + result.think += f"<{key}>{key_content}" + return result + + +class OpenAIChatCompletionModel(BaseModelWithPricing): + def __init__( + self, + model_name: str, + client_args: Optional[Dict[str, Any]] = {}, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + *args, + **kwargs, + ): + + self.tools = self.format_tools_for_chat_completion(kwargs.pop("tools", None)) + self.tool_choice = kwargs.pop("tool_choice", None) + + super().__init__( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + extra_kwargs=extra_kwargs, + *args, + **kwargs, + ) + + self.client = OpenAI( + **client_args + ) # Ensures client_args is a dict or defaults to an empty dict + + def _call_api(self, messages: list[dict | MessageBuilder]) -> openai.types.chat.ChatCompletion: + input = [] + for msg in messages: + input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg]) + api_params: Dict[str, Any] = { + "model": self.model_name, + "messages": input, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + **self.extra_kwargs, # Pass tools, tool_choice, etc. here + } + if self.tools is not None: + api_params["tools"] = self.tools + if self.tool_choice is not None: + api_params["tool_choice"] = self.tool_choice + + response = call_openai_api_with_retries(self.client.chat.completions.create, api_params) + + return response + + def _parse_response(self, response: openai.types.chat.ChatCompletion) -> LLMOutput: + + output = LLMOutput( + raw_response=response, + think="", + action=None, # Default if no tool call + tool_calls=None, + ) + message = response.choices[0].message.to_dict() + output.think = self.extract_content_with_reasoning(message) + + if tool_calls := message.get("tool_calls", None): + for tool_call in tool_calls: + function = tool_call["function"] + arguments = json.loads(function["arguments"]) + func_args_str = ", ".join( + [ + f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" + for k, v in arguments.items() + ] + ) + output.action = f"{function['name']}({func_args_str})" + output.tool_calls = { + "role": "assistant", + "tool_calls": [message["tool_calls"][0]], # Use only the first tool call + } + break + return output + + @staticmethod + def format_tools_for_chat_completion(tools): + """Formats response tools format for OpenAI Chat Completion API. + + Why we need this? + Ans: actionset.to_tool_description() in bgym only returns description + format valid for OpenAI Response API. + + Args: + tools: List of tool descriptions to format for Chat Completion API. + + Returns: + Formatted tools list compatible with OpenAI Chat Completion API, or None if tools is None. + """ + formatted_tools = None + if tools is not None: + formatted_tools = [ + { + "type": tool["type"], + "function": {k: tool[k] for k in ("name", "description", "parameters")}, + } + for tool in tools + ] + return formatted_tools + + @staticmethod + def extract_content_with_reasoning(message, wrap_tag="think"): + """Extracts the content from the message, including reasoning if available. + It wraps the reasoning around ... for easy identification of reasoning content, + When LLM produces 'text' and 'reasoning' in the same message. + Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered. + + Args: + message: The message object or dict containing content and reasoning. + wrap_tag: The tag name to wrap reasoning content (default: "think"). + + Returns: + str: The extracted content with reasoning wrapped in specified tags. + """ + if not isinstance(message, dict): + message = message.to_dict() + + reasoning_content = message.get("reasoning", None) + msg_content = message.get("text", "") # works for OR + + if reasoning_content: + # Wrap reasoning in tags with newlines for clarity + reasoning_content = f"<{wrap_tag}>{reasoning_content}\n" + logging.debug("Extracting content from response.choices[i].message.reasoning") + else: + reasoning_content = "" + return f"{reasoning_content}{msg_content}{message.get('content', '')}" + + +class ClaudeResponseModel(BaseModelWithPricing): + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + self.tools = kwargs.pop("tools", None) + self.tool_choice = kwargs.pop("tool_choice", None) + + super().__init__( + model_name=model_name, + api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + extra_kwargs=extra_kwargs, + **kwargs, + ) + + self.client = Anthropic(api_key=api_key) + + def _call_api( + self, messages: list[dict | MessageBuilder], tool_choice="auto", **kwargs + ) -> dict: + input = [] + + sys_msg, other_msgs = self.filter_system_messages(messages) + sys_msg_text = "\n".join(c["text"] for m in sys_msg for c in m.content) + for msg in other_msgs: + temp = msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg] + if kwargs.pop("use_cache_breakpoints", False): + temp = self.apply_cache_breakpoints(msg, temp) + input.extend(temp) + + api_params: Dict[str, Any] = { + "model": self.model_name, + "messages": input, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "system": sys_msg_text, # Anthropic API expects system message as a string + "tool_choice": {"type": tool_choice}, # Tool choice for Claude API + **self.extra_kwargs, # Pass tools, tool_choice, etc. here + } + if self.tools is not None: + api_params["tools"] = self.tools + if kwargs.pop("cache_tool_definition", False): + # Indicating cache control for the last tool enables caching of all previous tool definitions. + api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"} + if kwargs.pop("cache_complete_prompt", False): + # Indicating cache control for the last message enables caching of the complete prompt. + api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} + if self.extra_kwargs.get("reasoning", None) is not None: + api_params["reasoning"] = self.extra_kwargs["reasoning"] + + response = call_anthropic_api_with_retries(self.client.messages.create, api_params) + + return response + + @staticmethod + def filter_system_messages(messages: list[dict | MessageBuilder]) -> tuple[MessageBuilder]: + """Filter system messages from the list of messages.""" + # System message cannot have an image in the middle of the text sequences. + # Images can be appended in the end of the system message. + + sys_msgs, other_msgs = [], [] + for msg in messages: + if isinstance(msg, MessageBuilder) and msg.role == "system": + sys_msgs.append(msg) + for c in msg.content: + if c.get("type") == "image": + raise TypeError("System messages cannot contain images.") + else: + other_msgs.append(msg) + return sys_msgs, other_msgs + + def _parse_response(self, response: dict) -> dict: + result = LLMOutput( + raw_response=response, + think="", + action=None, + tool_calls={ + "role": "assistant", + "content": response.content, + }, + ) + for output in response.content: + if output.type == "tool_use": + func_args_str = ", ".join( + [ + f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" + for k, v in output.input.items() + ] + ) + result.action = f"{output.name}({func_args_str})" + elif output.type == "text": + result.think += output.text + return result + + # def ensure_cache_conditions(self, msgs: List[Message]) -> bool: + # """Ensure API specific cache conditions are met.""" + # assert sum(getattr(msg, "_cache_breakpoint", 0) for msg in msgs) <= 4, "Too many cache breakpoints in the message." + + def apply_cache_breakpoints(self, msg: Message, prepared_msg: dict) -> List[Message]: + """Apply cache breakpoints to the messages.""" + if getattr(msg, "_cache_breakpoint", False): + prepared_msg[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} + return prepared_msg + + +# Factory classes to create the appropriate model based on the API endpoint. +@dataclass +class OpenAIResponseModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenAI + model.""" + + api = "openai" + + def make_model(self, extra_kwargs=None, **kwargs): + return OpenAIResponseModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + pricing_api="openai", + **kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIResponseAPIMessageBuilder + + +@dataclass +class ClaudeResponseModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenAI + model.""" + + api = "anthropic" + + def make_model(self, extra_kwargs=None, **kwargs): + return ClaudeResponseModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + pricing_api="anthropic", + **kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return AnthropicAPIMessageBuilder + + +@dataclass +class OpenAIChatModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenAI + model.""" + + api = "openai" + + def make_model(self, extra_kwargs=None, **kwargs): + return OpenAIChatCompletionModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + pricing_api="openai", + **kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIChatCompletionAPIMessageBuilder + + +@dataclass +class OpenRouterModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenRouter + model.""" + + api: str = "openai" # tool description format used by actionset.to_tool_description() in bgym + + def make_model(self, extra_kwargs=None, **kwargs): + return OpenAIChatCompletionModel( + client_args={ + "base_url": "https://openrouter.ai/api/v1", + "api_key": os.getenv("OPENROUTER_API_KEY"), + }, + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + pricing_api="openrouter", + **kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIChatCompletionAPIMessageBuilder + + +class VLLMModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with a VLLM + model.""" + + api = "openai" # tool description format used by actionset.to_tool_description() in bgym + + def make_model(self, extra_kwargs=None, **kwargs): + return OpenAIChatCompletionModel( + client_args={ + "base_url": "http://localhost:8000/v1", + "api_key": os.getenv("VLLM_API_KEY", "EMPTY"), + }, + model_name=self.model_name, # this needs to be set + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + pricing_api="vllm", + **kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIChatCompletionAPIMessageBuilder diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 6a08839b..ad846a71 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -1,13 +1,28 @@ +import logging import os +import re import threading +from collections import defaultdict from contextlib import contextmanager +from dataclasses import dataclass, field from functools import cache +from typing import Optional import requests -from langchain_community.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS +from langchain_community.callbacks import bedrock_anthropic_callback, openai_info TRACKER = threading.local() +ANTHROPHIC_CACHE_PRICING_FACTOR = { + "cache_read_tokens": 0.1, # Cost for 5 min ephemeral cache. See Pricing Here: https://docs.anthropic.com/en/docs/about-claude/pricing#model-pricing + "cache_write_tokens": 1.25, +} + +OPENAI_CACHE_PRICING_FACTOR = { + "cache_read_tokens": 0.5, # This is a an upper bound. See Pricing Here: https://platform.openai.com/docs/pricing + "cache_write_tokens": 1, +} + class LLMTracker: def __init__(self, suffix=""): @@ -67,6 +82,7 @@ def wrapper(self, obs): @cache def get_pricing_openrouter(): + """Returns a dictionary of model pricing for OpenRouter models.""" api_key = os.getenv("OPENROUTER_API_KEY") assert api_key, "OpenRouter API key is required" # query api to get model metadata @@ -85,7 +101,8 @@ def get_pricing_openrouter(): def get_pricing_openai(): - cost_dict = MODEL_COST_PER_1K_TOKENS + """Returns a dictionary of model pricing for OpenAI models.""" + cost_dict = openai_info.MODEL_COST_PER_1K_TOKENS cost_dict = {k: v / 1000 for k, v in cost_dict.items()} res = {} for k in cost_dict: @@ -99,3 +116,213 @@ def get_pricing_openai(): "completion": cost_dict[completion_key], } return res + + +def _remove_version_suffix(model_name): + no_version = re.sub(r"-v\d+(?:[.:]\d+)?$", "", model_name) + return re.sub(r"anthropic.", "", no_version) + + +def get_pricing_anthropic(): + """Returns a dictionary of model pricing for Anthropic models.""" + input_cost_dict = bedrock_anthropic_callback.MODEL_COST_PER_1K_INPUT_TOKENS + output_cost_dict = bedrock_anthropic_callback.MODEL_COST_PER_1K_OUTPUT_TOKENS + + res = {} + for k, v in input_cost_dict.items(): + k = _remove_version_suffix(k) + res[k] = {"prompt": v / 1000} + + for k, v in output_cost_dict.items(): + k = _remove_version_suffix(k) + if k not in res: + res[k] = {} + res[k]["completion"] = v / 1000 + return res + + +class TrackAPIPricingMixin: + """Mixin class to handle pricing information for different models. + This populates the tracker.stats used by the cost_tracker_decorator + + Usage: provide the pricing_api to use in the constructor. + """ + + def reset_stats(self): + self.stats = Stats() + + def __init__(self, *args, **kwargs): + pricing_api = kwargs.pop("pricing_api", None) + self._pricing_api = pricing_api + super().__init__(*args, **kwargs) + self.set_pricing_attributes() + self.reset_stats() + + def __call__(self, *args, **kwargs): + """Call the API and update the pricing tracker.""" + response = self._call_api(*args, **kwargs) + + usage = dict(getattr(response, "usage", {})) + usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))} + usage |= {"n_api_calls": 1} + usage |= {"effective_cost": self.get_effective_cost(response)} + self.stats.increment_stats_dict(usage) + self.update_pricing_tracker(response) + return self._parse_response(response) + + def fetch_pricing_information_from_provider(self) -> Optional[dict]: + """ + Fetch the pricing information dictionary for the given provider. + + Returns: + Optional[dict]: A dict mapping model names to pricing info, or None if not found. + """ + pricing_fn_map = { + "openai": get_pricing_openai, + "anthropic": get_pricing_anthropic, + "openrouter": get_pricing_openrouter, + } + pricing_fn = pricing_fn_map.get(self._pricing_api, None) + if pricing_fn is None: + logging.warning( + f"Unsupported provider: {self._pricing_api}. Supported providers are: {list(pricing_fn_map.keys())}" + ) + return None + return pricing_fn() + + def set_pricing_attributes(self) -> None: + """Set the pricing attributes for the model based on the provider.""" + model_to_price_dict = self.fetch_pricing_information_from_provider() + model_costs = model_to_price_dict.get(self.model_name) if model_to_price_dict else None + if model_costs: + self.input_cost = float(model_costs["prompt"]) + self.output_cost = float(model_costs["completion"]) + else: + logging.warning(f"Model {self.model_name} not found in the pricing information.") + self.input_cost = 0.0 + self.output_cost = 0.0 + + def update_pricing_tracker(self, raw_response) -> None: + """Update the pricing tracker with the input and output tokens and cost.""" + + input_tokens, output_tokens = self.get_tokens_counts_from_response(raw_response) + cost = input_tokens * self.input_cost + output_tokens * self.output_cost + + if hasattr(TRACKER, "instance") and isinstance(TRACKER.instance, LLMTracker): + TRACKER.instance(input_tokens, output_tokens, cost) + + def get_tokens_counts_from_response(self, response) -> tuple: + """Get the input and output tokens counts from the response, provider-agnostic.""" + # Try OpenAI/Anthropic style + usage = getattr(response, "usage", None) + if usage: + input_tokens = getattr(usage, "input_tokens", None) or getattr( + usage, "prompt_tokens", None + ) + output_tokens = getattr(usage, "output_tokens", None) or getattr( + usage, "completion_tokens", None + ) + if input_tokens is not None and output_tokens is not None: + return input_tokens, output_tokens + + # Try dict style + if isinstance(response, dict) and "usage" in response: + usage = response["usage"] + input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") + output_tokens = usage.get("output_tokens") or usage.get("completion_tokens") + if input_tokens is not None and output_tokens is not None: + return input_tokens, output_tokens + + logging.warning( + "Unable to extract input and output tokens from the response. Defaulting to 0." + ) + return 0, 0 + + def get_effective_cost(self, response): + """Get the effective cost from the response based on the provider.""" + if self._pricing_api == "anthropic": + return self.get_effective_cost_from_antrophic_api(response) + elif self._pricing_api == "openai": + return self.get_effective_cost_from_openai_api(response) + else: + logging.warning( + f"Unsupported provider: {self._pricing_api}. No effective cost calculated." + ) + return 0.0 + + def get_effective_cost_from_antrophic_api(self, response) -> float: + """ + Get the effective cost from the Anthropic API response. + + Anthropic usage 'input_tokens' are new input tokens (tokens that are not cached). + Anthropic has different pricing for cache write and cache read tokens. + See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance + + Args: + response: The response object from the Anthropic API. + + Returns: + float: The effective cost calculated from the response. + """ + usage = getattr(response, "usage", {}) + new_input_tokens = getattr(usage, "input_tokens", 0) # new input tokens + output_tokens = getattr(usage, "output_tokens", 0) + cache_read_tokens = getattr(usage, "cache_input_tokens", 0) + cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0) + + cache_read_cost = self.input_cost * ANTHROPHIC_CACHE_PRICING_FACTOR["cache_read_tokens"] + cache_write_cost = self.input_cost * ANTHROPHIC_CACHE_PRICING_FACTOR["cache_write_tokens"] + + # Calculate the effective cost + effective_cost = ( + new_input_tokens * self.input_cost + + output_tokens * self.output_cost + + cache_read_tokens * cache_read_cost + + cache_write_tokens * cache_write_cost + ) + return effective_cost + + def get_effective_cost_from_openai_api(self, response) -> float: + """ + Get the effective cost from the OpenAI API response. + + OpenAI usage 'prompt_tokens' are the total input tokens (cache read tokens + new input tokens). + See https://openai.com/index/api-prompt-caching/ + OpenAI has only one price for cache tokens, i.e., cache read price (generally 50% cheaper). + OpenAI has no extra charge for cache write tokens. + See Pricing Here: https://platform.openai.com/docs/pricing + + Args: + response: The response object from the OpenAI API. + + Returns: + float: The effective cost calculated from the response. + """ + usage = getattr(response, "usage", {}) + prompt_token_details = getattr(response, "prompt_tokens_details", {}) + + total_input_tokens = getattr( + prompt_token_details, "prompt_tokens", 0 + ) # Cache read tokens + new input tokens + output_tokens = getattr(usage, "completion_tokens", 0) + cache_read_tokens = getattr(prompt_token_details, "cached_tokens", 0) + + non_cached_input_tokens = total_input_tokens - cache_read_tokens + cache_read_cost = self.input_cost * OPENAI_CACHE_PRICING_FACTOR["cache_read_tokens"] + + effective_cost = ( + self.input_cost * non_cached_input_tokens + + cache_read_tokens * cache_read_cost + + self.output_cost * output_tokens + ) + return effective_cost + + +@dataclass +class Stats: + stats_dict: dict = field(default_factory=lambda: defaultdict(float)) + + def increment_stats_dict(self, stats_dict: dict): + """increment the stats_dict with the given values.""" + for k, v in stats_dict.items(): + self.stats_dict[k] += v diff --git a/tests/agents/test_generic_prompt.py b/tests/agents/test_generic_prompt.py index cc1f9036..5e89799a 100644 --- a/tests/agents/test_generic_prompt.py +++ b/tests/agents/test_generic_prompt.py @@ -2,13 +2,11 @@ import bgym import pytest +from bgym import HighLevelActionSet, HighLevelActionSetArgs from agentlab.agents import dynamic_prompting as dp from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_3_5 -from agentlab.agents.generic_agent.generic_agent_prompt import ( - GenericPromptFlags, - MainPrompt, -) +from agentlab.agents.generic_agent.generic_agent_prompt import GenericPromptFlags, MainPrompt from agentlab.llm.llm_utils import count_tokens html_template = """ @@ -76,7 +74,7 @@ filter_visible_elements_only=True, ), action=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs( + action_set=HighLevelActionSetArgs( subsets=["bid"], multiaction=True, ), @@ -171,7 +169,7 @@ def test_shrinking_observation(): flags.obs.use_html = True prompt_maker = MainPrompt( - action_set=bgym.HighLevelActionSet(), + action_set=HighLevelActionSet(), obs_history=OBS_HISTORY, actions=ACTIONS, memories=MEMORIES, @@ -237,7 +235,7 @@ def test_main_prompt_elements_present(): # Initialize MainPrompt prompt = str( MainPrompt( - action_set=bgym.HighLevelActionSet(), + action_set=HighLevelActionSet(), obs_history=OBS_HISTORY, actions=ACTIONS, memories=MEMORIES, diff --git a/tests/experiments/test_reproducibility_util.py b/tests/experiments/test_reproducibility_util.py index aa10ff47..e5e771b3 100644 --- a/tests/experiments/test_reproducibility_util.py +++ b/tests/experiments/test_reproducibility_util.py @@ -5,6 +5,7 @@ import bgym import pytest +from bgym import DEFAULT_BENCHMARKS from agentlab.agents.generic_agent import AGENT_4o_MINI from agentlab.analyze import inspect_results @@ -17,7 +18,7 @@ ) def test_get_reproducibility_info(benchmark_name): - benchmark = bgym.DEFAULT_BENCHMARKS[benchmark_name]() + benchmark = DEFAULT_BENCHMARKS[benchmark_name]() info = reproducibility_util.get_reproducibility_info( "test_agent", benchmark, "test_id", ignore_changes=True diff --git a/tests/llm/test_response_api.py b/tests/llm/test_response_api.py new file mode 100644 index 00000000..16316a92 --- /dev/null +++ b/tests/llm/test_response_api.py @@ -0,0 +1,718 @@ +import os +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch + +import anthropic +import openai +import pytest + +from agentlab.llm import tracking +from agentlab.llm.response_api import ( + AnthropicAPIMessageBuilder, + ClaudeResponseModelArgs, + LLMOutput, + OpenAIChatCompletionAPIMessageBuilder, + OpenAIChatModelArgs, + OpenAIResponseAPIMessageBuilder, + OpenAIResponseModelArgs, +) + + +# Helper to create a mock OpenAI ChatCompletion response +def create_mock_openai_chat_completion( + content=None, tool_calls=None, prompt_tokens=10, completion_tokens=20 +): + completion = MagicMock(spec=openai.types.chat.ChatCompletion) + choice = MagicMock() + message = MagicMock(spec=openai.types.chat.ChatCompletionMessage) + message.content = content + message.tool_calls = None + if tool_calls: + message.tool_calls = [] + for tc in tool_calls: + tool_call_mock = MagicMock( + spec=openai.types.chat.chat_completion_message_tool_call.ChatCompletionMessageToolCall + ) + tool_call_mock.id = tc["id"] + tool_call_mock.type = tc["type"] + tool_call_mock.function = MagicMock( + spec=openai.types.chat.chat_completion_message_tool_call.Function + ) + tool_call_mock.function.name = tc["function"]["name"] + tool_call_mock.function.arguments = tc["function"]["arguments"] + message.tool_calls.append(tool_call_mock) + + choice.message = message + completion.choices = [choice] + + completion.usage = MagicMock() + # Explicitly set the attributes that get_tokens_counts_from_response will try first. + # These are the generic names. + completion.usage.input_tokens = prompt_tokens + completion.usage.output_tokens = completion_tokens + + # Also set the OpenAI-specific names if any other part of the code might look for them directly, + # or if get_tokens_counts_from_response had different fallback logic. + completion.usage.prompt_tokens = prompt_tokens + completion.usage.completion_tokens = completion_tokens + + completion.model_dump.return_value = { + "id": "chatcmpl-xxxx", + "choices": [ + {"message": {"role": "assistant", "content": content, "tool_calls": tool_calls}} + ], + # Ensure the usage dict in model_dump also reflects the token counts accurately. + # The get_tokens_counts_from_response also has a path for dict style. + "usage": { + "input_tokens": prompt_tokens, # Generic name + "output_tokens": completion_tokens, # Generic name + "prompt_tokens": prompt_tokens, # OpenAI specific + "completion_tokens": completion_tokens, # OpenAI specific + }, + } + message.to_dict.return_value = { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + return completion + + +# Helper to create a mock Anthropic response +def create_mock_anthropic_response( + text_content=None, tool_use=None, input_tokens=15, output_tokens=25 +): + + response = MagicMock(spec=anthropic.types.Message) + response.type = "message" # Explicitly set the type attribute + response.content = [] + response.content = [] + if text_content: + text_block = MagicMock(spec=anthropic.types.TextBlock) + text_block.type = "text" + text_block.text = text_content + response.content.append(text_block) + if tool_use: + tool_use_block = MagicMock(spec=anthropic.types.ToolUseBlock) + tool_use_block.type = "tool_use" + tool_use_block.id = tool_use["id"] + tool_use_block.name = tool_use["name"] + tool_use_block.input = tool_use["input"] + response.content.append(tool_use_block) + response.usage = MagicMock() + response.usage.input_tokens = input_tokens + response.usage.output_tokens = output_tokens + return response + + +def create_mock_openai_responses_api_response( + outputs: Optional[List[Dict[str, Any]]] = None, input_tokens: int = 10, output_tokens: int = 20 +) -> MagicMock: + """ + Helper to create a mock response object similar to what + openai.resources.Responses.create() would return. + Compatible with OpenAIResponseModel and TrackAPIPricingMixin. + """ + + response_mock = MagicMock(openai.types.responses.response) + response_mock.type = "response" + response_mock.output = [] + + if outputs: + for out_data in outputs: + output_item_mock = MagicMock() + output_item_mock.type = out_data.get("type") + + if output_item_mock.type == "function_call": + # You can adapt this depending on your expected object structure + output_item_mock.name = out_data.get("name") + output_item_mock.arguments = out_data.get("arguments") + output_item_mock.call_id = out_data.get("call_id") + elif output_item_mock.type == "reasoning": + output_item_mock.summary = [] + for text_content in out_data.get("summary", []): + summary_text_mock = MagicMock() + summary_text_mock.text = text_content + output_item_mock.summary.append(summary_text_mock) + + response_mock.output.append(output_item_mock) + + # Token usage for pricing tracking + response_mock.usage = MagicMock() + response_mock.usage.input_tokens = input_tokens + response_mock.usage.output_tokens = output_tokens + response_mock.usage.prompt_tokens = input_tokens + response_mock.usage.completion_tokens = output_tokens + + return response_mock + + +# --- Test MessageBuilders --- + + +def test_openai_response_api_message_builder_text(): + builder = OpenAIResponseAPIMessageBuilder.user() + builder.add_text("Hello, world!") + messages = builder.prepare_message() + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"type": "input_text", "text": "Hello, world!"}] + + +def test_openai_response_api_message_builder_image(): + builder = OpenAIResponseAPIMessageBuilder.user() + builder.add_image("") + messages = builder.prepare_message() + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [ + {"type": "input_image", "image_url": ""} + ] + + +def test_anthropic_api_message_builder_text(): + builder = AnthropicAPIMessageBuilder.user() + builder.add_text("Hello, Anthropic!") + messages = builder.prepare_message() + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"type": "text", "text": "Hello, Anthropic!"}] + + +def test_anthropic_api_message_builder_image(): + builder = AnthropicAPIMessageBuilder.user() + builder.add_image("") + messages = builder.prepare_message() + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert len(messages[0]["content"]) == 1 + image_content = messages[0]["content"][0] + assert image_content["type"] == "image" + assert image_content["source"]["type"] == "base64" + assert image_content["source"]["media_type"] == "image/png" + assert image_content["source"]["data"] == "ANTHROPICBASE64" # Base64 prefix should be stripped + + +def test_openai_chat_completion_api_message_builder_text(): + builder = OpenAIChatCompletionAPIMessageBuilder.user() + builder.add_text("Hello, ChatCompletion!") + # Mock last_response as it's used by tool role + builder.last_raw_response = MagicMock(spec=LLMOutput) + builder.last_raw_response.raw_response = MagicMock() + builder.last_raw_response.raw_response.choices = [MagicMock()] + builder.last_raw_response.raw_response.choices[0].message.to_dict.return_value = { + "tool_calls": [{"function": {"name": "some_function"}}] + } + messages = builder.prepare_message() + + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"type": "text", "text": "Hello, ChatCompletion!"}] + + +def test_openai_chat_completion_api_message_builder_image(): + builder = OpenAIChatCompletionAPIMessageBuilder.user() + builder.add_image("") + # Mock last_response + builder.last_raw_response = MagicMock(spec=LLMOutput) + builder.last_raw_response.raw_response = MagicMock() + builder.last_raw_response.raw_response.choices = [MagicMock()] + builder.last_raw_response.raw_response.choices[0].message.to_dict.return_value = { + "tool_calls": [{"function": {"name": "some_function"}}] + } + messages = builder.prepare_message() + + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [ + {"type": "image_url", "image_url": {"url": ""}} + ] + + +def test_openai_chat_completion_model_parse_and_cost(): + args = OpenAIChatModelArgs(model_name="gpt-3.5-turbo") # A cheap model for testing + # Mock the OpenAI client to avoid needing OPENAI_API_KEY + with patch("agentlab.llm.response_api.OpenAI") as mock_openai_class: + mock_client = MagicMock() + mock_openai_class.return_value = mock_client + model = args.make_model() + + # Mock the API call + mock_response = create_mock_openai_chat_completion( + content="This is a test thought.", + tool_calls=[ + { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "Paris"}'}, + } + ], + prompt_tokens=50, + completion_tokens=30, + ) + + with patch.object( + model.client.chat.completions, "create", return_value=mock_response + ) as mock_create: + with tracking.set_tracker() as global_tracker: # Use your global tracker + messages = [ + OpenAIChatCompletionAPIMessageBuilder.user() + .add_text("What's the weather in Paris?") + .prepare_message()[0] + ] + parsed_output = model(messages) + + mock_create.assert_called_once() + assert parsed_output.raw_response.choices[0].message.content == "This is a test thought." + assert parsed_output.action == 'get_weather(location="Paris")' + assert parsed_output.raw_response.choices[0].message.tool_calls[0].id == "call_123" + # Check cost tracking (token counts) + assert global_tracker.stats["input_tokens"] == 50 + assert global_tracker.stats["output_tokens"] == 30 + assert global_tracker.stats["cost"] > 0 + + +def test_claude_response_model_parse_and_cost(): + args = ClaudeResponseModelArgs(model_name="claude-3-haiku-20240307") # A cheap model + model = args.make_model() + + mock_anthropic_api_response = create_mock_anthropic_response( + text_content="Thinking about the request.", + tool_use={"id": "tool_abc", "name": "search_web", "input": {"query": "latest news"}}, + input_tokens=40, + output_tokens=20, + ) + + with patch.object( + model.client.messages, "create", return_value=mock_anthropic_api_response + ) as mock_create: + with tracking.set_tracker() as global_tracker: + messages = [ + AnthropicAPIMessageBuilder.user() + .add_text("Search for latest news") + .prepare_message()[0] + ] + parsed_output = model(messages) + + mock_create.assert_called_once() + fn_calls = [ + content for content in parsed_output.raw_response.content if content.type == "tool_use" + ] + assert "Thinking about the request." in parsed_output.think + assert parsed_output.action == 'search_web(query="latest news")' + assert fn_calls[0].id == "tool_abc" + assert global_tracker.stats["input_tokens"] == 40 + assert global_tracker.stats["output_tokens"] == 20 + # assert global_tracker.stats["cost"] > 0 # Verify cost is calculated + + +def test_openai_response_model_parse_and_cost(): + """ + Tests OpenAIResponseModel output parsing and cost tracking with both + function_call and reasoning outputs. + """ + args = OpenAIResponseModelArgs(model_name="gpt-4.1") + + # Mock outputs + mock_function_call_output = { + "type": "function_call", + "name": "get_current_weather", + "arguments": '{"location": "Boston, MA", "unit": "celsius"}', + "call_id": "call_abc123", + } + + mock_api_resp = create_mock_openai_responses_api_response( + outputs=[mock_function_call_output], + input_tokens=70, + output_tokens=40, + ) + + # Mock the OpenAI client to avoid needing OPENAI_API_KEY + with patch("agentlab.llm.response_api.OpenAI") as mock_openai_class: + mock_client = MagicMock() + mock_openai_class.return_value = mock_client + model = args.make_model() + + with patch.object( + model.client.responses, "create", return_value=mock_api_resp + ) as mock_create_method: + with tracking.set_tracker() as global_tracker: + messages = [ + OpenAIResponseAPIMessageBuilder.user() + .add_text("What's the weather in Boston?") + .prepare_message()[0] + ] + parsed_output = model(messages) + + mock_create_method.assert_called_once() + fn_calls = [ + content for content in parsed_output.raw_response.output if content.type == "function_call" + ] + assert parsed_output.action == 'get_current_weather(location="Boston, MA", unit="celsius")' + assert fn_calls[0].call_id == "call_abc123" + assert parsed_output.raw_response == mock_api_resp + assert global_tracker.stats["input_tokens"] == 70 + assert global_tracker.stats["output_tokens"] == 40 + + +# --- Test Response Models (Pricy - require API keys and actual calls) --- + + +@pytest.mark.pricy +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") +def test_openai_chat_completion_model_pricy_call(): + """Tests OpenAIChatCompletionModel with a real API call.""" + args = OpenAIChatModelArgs( + model_name="gpt-4.1", + temperature=1e-5, + max_new_tokens=100, + ) + + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature.", + }, + }, + "required": ["location"], + }, + } + ] + + model = args.make_model(tools=tools, tool_choice="required") + + with tracking.set_tracker() as global_tracker: + messages = [ + OpenAIChatCompletionAPIMessageBuilder.user() + .add_text("What is the weather in Paris?") + .prepare_message()[0] + ] + parsed_output = model(messages) + + assert parsed_output.raw_response is not None + assert ( + parsed_output.action == 'get_weather(location="Paris")' + ), f""" Expected get_weather(location="Paris") but got {parsed_output.action}""" + assert global_tracker.stats["input_tokens"] > 0 + assert global_tracker.stats["output_tokens"] > 0 + assert global_tracker.stats["cost"] > 0 + + +@pytest.mark.pricy +@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not set") +def test_claude_response_model_pricy_call(): + """Tests ClaudeResponseModel with a real API call.""" + + args = ClaudeResponseModelArgs( + model_name="claude-3-haiku-20240307", + temperature=1e-5, + max_new_tokens=100, + ) + tools = [ + { + "name": "get_weather", + "description": "Get the current weather in a given location.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + }, + "required": ["location"], + }, + } + ] + model = args.make_model(tools=tools) + + with tracking.set_tracker() as global_tracker: + messages = [ + AnthropicAPIMessageBuilder.user() + .add_text("What is the weather in Paris?") + .prepare_message()[0] + ] + parsed_output = model(messages) + + assert parsed_output.raw_response is not None + assert ( + parsed_output.action == 'get_weather(location="Paris")' + ), f'Expected get_weather("Paris") but got {parsed_output.action}' + assert global_tracker.stats["input_tokens"] > 0 + assert global_tracker.stats["output_tokens"] > 0 + assert global_tracker.stats["cost"] > 0 + + +@pytest.mark.pricy +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") +def test_openai_response_model_pricy_call(): + """ + Tests OpenAIResponseModel output parsing and cost tracking with both + function_call and reasoning outputs. + """ + args = OpenAIResponseModelArgs(model_name="gpt-4.1", temperature=1e-5, max_new_tokens=100) + + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature.", + }, + }, + "required": ["location"], + }, + } + ] + model = args.make_model(tools=tools) + + with tracking.set_tracker() as global_tracker: + messages = [ + OpenAIResponseAPIMessageBuilder.user() + .add_text("What is the weather in Paris?") + .prepare_message()[0] + ] + parsed_output = model(messages) + + assert parsed_output.raw_response is not None + assert ( + parsed_output.action == """get_weather(location="Paris")""" + ), f""" Expected get_weather(location="Paris") but got {parsed_output.action}""" + assert global_tracker.stats["input_tokens"] > 0 + assert global_tracker.stats["output_tokens"] > 0 + assert global_tracker.stats["cost"] > 0 + + +@pytest.mark.pricy +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") +def test_openai_response_model_with_multiple_messages_and_cost_tracking(): + """ + Test OpenAIResponseModel's output parsing and cost tracking + with a tool-using assistant and follow-up interaction. + """ + args = OpenAIResponseModelArgs(model_name="gpt-4.1", temperature=1e-5, max_new_tokens=100) + + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature.", + }, + }, + "required": ["location"], + }, + } + ] + + model = args.make_model(tools=tools, tool_choice="required") + builder = args.get_message_builder() + + messages = [builder.user().add_text("What is the weather in Paris?")] + + with tracking.set_tracker() as tracker: + # First turn: get initial tool call + parsed = model(messages) + prev_input = tracker.stats["input_tokens"] + prev_output = tracker.stats["output_tokens"] + prev_cost = tracker.stats["cost"] + + # Simulate tool execution and user follow-up + messages += [ + parsed.tool_calls, # Add tool call from the model + builder.tool(parsed.raw_response).add_text("Its sunny! 25°C"), + builder.user().add_text("What is the weather in Delhi?"), + ] + + parsed = model(messages) + + # Token and cost deltas + delta_input = tracker.stats["input_tokens"] - prev_input + delta_output = tracker.stats["output_tokens"] - prev_output + delta_cost = tracker.stats["cost"] - prev_cost + + # Assertions + assert prev_input > 0 + assert prev_output > 0 + assert prev_cost > 0 + assert parsed.raw_response is not None + assert parsed.action == 'get_weather(location="Delhi")', f"Unexpected action: {parsed.action}" + assert delta_input > 0 + assert delta_output > 0 + assert delta_cost > 0 + assert tracker.stats["input_tokens"] == prev_input + delta_input + assert tracker.stats["output_tokens"] == prev_output + delta_output + assert tracker.stats["cost"] == pytest.approx(prev_cost + delta_cost) + + +@pytest.mark.pricy +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") +def test_openai_chat_completion_model_with_multiple_messages_and_cost_tracking(): + """ + Test OpenAIResponseModel's output parsing and cost tracking + with a tool-using assistant and follow-up interaction. + """ + args = OpenAIChatModelArgs(model_name="gpt-4.1", temperature=1e-5, max_new_tokens=100) + + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature.", + }, + }, + "required": ["location"], + }, + } + ] + + model = args.make_model(tools=tools, tool_choice="required") + builder = args.get_message_builder() + + messages = [builder.user().add_text("What is the weather in Paris?")] + + with tracking.set_tracker() as tracker: + # First turn: get initial tool call + parsed = model(messages) + prev_input = tracker.stats["input_tokens"] + prev_output = tracker.stats["output_tokens"] + prev_cost = tracker.stats["cost"] + + # Simulate tool execution and user follow-up + messages += [ + parsed.tool_calls, # Add tool call from the model + builder.tool(parsed.raw_response).add_text("Its sunny! 25°C"), + builder.user().add_text("What is the weather in Delhi?"), + ] + + parsed = model(messages) + + # Token and cost deltas + delta_input = tracker.stats["input_tokens"] - prev_input + delta_output = tracker.stats["output_tokens"] - prev_output + delta_cost = tracker.stats["cost"] - prev_cost + + # Assertions + assert prev_input > 0 + assert prev_output > 0 + assert prev_cost > 0 + assert parsed.raw_response is not None + assert parsed.action == 'get_weather(location="Delhi")', f"Unexpected action: {parsed.action}" + assert delta_input > 0 + assert delta_output > 0 + assert delta_cost > 0 + assert tracker.stats["input_tokens"] == prev_input + delta_input + assert tracker.stats["output_tokens"] == prev_output + delta_output + assert tracker.stats["cost"] == pytest.approx(prev_cost + delta_cost) + + +@pytest.mark.pricy +@pytest.mark.skipif(not os.getenv("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not set") +def test_claude_model_with_multiple_messages_pricy_call(): + model_factory = ClaudeResponseModelArgs( + model_name="claude-3-haiku-20240307", temperature=1e-5, max_new_tokens=100 + ) + tools = [ + { + "name": "get_weather", + "description": "Get the current weather in a given location.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature.", + }, + }, + "required": ["location"], + }, + } + ] + model = model_factory.make_model(tools=tools) + msg_builder = model_factory.get_message_builder() + messages = [] + + messages.append(msg_builder.user().add_text("What is the weather in Paris?")) + with tracking.set_tracker() as global_tracker: + llm_output1 = model(messages) + + prev_input = global_tracker.stats["input_tokens"] + prev_output = global_tracker.stats["output_tokens"] + prev_cost = global_tracker.stats["cost"] + + messages.append(llm_output1.tool_calls) + messages.append(msg_builder.tool(llm_output1.raw_response).add_text("Its sunny! 25°C")) + messages.append(msg_builder.user().add_text("What is the weather in Delhi?")) + llm_output2 = model(messages) + # Token and cost deltas + delta_input = global_tracker.stats["input_tokens"] - prev_input + delta_output = global_tracker.stats["output_tokens"] - prev_output + delta_cost = global_tracker.stats["cost"] - prev_cost + + # Assertions + assert prev_input > 0, "Expected previous input tokens to be greater than 0" + assert prev_output > 0, "Expected previous output tokens to be greater than 0" + assert prev_cost > 0, "Expected previous cost value to be greater than 0" + assert llm_output2.raw_response is not None + assert ( + llm_output2.action == 'get_weather(location="Delhi", unit="celsius")' + ), f'Expected get_weather("Delhi") but got {llm_output2.action}' + assert delta_input > 0, "Expected new input tokens to be greater than 0" + assert delta_output > 0, "Expected new output tokens to be greater than 0" + assert delta_cost > 0, "Expected new cost value to be greater than 0" + assert global_tracker.stats["input_tokens"] == prev_input + delta_input + assert global_tracker.stats["output_tokens"] == prev_output + delta_output + assert global_tracker.stats["cost"] == pytest.approx(prev_cost + delta_cost) + + +# TODO: Add tests for image token costing (this is complex and model-specific) +# - For OpenAI, you'd need to know how they bill for images (e.g., fixed cost per image + tokens for text parts) +# - You'd likely need to mock the response from client.chat.completions.create to include specific usage for images.