Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions backend/app/services/llm/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from loguru import logger
from sqlalchemy import select
Expand All @@ -27,6 +27,8 @@

from .client import LLMError
from .failover import classify_error, FailoverErrorType
from .json_recovery import canonicalize_tool_arguments
from .tool_result_shaping import shape_tool_result
from .utils import LLMMessage, create_llm_client, get_max_tokens, get_model_api_key

if TYPE_CHECKING:
Expand All @@ -39,6 +41,10 @@
"send_message_to_agent", "send_feishu_message", "send_email"
})

# Cap for any single tool-result entry sent into LLM history.
# Phase 1 uses a constant; Phase 2 will make this per-agent configurable.
TOOL_RESULT_MAX_CHARS = 20_000


# ═══════════════════════════════════════════════════════════════════════════════
# Failover Guard
Expand Down Expand Up @@ -193,6 +199,47 @@ def _check_tool_requires_args(tool_name: str, args: dict) -> tuple[bool, str]:
return True, ""


def _canonicalize_tc_arguments(tc: dict, session_id: str) -> dict[str, Any]:
"""Canonicalize ``tc['function']['arguments']`` in place and return the parsed dict.

The canonical JSON is written back to ``tc['function']['arguments']`` so that
any subsequent LLM round receiving this ``tc`` in conversation history will
pass DashScope's ``function.arguments must be in JSON format`` validation.
Used by both in-flight tool loops (_process_tool_call and _try_model).
"""
fn = tc["function"]
tool_name = fn["name"]
raw_args = fn.get("arguments", "{}")
args, canonical_args, repair_method = canonicalize_tool_arguments(raw_args)
fn["arguments"] = canonical_args
if repair_method != "clean":
logger.warning(
f"[LLM] tool_call args repaired: tool={tool_name} method={repair_method} "
f"orig_len={len(raw_args)} new_len={len(canonical_args)} session={session_id}"
)
return args


def _shape_tool_content_for_context(tool_content, tool_name: str, session_id: str):
"""Return tool_content capped at TOOL_RESULT_MAX_CHARS (string content only).

Vision content is always a list[dict] per vision_inject.try_inject_screenshot_vision —
we pass those through unchanged to preserve base64 image data.
"""
# Invariant: vision content is always a list[dict]; see vision_inject.try_inject_screenshot_vision.
if not isinstance(tool_content, str):
return tool_content
shaped, was_truncated = shape_tool_result(tool_content, TOOL_RESULT_MAX_CHARS)
if was_truncated:
dropped = len(tool_content) - len(shaped)
logger.warning(
f"[LLM] tool_result truncated: tool={tool_name} "
f"orig_len={len(tool_content)} new_len={len(shaped)} "
f"dropped={dropped} session={session_id}"
)
return shaped


