diff --git a/AGENTS.md b/AGENTS.md index 2ad76a28b..9f3617ce9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,6 +26,7 @@ Runs on `http://localhost:3000` by default. 3. After finishing, use `ruff format .` and `ruff check .` to format and check the code. 4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`. 5. Use English for all new comments. +6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory. ## PR instructions diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index f1e0688f2..e43fcbda6 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -3,6 +3,7 @@ import time import traceback import typing as T +from dataclasses import dataclass from mcp.types import ( BlobResourceContents, @@ -14,8 +15,9 @@ ) from astrbot import logger -from astrbot.core.agent.message import TextPart, ThinkPart +from astrbot.core.agent.message import ImageURLPart, TextPart, ThinkPart from astrbot.core.agent.tool import ToolSet +from astrbot.core.agent.tool_image_cache import tool_image_cache from astrbot.core.message.components import Json from astrbot.core.message.message_event_result import ( MessageChain, @@ -44,6 +46,28 @@ from typing_extensions import override +@dataclass(slots=True) +class _HandleFunctionToolsResult: + kind: T.Literal["message_chain", "tool_call_result_blocks", "cached_image"] + message_chain: MessageChain | None = None + tool_call_result_blocks: list[ToolCallMessageSegment] | None = None + cached_image: T.Any = None + + @classmethod + def from_message_chain(cls, chain: MessageChain) -> "_HandleFunctionToolsResult": + return cls(kind="message_chain", message_chain=chain) + + @classmethod + def from_tool_call_result_blocks( + cls, blocks: list[ToolCallMessageSegment] + ) -> "_HandleFunctionToolsResult": + return cls(kind="tool_call_result_blocks", tool_call_result_blocks=blocks) + + @classmethod + def from_cached_image(cls, image: T.Any) -> "_HandleFunctionToolsResult": + return cls(kind="cached_image", cached_image=image) + + class ToolLoopAgentRunner(BaseAgentRunner[TContext]): @override async def reset( @@ -286,20 +310,27 @@ async def step(self): llm_resp, _ = await self._resolve_tool_exec(llm_resp) tool_call_result_blocks = [] + cached_images = [] # Collect cached images for LLM visibility async for result in self._handle_function_tools(self.req, llm_resp): - if isinstance(result, list): - tool_call_result_blocks = result - elif isinstance(result, MessageChain): - if result.type is None: + if result.kind == "tool_call_result_blocks": + if result.tool_call_result_blocks is not None: + tool_call_result_blocks = result.tool_call_result_blocks + elif result.kind == "cached_image": + if result.cached_image is not None: + # Collect cached image info + cached_images.append(result.cached_image) + elif result.kind == "message_chain": + chain = result.message_chain + if chain is None or chain.type is None: # should not happen continue - if result.type == "tool_direct_result": + if chain.type == "tool_direct_result": ar_type = "tool_call_result" else: - ar_type = result.type + ar_type = chain.type yield AgentResponse( type=ar_type, - data=AgentResponseData(chain=result), + data=AgentResponseData(chain=chain), ) # 将结果添加到上下文中 @@ -327,6 +358,41 @@ async def step(self): tool_calls_result.to_openai_messages_model() ) + # If there are cached images and the model supports image input, + # append a user message with images so LLM can see them + if cached_images: + modalities = self.provider.provider_config.get("modalities", []) + supports_image = "image" in modalities + if supports_image: + # Build user message with images for LLM to review + image_parts = [] + for cached_img in cached_images: + img_data = tool_image_cache.get_image_base64_by_path( + cached_img.file_path, cached_img.mime_type + ) + if img_data: + base64_data, mime_type = img_data + image_parts.append( + TextPart( + text=f"[Image from tool '{cached_img.tool_name}', path='{cached_img.file_path}']" + ) + ) + image_parts.append( + ImageURLPart( + image_url=ImageURLPart.ImageURL( + url=f"data:{mime_type};base64,{base64_data}", + id=cached_img.file_path, + ) + ) + ) + if image_parts: + self.run_context.messages.append( + Message(role="user", content=image_parts) + ) + logger.debug( + f"Appended {len(cached_images)} cached image(s) to context for LLM review" + ) + self.req.append_tool_calls_result(tool_calls_result) async def step_until_done( @@ -362,7 +428,7 @@ async def _handle_function_tools( self, req: ProviderRequest, llm_response: LLMResponse, - ) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]: + ) -> T.AsyncGenerator[_HandleFunctionToolsResult, None]: """处理函数工具调用。""" tool_call_result_blocks: list[ToolCallMessageSegment] = [] logger.info(f"Agent 使用工具: {llm_response.tools_call_name}") @@ -373,18 +439,20 @@ async def _handle_function_tools( llm_response.tools_call_args, llm_response.tools_call_ids, ): - yield MessageChain( - type="tool_call", - chain=[ - Json( - data={ - "id": func_tool_id, - "name": func_tool_name, - "args": func_tool_args, - "ts": time.time(), - } - ) - ], + yield _HandleFunctionToolsResult.from_message_chain( + MessageChain( + type="tool_call", + chain=[ + Json( + data={ + "id": func_tool_id, + "name": func_tool_name, + "args": func_tool_args, + "ts": time.time(), + } + ) + ], + ) ) try: if not req.func_tool: @@ -470,15 +538,28 @@ async def _handle_function_tools( ), ) elif isinstance(res.content[0], ImageContent): + # Cache the image instead of sending directly + cached_img = tool_image_cache.save_image( + base64_data=res.content[0].data, + tool_call_id=func_tool_id, + tool_name=func_tool_name, + index=0, + mime_type=res.content[0].mimeType or "image/png", + ) tool_call_result_blocks.append( ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content="The tool has successfully returned an image and sent directly to the user. You can describe it in your next response.", + content=( + f"Image returned and cached at path='{cached_img.file_path}'. " + f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " + f"with type='image' and path='{cached_img.file_path}'." + ), ), ) - yield MessageChain(type="tool_direct_result").base64_image( - res.content[0].data, + # Yield image info for LLM visibility (will be handled in step()) + yield _HandleFunctionToolsResult.from_cached_image( + cached_img ) elif isinstance(res.content[0], EmbeddedResource): resource = res.content[0].resource @@ -495,16 +576,29 @@ async def _handle_function_tools( and resource.mimeType and resource.mimeType.startswith("image/") ): + # Cache the image instead of sending directly + cached_img = tool_image_cache.save_image( + base64_data=resource.blob, + tool_call_id=func_tool_id, + tool_name=func_tool_name, + index=0, + mime_type=resource.mimeType, + ) tool_call_result_blocks.append( ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content="The tool has successfully returned an image and sent directly to the user. You can describe it in your next response.", + content=( + f"Image returned and cached at path='{cached_img.file_path}'. " + f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " + f"with type='image' and path='{cached_img.file_path}'." + ), ), ) - yield MessageChain( - type="tool_direct_result", - ).base64_image(resource.blob) + # Yield image info for LLM visibility + yield _HandleFunctionToolsResult.from_cached_image( + cached_img + ) else: tool_call_result_blocks.append( ToolCallMessageSegment( @@ -565,23 +659,27 @@ async def _handle_function_tools( # yield the last tool call result if tool_call_result_blocks: last_tcr_content = str(tool_call_result_blocks[-1].content) - yield MessageChain( - type="tool_call_result", - chain=[ - Json( - data={ - "id": func_tool_id, - "ts": time.time(), - "result": last_tcr_content, - } - ) - ], + yield _HandleFunctionToolsResult.from_message_chain( + MessageChain( + type="tool_call_result", + chain=[ + Json( + data={ + "id": func_tool_id, + "ts": time.time(), + "result": last_tcr_content, + } + ) + ], + ) ) logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}") # 处理函数调用响应 if tool_call_result_blocks: - yield tool_call_result_blocks + yield _HandleFunctionToolsResult.from_tool_call_result_blocks( + tool_call_result_blocks + ) def _build_tool_requery_context( self, tool_names: list[str] diff --git a/astrbot/core/agent/tool_image_cache.py b/astrbot/core/agent/tool_image_cache.py new file mode 100644 index 000000000..72e22dd52 --- /dev/null +++ b/astrbot/core/agent/tool_image_cache.py @@ -0,0 +1,162 @@ +"""Tool image cache module for storing and retrieving images returned by tools. + +This module allows LLM to review images before deciding whether to send them to users. +""" + +import base64 +import os +import time +from dataclasses import dataclass, field +from typing import ClassVar + +from astrbot import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + + +@dataclass +class CachedImage: + """Represents a cached image from a tool call.""" + + tool_call_id: str + """The tool call ID that produced this image.""" + tool_name: str + """The name of the tool that produced this image.""" + file_path: str + """The file path where the image is stored.""" + mime_type: str + """The MIME type of the image.""" + created_at: float = field(default_factory=time.time) + """Timestamp when the image was cached.""" + + +class ToolImageCache: + """Manages cached images from tool calls. + + Images are stored in data/temp/tool_images/ and can be retrieved by file path. + """ + + _instance: ClassVar["ToolImageCache | None"] = None + CACHE_DIR_NAME: ClassVar[str] = "tool_images" + # Cache expiry time in seconds (1 hour) + CACHE_EXPIRY: ClassVar[int] = 3600 + + def __new__(cls) -> "ToolImageCache": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self) -> None: + if self._initialized: + return + self._initialized = True + self._cache_dir = os.path.join(get_astrbot_temp_path(), self.CACHE_DIR_NAME) + os.makedirs(self._cache_dir, exist_ok=True) + logger.debug(f"ToolImageCache initialized, cache dir: {self._cache_dir}") + + def _get_file_extension(self, mime_type: str) -> str: + """Get file extension from MIME type.""" + mime_to_ext = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/jpg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "image/svg+xml": ".svg", + } + return mime_to_ext.get(mime_type.lower(), ".png") + + def save_image( + self, + base64_data: str, + tool_call_id: str, + tool_name: str, + index: int = 0, + mime_type: str = "image/png", + ) -> CachedImage: + """Save an image to cache and return the cached image info. + + Args: + base64_data: Base64 encoded image data. + tool_call_id: The tool call ID that produced this image. + tool_name: The name of the tool that produced this image. + index: The index of the image (for multiple images from same tool call). + mime_type: The MIME type of the image. + + Returns: + CachedImage object with file path. + """ + ext = self._get_file_extension(mime_type) + file_name = f"{tool_call_id}_{index}{ext}" + file_path = os.path.join(self._cache_dir, file_name) + + # Decode and save the image + try: + image_bytes = base64.b64decode(base64_data) + with open(file_path, "wb") as f: + f.write(image_bytes) + logger.debug(f"Saved tool image to: {file_path}") + except Exception as e: + logger.error(f"Failed to save tool image: {e}") + raise + + return CachedImage( + tool_call_id=tool_call_id, + tool_name=tool_name, + file_path=file_path, + mime_type=mime_type, + ) + + def get_image_base64_by_path( + self, file_path: str, mime_type: str = "image/png" + ) -> tuple[str, str] | None: + """Read an image file and return its base64 encoded data. + + Args: + file_path: The file path of the cached image. + mime_type: The MIME type of the image. + + Returns: + Tuple of (base64_data, mime_type) if found, None otherwise. + """ + if not os.path.exists(file_path): + return None + + try: + with open(file_path, "rb") as f: + image_bytes = f.read() + base64_data = base64.b64encode(image_bytes).decode("utf-8") + return base64_data, mime_type + except Exception as e: + logger.error(f"Failed to read cached image {file_path}: {e}") + return None + + def cleanup_expired(self) -> int: + """Clean up expired cached images. + + Returns: + Number of images cleaned up. + """ + now = time.time() + cleaned = 0 + + try: + for file_name in os.listdir(self._cache_dir): + file_path = os.path.join(self._cache_dir, file_name) + if os.path.isfile(file_path): + file_age = now - os.path.getmtime(file_path) + if file_age > self.CACHE_EXPIRY: + os.remove(file_path) + cleaned += 1 + except Exception as e: + logger.warning(f"Error during cache cleanup: {e}") + + if cleaned: + logger.info(f"Cleaned up {cleaned} expired cached images") + + return cleaned + + +# Global singleton instance +tool_image_cache = ToolImageCache()