async def _process_tool_call(
tc: dict,
api_messages: list,
Expand All @@ -204,15 +251,11 @@ async def _process_tool_call(
full_reasoning_content: str,
) -> str:
"""Process a single tool call and return result."""
fn = tc["function"]
tool_name = fn["name"]
raw_args = fn.get("arguments", "{}")
logger.info(f"[LLM] Calling tool: {tool_name}({json.dumps(raw_args, ensure_ascii=False)[:100]})")
raw_args = tc["function"].get("arguments", "{}")
logger.info(f"[LLM] Calling tool: {tc['function']['name']}({json.dumps(raw_args, ensure_ascii=False)[:100]})")

try:
args = json.loads(raw_args) if raw_args else {}
except json.JSONDecodeError:
args = {}
args = _canonicalize_tc_arguments(tc, session_id)
tool_name = tc["function"]["name"]

# Guard: check if tool requires arguments
should_execute, error_msg = _check_tool_requires_args(tool_name, args)
Expand Down Expand Up @@ -268,6 +311,8 @@ async def _process_tool_call(
except Exception:
pass

tool_content = _shape_tool_content_for_context(tool_content, tool_name, session_id)

api_messages.append(LLMMessage(
role="tool",
tool_call_id=tc["id"],
Expand Down Expand Up @@ -404,6 +449,9 @@ async def call_llm(
logger.info(f"[LLM] Round {round_i+1}: {len(response.tool_calls)} tool call(s)")

# Add assistant message with tool calls
# NB: tc["function"] is shared by reference with _canonicalize_tc_arguments's
# in-place canonicalization — must stay as a reference (no deepcopy), or
# history entries will carry the pre-repair malformed arguments.
api_messages.append(LLMMessage(
role="assistant",
content=response.content or None,
Expand Down Expand Up @@ -734,9 +782,12 @@ async def _try_model(model: LLMModel) -> tuple[str, bool, bool]:
if agent_id and _accumulated_tokens > 0:
await record_token_usage(agent_id, _accumulated_tokens)
await client.close()
return response.content or "[Empty response]", True
return response.content or "[Empty response]", True, tool_executed

# Execute tool calls
# NB: tc["function"] is shared by reference with _canonicalize_tc_arguments's
# in-place canonicalization — must stay as a reference (no deepcopy), or
# history entries will carry the pre-repair malformed arguments.
api_messages.append(LLMMessage(
role="assistant",
content=response.content or None,
Expand All @@ -749,13 +800,8 @@ async def _try_model(model: LLMModel) -> tuple[str, bool, bool]:
))

for tc in response.tool_calls:
fn = tc["function"]
tool_name = fn["name"]
raw_args = fn.get("arguments", "{}")
try:
args = json.loads(raw_args) if raw_args else {}
except json.JSONDecodeError:
args = {}
args = _canonicalize_tc_arguments(tc, session_id)
tool_name = tc["function"]["name"]

tool_executed = True
result = await execute_tool(
Expand All @@ -764,10 +810,11 @@ async def _try_model(model: LLMModel) -> tuple[str, bool, bool]:
user_id=agent.creator_id,
session_id=session_id,
)
shaped_content = _shape_tool_content_for_context(str(result), tool_name, session_id)
api_messages.append(LLMMessage(
role="tool",
tool_call_id=tc["id"],
content=str(result),
content=shaped_content,
))

if agent_id and _accumulated_tokens > 0:
Expand Down
161 changes: 161 additions & 0 deletions backend/app/services/llm/json_recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Tool-call JSON argument recovery and canonicalization.

LLM streaming sometimes produces slightly malformed JSON for tool_call.arguments:
trailing commas, unescaped control characters in string values, truncated tokens.
DashScope validates this field strictly server-side and rejects the request with
HTTP 400 `function.arguments parameter must be in JSON format` on the NEXT round.

`canonicalize_tool_arguments` accepts any raw string and returns a parsed dict
plus a canonical JSON string that is guaranteed to round-trip through
`json.loads`. It never raises.

Repair methods reported back to callers:

- ``"clean"`` — ``json.loads`` succeeded on the raw input and it was a dict.
- ``"trailing_comma"`` — succeeded after stripping trailing commas before
``}`` or ``]`` (string-aware so commas inside string literals are kept).
- ``"control_char_escape"`` — succeeded after escaping unescaped control
characters inside JSON string values.
- ``"non_dict_coerced"`` — a parse attempt succeeded but produced a non-dict
top-level value (list, scalar, ``null``). Coerced to ``{}``. Callers
should log/alert on this because real user data was dropped.
- ``"failed"`` — every repair attempt raised ``json.JSONDecodeError``.
Returns ``{}`` / ``"{}"``.
"""
from __future__ import annotations

import json
from typing import Any


def _strip_trailing_commas(s: str) -> str:
"""Remove trailing commas before } or ] — but only when OUTSIDE a JSON string.

Walks the input char by char so that a comma inside a string literal
(e.g. `"hello,}"`) is not confused with a trailing comma in the outer
structure.
"""
out: list[str] = []
in_string = False
escape_next = False
i = 0
n = len(s)
while i < n:
ch = s[i]
if escape_next:
out.append(ch)
escape_next = False
i += 1
continue
if ch == '\\' and in_string:
out.append(ch)
escape_next = True
i += 1
continue
if ch == '"':
in_string = not in_string
out.append(ch)
i += 1
continue
if not in_string and ch == ',':
# Peek ahead past whitespace to see if next non-ws is } or ]
j = i + 1
while j < n and s[j] in ' \t\n\r':
j += 1
if j < n and s[j] in '}]':
# Drop the comma, keep the whitespace
i += 1
continue
out.append(ch)
i += 1
return ''.join(out)


def _escape_control_chars_in_strings(s: str) -> str:
"""Scan through string and escape unescaped control chars inside JSON string values.

We can't do this by simple regex because we only want to escape control
chars *inside string values*, not outside. Walk char by char tracking
whether we're inside a string.
"""
out: list[str] = []
in_string = False
escape_next = False
for ch in s:
if escape_next:
out.append(ch)
escape_next = False
continue
if ch == '\\' and in_string:
out.append(ch)
escape_next = True
continue
if ch == '"':
in_string = not in_string
out.append(ch)
continue
if in_string and ord(ch) < 0x20:
# Escape control chars per JSON spec
if ch == '\n':
out.append('\\n')
elif ch == '\r':
out.append('\\r')
elif ch == '\t':
out.append('\\t')
elif ch == '\b':
out.append('\\b')
elif ch == '\f':
out.append('\\f')
else:
out.append(f'\\u{ord(ch):04x}')
continue
out.append(ch)
return ''.join(out)


def canonicalize_tool_arguments(raw: str) -> tuple[dict[str, Any], str, str]:
"""Parse and canonicalize a raw tool_call.arguments string.

Returns:
(parsed_dict, canonical_json_string, repair_method)

repair_method is one of: "clean", "trailing_comma", "control_char_escape",
"non_dict_coerced", "failed". Never raises.
"""
if not raw:
return {}, "{}", "clean"

# Attempt 1: clean parse
try:
parsed = json.loads(raw)
if not isinstance(parsed, dict):
return {}, "{}", "non_dict_coerced"
canonical = json.dumps(parsed, ensure_ascii=False)
return parsed, canonical, "clean"
except json.JSONDecodeError:
pass

# Attempt 2: strip trailing commas
cleaned = _strip_trailing_commas(raw)
try:
parsed = json.loads(cleaned)
if not isinstance(parsed, dict):
return {}, "{}", "non_dict_coerced"
canonical = json.dumps(parsed, ensure_ascii=False)
return parsed, canonical, "trailing_comma"
except json.JSONDecodeError:
pass

# Attempt 3: escape unescaped control chars in strings, then retry
escaped = _escape_control_chars_in_strings(cleaned)
try:
parsed = json.loads(escaped)
if not isinstance(parsed, dict):
return {}, "{}", "non_dict_coerced"
canonical = json.dumps(parsed, ensure_ascii=False)
return parsed, canonical, "control_char_escape"
except json.JSONDecodeError:
pass

# Gave up — return safe empty
return {}, "{}", "failed"
43 changes: 43 additions & 0 deletions backend/app/services/llm/tool_result_shaping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Shape oversized tool results to stay within per-call size budget.

A single tool result (e.g. a long a long tool output JSON, or an
`execute_code` stdout dump) can exceed 50KB. Accumulating many such results
across 10+ tool rounds blows past Qwen3.5-plus's ~983k-char input limit and
causes `HTTP 400: Range of input length should be [1, 983616]`.

This module applies a head+tail truncation with an explicit marker so the
LLM can see that truncation happened and ask for more if needed.

A degenerate budget (``max_chars <= 0``) returns an empty string, with
``was_truncated=True`` iff the input was non-empty.
"""
from __future__ import annotations


def shape_tool_result(result, max_chars: int) -> tuple[str, bool]:
"""Return (possibly-truncated string, was_truncated).

Strategy for oversized results: keep ~60% head and ~30% tail, with a
marker in between describing how much was dropped. Total output stays
within max_chars plus a small marker overhead (~120 chars).

Edge case: if ``max_chars <= 0`` the budget is degenerate — there is no
room for any content (nor for the marker itself), so an empty string is
returned, with ``was_truncated=True`` iff the input was non-empty.
"""
s = str(result) if not isinstance(result, str) else result
if max_chars <= 0:
# Degenerate budget — treat as "drop everything", no marker (it would
# exceed max_chars itself). was_truncated reflects whether any content
# was actually dropped.
return "", len(s) > 0
if len(s) <= max_chars:
return s, False

# Budget split: 60% head, 30% tail, 10% safety
head_budget = int(max_chars * 0.60)
tail_budget = int(max_chars * 0.30)
dropped = len(s) - head_budget - tail_budget
marker = f"\n\n[... truncated: {dropped:,} chars omitted (head {head_budget:,} + tail {tail_budget:,} kept) ...]\n\n"

return s[:head_budget] + marker + s[-tail_budget:], True
Loading