From fa7582086393449be671159ff2b1e81ef079126b Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Thu, 26 Feb 2026 14:35:45 +0100 Subject: [PATCH 01/16] feat: tanstack ai middleware --- docs/config.json | 4 + docs/guides/middleware.md | 650 +++++ .../ai/src/activities/chat/index.ts | 266 +- .../src/activities/chat/middleware/compose.ts | 185 ++ .../src/activities/chat/middleware/index.ts | 22 + .../chat/middleware/tool-cache-middleware.ts | 178 ++ .../src/activities/chat/middleware/types.ts | 311 +++ .../src/activities/chat/tools/tool-calls.ts | 306 ++- packages/typescript/ai/src/index.ts | 20 + packages/typescript/ai/tests/chat.test.ts | 140 +- .../typescript/ai/tests/middleware.test.ts | 2404 +++++++++++++++++ packages/typescript/ai/tests/test-utils.ts | 181 ++ .../ai/tests/tool-cache-middleware.test.ts | 722 +++++ 13 files changed, 5157 insertions(+), 232 deletions(-) create mode 100644 docs/guides/middleware.md create mode 100644 packages/typescript/ai/src/activities/chat/middleware/compose.ts create mode 100644 packages/typescript/ai/src/activities/chat/middleware/index.ts create mode 100644 packages/typescript/ai/src/activities/chat/middleware/tool-cache-middleware.ts create mode 100644 packages/typescript/ai/src/activities/chat/middleware/types.ts create mode 100644 packages/typescript/ai/tests/middleware.test.ts create mode 100644 packages/typescript/ai/tests/test-utils.ts create mode 100644 packages/typescript/ai/tests/tool-cache-middleware.test.ts diff --git a/docs/config.json b/docs/config.json index 1439169ae..a6ac175d9 100644 --- a/docs/config.json +++ b/docs/config.json @@ -50,6 +50,10 @@ "label": "Agentic Cycle", "to": "guides/agentic-cycle" }, + { + "label": "Middleware", + "to": "guides/middleware" + }, { "label": "Structured Outputs", "to": "guides/structured-outputs" diff --git a/docs/guides/middleware.md b/docs/guides/middleware.md new file mode 100644 index 000000000..d2e469d56 --- /dev/null +++ b/docs/guides/middleware.md @@ -0,0 +1,650 @@ +--- +title: Middleware +id: middleware +order: 7 +--- + +Middleware lets you hook into every stage of the `chat()` lifecycle — from configuration to streaming, tool execution, usage tracking, and completion. You can observe, transform, or short-circuit behavior at each stage without modifying your adapter or tool implementations. + +Common use cases include: + +- **Logging and observability** — track token usage, tool execution timing, errors +- **Configuration transforms** — inject system prompts, adjust temperature per iteration, filter tools +- **Stream processing** — redact sensitive content, transform chunks, drop unwanted events +- **Tool call interception** — validate arguments, cache results, abort on dangerous calls +- **Side effects** — send analytics, update databases, trigger notifications + +## Quick Start + +Pass an array of middleware to the `chat()` function: + +```typescript +import { chat, type ChatMiddleware } from "@tanstack/ai"; +import { openaiText } from "@tanstack/ai-openai"; + +const logger: ChatMiddleware = { + name: "logger", + onStart: (ctx) => { + console.log(`[${ctx.requestId}] Chat started`); + }, + onFinish: (ctx, info) => { + console.log(`[${ctx.requestId}] Finished in ${info.duration}ms`); + }, +}; + +const stream = chat({ + adapter: openaiText("gpt-4o"), + messages: [{ role: "user", content: "Hello" }], + middleware: [logger], +}); +``` + +## Lifecycle Overview + +Every `chat()` invocation follows a predictable lifecycle. Middleware hooks fire at specific phases: + +```mermaid +graph TD + A["chat() called"] --> B["onConfig (phase: init)"] + B --> C[onStart] + C --> D["onConfig (phase: beforeModel)"] + D --> E["Adapter streams response"] + E --> F["onChunk (for each chunk)"] + F --> G{Tool calls?} + G -->|No| H[onUsage] + G -->|Yes| I[onBeforeToolCall] + I --> J[Tool executes] + J --> K[onAfterToolCall] + K --> L{Continue loop?} + L -->|Yes| D + L -->|No| H + H --> M{Outcome} + M -->|Success| N[onFinish] + M -->|Abort| O[onAbort] + M -->|Error| P[onError] + + style I fill:#e1f5ff + style J fill:#ffe1e1 + style N fill:#e1ffe1 + style O fill:#fff4e1 + style P fill:#ffe1e1 +``` + +### Phase Transitions + +The context's `phase` field tracks where you are in the lifecycle: + +| Phase | When | Hooks Called | +|-------|------|-------------| +| `init` | Once at startup | `onConfig` | +| `beforeModel` | Before each model call (per iteration) | `onConfig` | +| `modelStream` | While adapter streams chunks | `onChunk`, `onUsage` | +| `beforeTools` | Before tool execution | `onBeforeToolCall` | +| `afterTools` | After tool execution | `onAfterToolCall` | + +## Hooks Reference + +### onConfig + +Called twice per iteration: once during `init` (startup) and once during `beforeModel` (before each model call). Use it to transform the configuration that the model receives. + +Return a **partial** config object with only the fields you want to change — they are shallow-merged with the current config automatically. No need to spread the existing config. + +```typescript +const dynamicTemperature: ChatMiddleware = { + name: "dynamic-temperature", + onConfig: (ctx, config) => { + if (ctx.phase === "init") { + // Add a system prompt at startup — only systemPrompts is overwritten + return { + systemPrompts: [ + ...config.systemPrompts, + "You are a helpful assistant.", + ], + }; + } + + if (ctx.phase === "beforeModel" && ctx.iteration > 0) { + // Increase temperature on retries — other fields stay unchanged + return { + temperature: Math.min((config.temperature ?? 0.7) + 0.1, 1.0), + }; + } + }, +}; +``` + +**Config fields you can transform:** + +| Field | Type | Description | +|-------|------|-------------| +| `messages` | `ModelMessage[]` | Conversation history | +| `systemPrompts` | `string[]` | System prompts | +| `tools` | `Tool[]` | Available tools | +| `temperature` | `number` | Sampling temperature | +| `topP` | `number` | Nucleus sampling | +| `maxTokens` | `number` | Token limit | +| `metadata` | `Record` | Request metadata | +| `modelOptions` | `Record` | Provider-specific options | + +When multiple middleware define `onConfig`, the config is **piped** through them in order — each receives the merged config from the previous middleware. + +### onStart + +Called once after the initial `onConfig` completes. Use it for setup tasks like initializing timers or logging. + +```typescript +const timer: ChatMiddleware = { + name: "timer", + onStart: (ctx) => { + console.log(`Request ${ctx.requestId} started at iteration ${ctx.iteration}`); + }, +}; +``` + +### onChunk + +Called for every chunk streamed from the adapter. You can observe, transform, expand, or drop chunks. + +```typescript +const redactor: ChatMiddleware = { + name: "redactor", + onChunk: (ctx, chunk) => { + if (chunk.type === "TEXT_MESSAGE_CONTENT") { + // Transform: redact sensitive content + return { + ...chunk, + delta: chunk.delta.replace(/\b\d{3}-\d{2}-\d{4}\b/g, "[REDACTED]"), + }; + } + // Return void to pass through unchanged + }, +}; +``` + +**Return values:** + +| Return | Effect | +|--------|--------| +| `void` / `undefined` | Chunk passes through unchanged | +| `StreamChunk` | Replaces the original chunk | +| `StreamChunk[]` | Expands into multiple chunks | +| `null` | Drops the chunk entirely | + +When multiple middleware define `onChunk`, chunks flow through them in order. If one middleware drops a chunk (returns `null`), subsequent middleware never see it. + +### onBeforeToolCall + +Called before each tool executes. The first middleware that returns a non-void decision short-circuits — remaining middleware are skipped for that tool call. + +```typescript +const guard: ChatMiddleware = { + name: "guard", + onBeforeToolCall: (ctx, hookCtx) => { + // Block dangerous tools + if (hookCtx.toolName === "deleteDatabase") { + return { type: "abort", reason: "Dangerous operation blocked" }; + } + + // Validate and transform arguments + if (hookCtx.toolName === "search" && !hookCtx.args.limit) { + return { + type: "transformArgs", + args: { ...hookCtx.args, limit: 10 }, + }; + } + }, +}; +``` + +**Decision types:** + +| Decision | Effect | +|----------|--------| +| `void` / `undefined` | Continue normally, next middleware can decide | +| `{ type: 'transformArgs', args }` | Replace tool arguments before execution | +| `{ type: 'skip', result }` | Skip execution entirely, use provided result | +| `{ type: 'abort', reason? }` | Abort the entire chat run | + +The `hookCtx` provides: + +| Field | Type | Description | +|-------|------|-------------| +| `toolCall` | `ToolCall` | Raw tool call object | +| `tool` | `Tool \| undefined` | Resolved tool definition | +| `args` | `unknown` | Parsed arguments | +| `toolName` | `string` | Tool name | +| `toolCallId` | `string` | Tool call ID | + +### onAfterToolCall + +Called after each tool execution (or skip). All middleware run — there is no short-circuiting. + +```typescript +const toolLogger: ChatMiddleware = { + name: "tool-logger", + onAfterToolCall: (ctx, info) => { + if (info.ok) { + console.log(`${info.toolName} completed in ${info.duration}ms`); + } else { + console.error(`${info.toolName} failed:`, info.error); + } + }, +}; +``` + +The `info` object provides: + +| Field | Type | Description | +|-------|------|-------------| +| `toolCall` | `ToolCall` | Raw tool call object | +| `tool` | `Tool \| undefined` | Resolved tool definition | +| `toolName` | `string` | Tool name | +| `toolCallId` | `string` | Tool call ID | +| `ok` | `boolean` | Whether execution succeeded | +| `duration` | `number` | Execution time in milliseconds | +| `result` | `unknown` | Result (when `ok` is true) | +| `error` | `unknown` | Error (when `ok` is false) | + +### onUsage + +Called once per model iteration when the `RUN_FINISHED` chunk includes usage data. Receives the usage object directly. + +```typescript +const usageTracker: ChatMiddleware = { + name: "usage-tracker", + onUsage: (ctx, usage) => { + console.log( + `Iteration ${ctx.iteration}: ${usage.totalTokens} tokens` + ); + }, +}; +``` + +The `usage` object: + +| Field | Type | Description | +|-------|------|-------------| +| `promptTokens` | `number` | Input tokens | +| `completionTokens` | `number` | Output tokens | +| `totalTokens` | `number` | Total tokens | + +### Terminal Hooks: onFinish, onAbort, onError + +Exactly **one** terminal hook fires per `chat()` invocation. They are mutually exclusive: + +| Hook | When it fires | +|------|--------------| +| `onFinish` | Run completed normally | +| `onAbort` | Run was aborted (via `ctx.abort()`, an external `AbortSignal`, or a `{ type: 'abort' }` decision from `onBeforeToolCall`) | +| `onError` | An unhandled error occurred | + +```typescript +const terminal: ChatMiddleware = { + name: "terminal", + onFinish: (ctx, info) => { + console.log(`Finished: ${info.finishReason}, ${info.duration}ms`); + console.log(`Content: ${info.content}`); + if (info.usage) { + console.log(`Tokens: ${info.usage.totalTokens}`); + } + }, + onAbort: (ctx, info) => { + console.log(`Aborted: ${info.reason}, ${info.duration}ms`); + }, + onError: (ctx, info) => { + console.error(`Error after ${info.duration}ms:`, info.error); + }, +}; +``` + +## Context Object + +Every hook receives a `ChatMiddlewareContext` as its first argument. It provides request-scoped information and control functions: + +| Field | Type | Description | +|-------|------|-------------| +| `requestId` | `string` | Unique ID for this chat request | +| `streamId` | `string` | Unique ID for this stream | +| `conversationId` | `string \| undefined` | User-provided conversation ID | +| `phase` | `ChatMiddlewarePhase` | Current lifecycle phase | +| `iteration` | `number` | Agent loop iteration (0-indexed) | +| `chunkIndex` | `number` | Running count of chunks yielded | +| `signal` | `AbortSignal \| undefined` | External abort signal | +| `abort(reason?)` | `function` | Abort the run from within middleware | +| `context` | `unknown` | User-provided context value | +| `defer(promise)` | `function` | Register a non-blocking side-effect | + +### Aborting from Middleware + +Call `ctx.abort()` to gracefully stop the run. This triggers the `onAbort` terminal hook: + +```typescript +const timeout: ChatMiddleware = { + name: "timeout", + onChunk: (ctx) => { + if (ctx.chunkIndex > 1000) { + ctx.abort("Too many chunks"); + } + }, +}; +``` + +### Deferred Side Effects + +Use `ctx.defer()` to register promises that run after the terminal hook without blocking the stream: + +```typescript +const analytics: ChatMiddleware = { + name: "analytics", + onFinish: (ctx, info) => { + ctx.defer( + fetch("/api/analytics", { + method: "POST", + body: JSON.stringify({ + requestId: ctx.requestId, + duration: info.duration, + tokens: info.usage?.totalTokens, + }), + }) + ); + }, +}; +``` + +## Composing Multiple Middleware + +Middleware execute in array order. The ordering matters for hooks that pipe or short-circuit: + +```typescript +const stream = chat({ + adapter: openaiText("gpt-4o"), + messages, + middleware: [authMiddleware, loggingMiddleware, cachingMiddleware], +}); +``` + +### Composition Rules + +| Hook | Composition | Effect of Order | +|------|------------|----------------| +| `onConfig` | **Piped** — each receives previous output | Earlier middleware transforms first | +| `onStart` | Sequential | All run in order | +| `onChunk` | **Piped** — chunks flow through each middleware | If first drops a chunk, later middleware never see it | +| `onBeforeToolCall` | **First-win** — first non-void decision wins | Earlier middleware has priority | +| `onAfterToolCall` | Sequential | All run in order | +| `onUsage` | Sequential | All run in order | +| `onFinish/onAbort/onError` | Sequential | All run in order | + +## Built-in Middleware + +### toolCacheMiddleware + +Caches tool call results based on tool name and arguments. When a tool is called with the same name and arguments as a previous call, the cached result is returned immediately without re-executing the tool. + +```typescript +import { chat, toolCacheMiddleware } from "@tanstack/ai"; + +const stream = chat({ + adapter: openaiText("gpt-4o"), + messages, + tools: [weatherTool, stockTool], + middleware: [ + toolCacheMiddleware({ + ttl: 60_000, // Cache entries expire after 60 seconds + maxSize: 50, // Keep at most 50 entries (LRU eviction) + toolNames: ["getWeather"], // Only cache specific tools + }), + ], +}); +``` + +**Options:** + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `maxSize` | `number` | `100` | Maximum cache entries. Oldest evicted first (LRU). Only applies to the default in-memory storage. | +| `ttl` | `number` | `Infinity` | Time-to-live in milliseconds. Expired entries are not served. | +| `toolNames` | `string[]` | All tools | Only cache these tools. Others pass through. | +| `keyFn` | `(toolName, args) => string` | `JSON.stringify([toolName, args])` | Custom cache key derivation. | +| `storage` | `ToolCacheStorage` | In-memory Map | Custom storage backend. When provided, `maxSize` is ignored — the storage manages its own capacity. | + +**Behaviors:** + +- Only successful tool calls are cached — errors are never stored +- Cache hits trigger `{ type: 'skip', result }` via `onBeforeToolCall` +- LRU eviction: when `maxSize` is reached, the oldest entry is removed (default storage only) +- Cache hits refresh the entry's LRU position (moved to most-recently-used) + +**Custom key function** — useful when you want to ignore certain arguments: + +```typescript +toolCacheMiddleware({ + keyFn: (toolName, args) => { + // Ignore pagination, cache by query only + const { page, ...rest } = args as Record; + return JSON.stringify([toolName, rest]); + }, +}); +``` + +#### Custom Storage + +By default the cache lives in-memory and is scoped to a single `toolCacheMiddleware()` instance. Pass a `storage` option to use an external backend like Redis, localStorage, or a database. This also enables **sharing a cache across multiple `chat()` calls**. + +The storage interface: + +```typescript +import type { ToolCacheStorage, ToolCacheEntry } from "@tanstack/ai"; + +interface ToolCacheStorage { + getItem: (key: string) => ToolCacheEntry | undefined | Promise; + setItem: (key: string, value: ToolCacheEntry) => void | Promise; + deleteItem: (key: string) => void | Promise; +} + +// ToolCacheEntry is { result: unknown, timestamp: number } +``` + +All methods may return a `Promise` for async backends. The middleware handles TTL checking — your storage just needs to store and retrieve entries. + +**Redis example:** + +```typescript +import { createClient } from "redis"; +import { toolCacheMiddleware, type ToolCacheStorage } from "@tanstack/ai"; + +const redis = createClient(); + +const redisStorage: ToolCacheStorage = { + getItem: async (key) => { + const raw = await redis.get(`tool-cache:${key}`); + return raw ? JSON.parse(raw) : undefined; + }, + setItem: async (key, value) => { + await redis.set(`tool-cache:${key}`, JSON.stringify(value)); + }, + deleteItem: async (key) => { + await redis.del(`tool-cache:${key}`); + }, +}; + +const stream = chat({ + adapter, + messages, + tools: [weatherTool], + middleware: [toolCacheMiddleware({ storage: redisStorage, ttl: 60_000 })], +}); +``` + +**Sharing a cache across requests:** + +```typescript +// Create storage once, reuse across chat() calls +const sharedStorage: ToolCacheStorage = { + getItem: (key) => globalCache.get(key), + setItem: (key, value) => { globalCache.set(key, value); }, + deleteItem: (key) => { globalCache.delete(key); }, +}; + +// Both requests share the same cache +app.post("/api/chat", async (req) => { + const stream = chat({ + adapter, + messages: req.body.messages, + tools: [weatherTool], + middleware: [toolCacheMiddleware({ storage: sharedStorage })], + }); + return toServerSentEventsResponse(stream); +}); +``` + +## Recipes + +### Rate Limiting + +Limit the number of tool calls per request: + +```typescript +function rateLimitMiddleware(maxCalls: number): ChatMiddleware { + return { + name: "rate-limit", + onBeforeToolCall: (ctx, hookCtx) => { + if (ctx.iteration >= maxCalls) { + return { + type: "abort", + reason: `Rate limit: exceeded ${maxCalls} tool calls`, + }; + } + }, + }; +} +``` + +### Audit Trail + +Log every action for compliance: + +```typescript +const auditTrail: ChatMiddleware = { + name: "audit-trail", + onStart: (ctx) => { + ctx.defer( + db.auditLog.create({ + requestId: ctx.requestId, + event: "chat_started", + timestamp: Date.now(), + }) + ); + }, + onAfterToolCall: (ctx, info) => { + ctx.defer( + db.auditLog.create({ + requestId: ctx.requestId, + event: "tool_executed", + toolName: info.toolName, + success: info.ok, + duration: info.duration, + timestamp: Date.now(), + }) + ); + }, + onFinish: (ctx, info) => { + ctx.defer( + db.auditLog.create({ + requestId: ctx.requestId, + event: "chat_finished", + duration: info.duration, + tokens: info.usage?.totalTokens, + timestamp: Date.now(), + }) + ); + }, +}; +``` + +### Per-Iteration Tool Swapping + +Expose different tools at different stages of the agent loop: + +```typescript +const toolSwapper: ChatMiddleware = { + name: "tool-swapper", + onConfig: (ctx, config) => { + if (ctx.phase !== "beforeModel") return; + + if (ctx.iteration === 0) { + // First iteration: only allow search + return { + tools: config.tools.filter((t) => t.name === "search"), + }; + } + // Later iterations: allow all tools + }, +}; +``` + +### Content Filtering + +Drop or transform chunks before they reach the consumer: + +```typescript +const contentFilter: ChatMiddleware = { + name: "content-filter", + onChunk: (ctx, chunk) => { + if (chunk.type === "TEXT_MESSAGE_CONTENT") { + if (containsProfanity(chunk.delta)) { + // Drop the chunk entirely + return null; + } + } + }, +}; +``` + +### Error Recovery with Retry Logging + +```typescript +const errorRecovery: ChatMiddleware = { + name: "error-recovery", + onError: (ctx, info) => { + ctx.defer( + alertService.send({ + level: "error", + message: `Chat ${ctx.requestId} failed after ${info.duration}ms`, + error: String(info.error), + }) + ); + }, +}; +``` + +## TypeScript Types + +All middleware types are exported from `@tanstack/ai`: + +```typescript +import type { + ChatMiddleware, + ChatMiddlewareContext, + ChatMiddlewarePhase, + ChatMiddlewareConfig, + ToolCallHookContext, + BeforeToolCallDecision, + AfterToolCallInfo, + UsageInfo, + FinishInfo, + AbortInfo, + ErrorInfo, + ToolCacheMiddlewareOptions, + ToolCacheStorage, + ToolCacheEntry, +} from "@tanstack/ai"; +``` + +## Next Steps + +- [Tools](./tools) — Learn about the isomorphic tool system +- [Agentic Cycle](./agentic-cycle) — Understand the multi-step agent loop +- [Observability](./observability) — Event-driven observability with the event client +- [Streaming](./streaming) — How streaming works in TanStack AI diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index f7f4b7a66..59339287e 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -7,7 +7,11 @@ import { aiEventClient } from '../../event-client.js' import { streamToText } from '../../stream-to-response.js' -import { ToolCallManager, executeToolCalls } from './tools/tool-calls' +import { + MiddlewareAbortError, + ToolCallManager, + executeToolCalls, +} from './tools/tool-calls' import { convertSchemaToJsonSchema, isStandardSchema, @@ -15,6 +19,7 @@ import { } from './tools/schema-converter' import { maxIterations as maxIterationsStrategy } from './agent-loop-strategies' import { convertMessagesToModelMessages } from './messages' +import { MiddlewareRunner } from './middleware/compose' import type { ApprovalRequest, ClientToolRequest, @@ -38,6 +43,12 @@ import type { ToolCallEndEvent, ToolCallStartEvent, } from '../../types' +import type { + ChatMiddleware, + ChatMiddlewareConfig, + ChatMiddlewareContext, + ChatMiddlewarePhase, +} from './middleware/types' // =========================== // Activity Kind @@ -132,6 +143,25 @@ export interface TextActivityOptions< * ``` */ stream?: TStream + /** + * Optional middleware array for observing/transforming chat behavior. + * Middleware hooks are called in array order. See {@link ChatMiddleware} for available hooks. + * + * @example + * ```ts + * const stream = chat({ + * adapter: openaiText('gpt-4o'), + * messages: [...], + * middleware: [loggingMiddleware, redactionMiddleware], + * }) + * ``` + */ + middleware?: Array + /** + * Opaque user-provided context value passed to middleware hooks. + * Can be used to pass request-scoped data (e.g., user ID, request context). + */ + context?: unknown } // =========================== @@ -191,6 +221,8 @@ interface TextEngineConfig< adapter: TAdapter systemPrompts?: Array params: TParams + middleware?: Array + context?: unknown } type ToolPhaseResult = 'continue' | 'stop' | 'wait' @@ -201,9 +233,9 @@ class TextEngine< TParams extends TextOptions = TextOptions, > { private readonly adapter: TAdapter - private readonly params: TParams - private readonly systemPrompts: Array - private readonly tools: ReadonlyArray + private params: TParams + private systemPrompts: Array + private tools: Array private readonly loopStrategy: AgentLoopStrategy private readonly toolCallManager: ToolCallManager private readonly initialMessageCount: number @@ -230,6 +262,14 @@ class TextEngine< private readonly initialApprovals: Map private readonly initialClientToolResults: Map + // Middleware support + private readonly middlewareRunner: MiddlewareRunner + private readonly middlewareCtx: ChatMiddlewareContext + private readonly deferredPromises: Array> = [] + private abortReason?: string + private middlewareAbortController?: AbortController + private terminalHookCalled = false + constructor(config: TextEngineConfig) { this.adapter = config.adapter this.params = config.params @@ -260,6 +300,27 @@ class TextEngine< ? { signal: config.params.abortController.signal } : undefined this.effectiveSignal = config.params.abortController?.signal + + // Initialize middleware + this.middlewareRunner = new MiddlewareRunner(config.middleware || []) + this.middlewareAbortController = new AbortController() + this.middlewareCtx = { + requestId: this.requestId, + streamId: this.streamId, + conversationId: config.params.conversationId, + phase: 'init' as ChatMiddlewarePhase, + iteration: 0, + chunkIndex: 0, + signal: this.effectiveSignal, + abort: (reason?: string) => { + this.abortReason = reason + this.middlewareAbortController?.abort(reason) + }, + context: config.context, + defer: (promise: Promise) => { + this.deferredPromises.push(promise) + }, + } } /** Get the accumulated content after the chat loop completes */ @@ -276,19 +337,48 @@ class TextEngine< this.beforeRun() try { + // Run initial onConfig (phase = init) + if (this.middlewareRunner.hasMiddleware) { + this.middlewareCtx.phase = 'init' + const initialConfig = this.buildMiddlewareConfig() + const transformedConfig = await this.middlewareRunner.runOnConfig( + this.middlewareCtx, + initialConfig, + ) + this.applyMiddlewareConfig(transformedConfig) + + // Run onStart + await this.middlewareRunner.runOnStart(this.middlewareCtx) + } + const pendingPhase = yield* this.checkForPendingToolCalls() if (pendingPhase === 'wait') { return } do { - if (this.earlyTermination || this.isAborted()) { + if ( + this.earlyTermination || + this.isCancelled() + ) { return } this.beginCycle() if (this.cyclePhase === 'processText') { + // Run onConfig before each model call (phase = beforeModel) + if (this.middlewareRunner.hasMiddleware) { + this.middlewareCtx.phase = 'beforeModel' + this.middlewareCtx.iteration = this.iterationCount + const iterConfig = this.buildMiddlewareConfig() + const transformedConfig = await this.middlewareRunner.runOnConfig( + this.middlewareCtx, + iterConfig, + ) + this.applyMiddlewareConfig(transformedConfig) + } + yield* this.streamModelResponse() } else { yield* this.processToolCalls() @@ -296,8 +386,59 @@ class TextEngine< this.endCycle() } while (this.shouldContinue()) + + // Call terminal onFinish hook + if (this.middlewareRunner.hasMiddleware && !this.terminalHookCalled) { + this.terminalHookCalled = true + await this.middlewareRunner.runOnFinish(this.middlewareCtx, { + finishReason: this.lastFinishReason, + duration: Date.now() - this.streamStartTime, + content: this.accumulatedContent, + usage: this.finishedEvent?.usage, + }) + } + } catch (error: unknown) { + if (this.middlewareRunner.hasMiddleware && !this.terminalHookCalled) { + this.terminalHookCalled = true + if (error instanceof MiddlewareAbortError) { + // Middleware abort decision — call onAbort, not onError + this.abortReason = error.message + await this.middlewareRunner.runOnAbort(this.middlewareCtx, { + reason: error.message, + duration: Date.now() - this.streamStartTime, + }) + } else { + // Genuine error — call onError + await this.middlewareRunner.runOnError(this.middlewareCtx, { + error, + duration: Date.now() - this.streamStartTime, + }) + } + } + // Don't rethrow middleware abort errors — the run just stops gracefully + if (!(error instanceof MiddlewareAbortError)) { + throw error + } } finally { + // Check for abort terminal hook + if ( + this.middlewareRunner.hasMiddleware && + !this.terminalHookCalled && + this.isCancelled() + ) { + this.terminalHookCalled = true + await this.middlewareRunner.runOnAbort(this.middlewareCtx, { + reason: this.abortReason, + duration: Date.now() - this.streamStartTime, + }) + } + this.afterRun() + + // Await deferred promises (non-blocking side effects) + if (this.deferredPromises.length > 0) { + await Promise.allSettled(this.deferredPromises) + } } } @@ -422,6 +563,10 @@ class TextEngine< : undefined, })) + if (this.middlewareRunner.hasMiddleware) { + this.middlewareCtx.phase = 'modelStream' + } + for await (const chunk of this.adapter.chatStream({ model: this.params.model, messages: this.messages, @@ -434,14 +579,35 @@ class TextEngine< modelOptions, systemPrompts: this.systemPrompts, })) { - if (this.isAborted()) { + if (this.isCancelled()) { break } this.totalChunkCount++ - yield chunk - this.handleStreamChunk(chunk) + // Pipe chunk through middleware + if (this.middlewareRunner.hasMiddleware) { + const outputChunks = await this.middlewareRunner.runOnChunk( + this.middlewareCtx, + chunk, + ) + for (const outputChunk of outputChunks) { + yield outputChunk + this.handleStreamChunk(outputChunk) + this.middlewareCtx.chunkIndex++ + } + + // Handle usage via middleware + if (chunk.type === 'RUN_FINISHED' && chunk.usage) { + await this.middlewareRunner.runOnUsage( + this.middlewareCtx, + chunk.usage, + ) + } + } else { + yield chunk + this.handleStreamChunk(chunk) + } if (this.earlyTermination) { break @@ -663,6 +829,10 @@ class TextEngine< this.addAssistantToolCallMessage(toolCalls) + if (this.middlewareRunner.hasMiddleware) { + this.middlewareCtx.phase = 'beforeTools' + } + const { approvals, clientToolResults } = this.collectClientState() const generator = executeToolCalls( @@ -671,11 +841,44 @@ class TextEngine< approvals, clientToolResults, (eventName, data) => this.createCustomEventChunk(eventName, data), + this.middlewareRunner.hasMiddleware + ? { + onBeforeToolCall: async (toolCall, tool, args) => { + const hookCtx = { + toolCall, + tool, + args, + toolName: toolCall.function.name, + toolCallId: toolCall.id, + } + return this.middlewareRunner.runOnBeforeToolCall( + this.middlewareCtx, + hookCtx, + ) + }, + onAfterToolCall: async (info) => { + await this.middlewareRunner.runOnAfterToolCall( + this.middlewareCtx, + info, + ) + }, + } + : undefined, ) // Consume the async generator, yielding custom events and collecting the return value const executionResult = yield* this.drainToolCallGenerator(generator) + if (this.middlewareRunner.hasMiddleware) { + this.middlewareCtx.phase = 'afterTools' + } + + // Check if middleware aborted during tool execution + if (this.isMiddlewareAborted()) { + this.setToolPhase('stop') + return + } + if ( executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0 @@ -1010,6 +1213,41 @@ class TextEngine< return !!this.effectiveSignal?.aborted } + private isMiddlewareAborted(): boolean { + return !!this.middlewareAbortController?.signal.aborted + } + + private isCancelled(): boolean { + return this.isAborted() || this.isMiddlewareAborted() + } + + private buildMiddlewareConfig(): ChatMiddlewareConfig { + return { + messages: this.messages, + systemPrompts: [...this.systemPrompts], + tools: [...this.tools], + temperature: this.params.temperature, + topP: this.params.topP, + maxTokens: this.params.maxTokens, + metadata: this.params.metadata, + modelOptions: this.params.modelOptions, + } + } + + private applyMiddlewareConfig(config: ChatMiddlewareConfig): void { + this.messages = config.messages + this.systemPrompts = config.systemPrompts + this.tools = config.tools + this.params = { + ...this.params, + temperature: config.temperature, + topP: config.topP, + maxTokens: config.maxTokens, + metadata: config.metadata, + modelOptions: config.modelOptions, + } + } + private buildTextEventContext(): { requestId: string streamId: string @@ -1036,9 +1274,7 @@ class TextEngine< this.systemPrompts.length > 0 ? this.systemPrompts : undefined, toolNames: this.eventToolNames, options: this.eventOptions, - modelOptions: this.params.modelOptions as - | Record - | undefined, + modelOptions: this.params.modelOptions, messageCount: this.initialMessageCount, hasTools: this.tools.length > 0, streaming: true, @@ -1218,7 +1454,7 @@ export function chat< async function* runStreamingText( options: TextActivityOptions, ): AsyncIterable { - const { adapter, ...textOptions } = options + const { adapter, middleware, context, ...textOptions } = options const model = adapter.model const engine = new TextEngine({ @@ -1227,6 +1463,8 @@ async function* runStreamingText( Record, Record >, + middleware, + context, }) for await (const chunk of engine.run()) { @@ -1258,7 +1496,7 @@ function runNonStreamingText( async function runAgenticStructuredOutput( options: TextActivityOptions, ): Promise> { - const { adapter, outputSchema, ...textOptions } = options + const { adapter, outputSchema, middleware, context, ...textOptions } = options const model = adapter.model if (!outputSchema) { @@ -1272,6 +1510,8 @@ async function runAgenticStructuredOutput( Record, Record >, + middleware, + context, }) // Consume the stream to run the agentic loop diff --git a/packages/typescript/ai/src/activities/chat/middleware/compose.ts b/packages/typescript/ai/src/activities/chat/middleware/compose.ts new file mode 100644 index 000000000..025caeb8f --- /dev/null +++ b/packages/typescript/ai/src/activities/chat/middleware/compose.ts @@ -0,0 +1,185 @@ +import type { StreamChunk } from '../../../types' +import type { + AbortInfo, + AfterToolCallInfo, + BeforeToolCallDecision, + ChatMiddleware, + ChatMiddlewareConfig, + ChatMiddlewareContext, + ErrorInfo, + FinishInfo, + ToolCallHookContext, + UsageInfo, +} from './types' + +/** + * Internal middleware runner that manages composed execution of middleware hooks. + * Created once per chat() invocation. + */ +export class MiddlewareRunner { + private readonly middlewares: ReadonlyArray + + constructor(middlewares: ReadonlyArray) { + this.middlewares = middlewares + } + + get hasMiddleware(): boolean { + return this.middlewares.length > 0 + } + + /** + * Pipe config through all middleware onConfig hooks in order. + * Each middleware receives the merged config from previous middleware. + * Partial returns are shallow-merged with the current config. + */ + async runOnConfig( + ctx: ChatMiddlewareContext, + config: ChatMiddlewareConfig, + ): Promise { + let current = config + for (const mw of this.middlewares) { + if (mw.onConfig) { + const result = await mw.onConfig(ctx, current) + if (result !== undefined && result !== null) { + current = { ...current, ...result } + } + } + } + return current + } + + /** + * Call onStart on all middleware in order. + */ + async runOnStart(ctx: ChatMiddlewareContext): Promise { + for (const mw of this.middlewares) { + if (mw.onStart) { + await mw.onStart(ctx) + } + } + } + + /** + * Pipe a single chunk through all middleware onChunk hooks in order. + * Returns the resulting chunks (0..N) to yield to the consumer. + * + * - void: pass through unchanged + * - chunk: replace with this chunk + * - chunk[]: expand to multiple chunks + * - null: drop the chunk entirely + */ + async runOnChunk( + ctx: ChatMiddlewareContext, + chunk: StreamChunk, + ): Promise> { + let chunks: Array = [chunk] + + for (const mw of this.middlewares) { + if (!mw.onChunk) continue + + const nextChunks: Array = [] + for (const c of chunks) { + const result = await mw.onChunk(ctx, c) + if (result === null) { + // Drop this chunk + continue + } else if (result === undefined) { + // Pass through + nextChunks.push(c) + } else if (Array.isArray(result)) { + // Expand + nextChunks.push(...result) + } else { + // Replace + nextChunks.push(result) + } + } + chunks = nextChunks + } + + return chunks + } + + /** + * Run onBeforeToolCall through middleware in order. + * Returns the first non-void decision, or undefined to continue normally. + */ + async runOnBeforeToolCall( + ctx: ChatMiddlewareContext, + hookCtx: ToolCallHookContext, + ): Promise { + for (const mw of this.middlewares) { + if (mw.onBeforeToolCall) { + const decision = await mw.onBeforeToolCall(ctx, hookCtx) + if (decision !== undefined && decision !== null) { + return decision + } + } + } + return undefined + } + + /** + * Run onAfterToolCall on all middleware in order. + */ + async runOnAfterToolCall( + ctx: ChatMiddlewareContext, + info: AfterToolCallInfo, + ): Promise { + for (const mw of this.middlewares) { + if (mw.onAfterToolCall) { + await mw.onAfterToolCall(ctx, info) + } + } + } + + /** + * Run onUsage on all middleware in order. + */ + async runOnUsage( + ctx: ChatMiddlewareContext, + usage: UsageInfo, + ): Promise { + for (const mw of this.middlewares) { + if (mw.onUsage) { + await mw.onUsage(ctx, usage) + } + } + } + + /** + * Run onFinish on all middleware in order. + */ + async runOnFinish( + ctx: ChatMiddlewareContext, + info: FinishInfo, + ): Promise { + for (const mw of this.middlewares) { + if (mw.onFinish) { + await mw.onFinish(ctx, info) + } + } + } + + /** + * Run onAbort on all middleware in order. + */ + async runOnAbort(ctx: ChatMiddlewareContext, info: AbortInfo): Promise { + for (const mw of this.middlewares) { + if (mw.onAbort) { + await mw.onAbort(ctx, info) + } + } + } + + /** + * Run onError on all middleware in order. + */ + async runOnError(ctx: ChatMiddlewareContext, info: ErrorInfo): Promise { + for (const mw of this.middlewares) { + if (mw.onError) { + await mw.onError(ctx, info) + } + } + } +} diff --git a/packages/typescript/ai/src/activities/chat/middleware/index.ts b/packages/typescript/ai/src/activities/chat/middleware/index.ts new file mode 100644 index 000000000..d2175bb14 --- /dev/null +++ b/packages/typescript/ai/src/activities/chat/middleware/index.ts @@ -0,0 +1,22 @@ +export type { + ChatMiddleware, + ChatMiddlewareContext, + ChatMiddlewarePhase, + ChatMiddlewareConfig, + ToolCallHookContext, + BeforeToolCallDecision, + AfterToolCallInfo, + UsageInfo, + FinishInfo, + AbortInfo, + ErrorInfo, +} from './types' + +export { MiddlewareRunner } from './compose' + +export { toolCacheMiddleware } from './tool-cache-middleware' +export type { + ToolCacheMiddlewareOptions, + ToolCacheStorage, + ToolCacheEntry, +} from './tool-cache-middleware' diff --git a/packages/typescript/ai/src/activities/chat/middleware/tool-cache-middleware.ts b/packages/typescript/ai/src/activities/chat/middleware/tool-cache-middleware.ts new file mode 100644 index 000000000..475607758 --- /dev/null +++ b/packages/typescript/ai/src/activities/chat/middleware/tool-cache-middleware.ts @@ -0,0 +1,178 @@ +import type { ChatMiddleware } from './types' + +/** + * A cache entry stored by the tool cache middleware. + */ +export interface ToolCacheEntry { + result: unknown + timestamp: number +} + +/** + * Custom storage backend for the tool cache middleware. + * + * When provided, the middleware delegates all cache operations to this storage + * instead of using the built-in in-memory Map. This enables external storage + * backends like Redis, localStorage, databases, etc. + * + * All methods may return a Promise for async storage backends. + */ +export interface ToolCacheStorage { + getItem: ( + key: string, + ) => ToolCacheEntry | undefined | Promise + setItem: (key: string, value: ToolCacheEntry) => void | Promise + deleteItem: (key: string) => void | Promise +} + +/** + * Options for the tool cache middleware. + */ +export interface ToolCacheMiddlewareOptions { + /** + * Maximum number of entries in the cache. + * When exceeded, the oldest entry is evicted (LRU). + * + * Only applies to the default in-memory storage. + * When a custom `storage` is provided, capacity management is the storage's responsibility. + * + * @default 100 + */ + maxSize?: number + + /** + * Time-to-live in milliseconds. Entries older than this are not served from cache. + * @default Infinity (no expiry) + */ + ttl?: number + + /** + * Tool names to cache. If not provided, all tools are cached. + */ + toolNames?: Array + + /** + * Custom function to generate a cache key from tool name and args. + * Defaults to `JSON.stringify([toolName, args])`. + */ + keyFn?: (toolName: string, args: unknown) => string + + /** + * Custom storage backend. When provided, the middleware uses this instead of + * the built-in in-memory Map. The storage is responsible for its own capacity + * management — the `maxSize` option is ignored. + * + * @example + * ```ts + * toolCacheMiddleware({ + * storage: { + * getItem: (key) => redisClient.get(key).then(v => v ? JSON.parse(v) : undefined), + * setItem: (key, value) => redisClient.set(key, JSON.stringify(value)), + * deleteItem: (key) => redisClient.del(key), + * }, + * }) + * ``` + */ + storage?: ToolCacheStorage +} + +function defaultKeyFn(toolName: string, args: unknown): string { + return JSON.stringify([toolName, args]) +} + +function createDefaultStorage(maxSize: number): ToolCacheStorage { + const cache = new Map() + + return { + getItem: (key) => cache.get(key), + setItem: (key, value) => { + if (cache.size >= maxSize && !cache.has(key)) { + // LRU eviction: Map iteration order is insertion order — first key is oldest + const firstKey = cache.keys().next().value + if (firstKey !== undefined) { + cache.delete(firstKey) + } + } + cache.set(key, value) + }, + deleteItem: (key) => { + cache.delete(key) + }, + } +} + +/** + * Creates a middleware that caches tool call results based on tool name + arguments. + * + * When a tool is called with the same name and arguments as a previous call, + * the cached result is returned immediately without executing the tool. + * + * @example + * ```ts + * import { chat, toolCacheMiddleware } from '@tanstack/ai' + * + * const stream = chat({ + * adapter, + * messages, + * tools: [weatherTool, stockTool], + * middleware: [ + * toolCacheMiddleware({ ttl: 60_000, toolNames: ['getWeather'] }), + * ], + * }) + * ``` + */ +export function toolCacheMiddleware( + options: ToolCacheMiddlewareOptions = {}, +): ChatMiddleware { + const { + maxSize = 100, + ttl = Infinity, + toolNames, + keyFn = defaultKeyFn, + storage = createDefaultStorage(maxSize), + } = options + + return { + name: 'tool-cache-middleware', + + onBeforeToolCall: async (_ctx, hookCtx) => { + if (toolNames && !toolNames.includes(hookCtx.toolName)) { + return undefined + } + + const key = keyFn(hookCtx.toolName, hookCtx.args) + const entry = await storage.getItem(key) + + if (entry) { + const age = Date.now() - entry.timestamp + if (age < ttl) { + return { type: 'skip', result: entry.result } + } + // Expired — remove + await storage.deleteItem(key) + } + + return undefined + }, + + onAfterToolCall: async (_ctx, info) => { + if (!info.ok) return + if (toolNames && !toolNames.includes(info.toolName)) return + + // Re-derive the key from the raw arguments to match what onBeforeToolCall produces + let parsedArgs: unknown + try { + parsedArgs = JSON.parse(info.toolCall.function.arguments.trim() || '{}') + } catch { + return + } + + const key = keyFn(info.toolName, parsedArgs) + + await storage.setItem(key, { + result: info.result, + timestamp: Date.now(), + }) + }, + } +} diff --git a/packages/typescript/ai/src/activities/chat/middleware/types.ts b/packages/typescript/ai/src/activities/chat/middleware/types.ts new file mode 100644 index 000000000..3f4d8cb33 --- /dev/null +++ b/packages/typescript/ai/src/activities/chat/middleware/types.ts @@ -0,0 +1,311 @@ +import type { ModelMessage, StreamChunk, Tool, ToolCall } from '../../../types' + +// =========================== +// Middleware Context +// =========================== + +/** + * Phase of the chat middleware lifecycle. + * - 'init': Initial config transform before the chat engine starts + * - 'beforeModel': Before each adapter chatStream call (per agent iteration) + * - 'modelStream': During model streaming + * - 'beforeTools': Before tool execution phase + * - 'afterTools': After tool execution phase + */ +export type ChatMiddlewarePhase = + | 'init' + | 'beforeModel' + | 'modelStream' + | 'beforeTools' + | 'afterTools' + +/** + * Stable context object passed to all middleware hooks. + * Created once per chat() invocation and shared across all hooks. + */ +export interface ChatMiddlewareContext { + /** Unique identifier for this chat request */ + requestId: string + /** Unique identifier for this stream */ + streamId: string + /** Conversation identifier, if provided by the caller */ + conversationId?: string + /** Current lifecycle phase */ + phase: ChatMiddlewarePhase + /** Current agent loop iteration (0-indexed) */ + iteration: number + /** Running count of chunks yielded so far */ + chunkIndex: number + /** Abort signal from the chat request */ + signal?: AbortSignal + /** Abort the chat run with a reason */ + abort: (reason?: string) => void + /** Opaque user-provided value from chat() options */ + context: unknown + /** + * Defer a non-blocking side-effect promise. + * Deferred promises do not block streaming and are awaited + * after the terminal hook (onFinish/onAbort/onError). + */ + defer: (promise: Promise) => void +} + +// =========================== +// Config passed to onConfig +// =========================== + +/** + * Chat configuration that middleware can observe or transform. + * This is a subset of the chat engine's effective configuration + * that middleware is allowed to modify. + */ +export interface ChatMiddlewareConfig { + messages: Array + systemPrompts: Array + tools: Array + temperature?: number + topP?: number + maxTokens?: number + metadata?: Record + modelOptions?: Record +} + +// =========================== +// Tool Call Hook Context +// =========================== + +/** + * Context provided to tool call hooks (onBeforeToolCall / onAfterToolCall). + */ +export interface ToolCallHookContext { + /** The tool call being executed */ + toolCall: ToolCall + /** The resolved tool definition, if found */ + tool: Tool | undefined + /** Parsed arguments for the tool call */ + args: unknown + /** Name of the tool */ + toolName: string + /** ID of the tool call */ + toolCallId: string +} + +/** + * Decision returned from onBeforeToolCall. + * - undefined/void: continue with normal execution + * - { type: 'transformArgs', args }: replace args used for execution + * - { type: 'skip', result }: skip execution, use provided result + * - { type: 'abort', reason }: abort the entire chat run + */ +export type BeforeToolCallDecision = + | void + | undefined + | null + | { type: 'transformArgs'; args: unknown } + | { type: 'skip'; result: unknown } + | { type: 'abort'; reason?: string } + +/** + * Outcome information provided to onAfterToolCall. + */ +export interface AfterToolCallInfo { + /** The tool call that was executed */ + toolCall: ToolCall + /** The resolved tool definition */ + tool: Tool | undefined + /** Name of the tool */ + toolName: string + /** ID of the tool call */ + toolCallId: string + /** Whether the execution succeeded */ + ok: boolean + /** Duration of tool execution in milliseconds */ + duration: number + /** The result (if ok) or error (if not ok) */ + result?: unknown + error?: unknown +} + +// =========================== +// Usage Info +// =========================== + +/** + * Token usage statistics passed to the onUsage hook. + * Extracted from the RUN_FINISHED chunk when usage data is present. + */ +export interface UsageInfo { + promptTokens: number + completionTokens: number + totalTokens: number +} + +// =========================== +// Terminal Hook Info +// =========================== + +/** + * Information passed to onFinish. + */ +export interface FinishInfo { + /** The finish reason from the last model response */ + finishReason: string | null + /** Total duration of the chat run in milliseconds */ + duration: number + /** Final accumulated text content */ + content: string + /** Final usage totals, if available */ + usage?: { + promptTokens: number + completionTokens: number + totalTokens: number + } +} + +/** + * Information passed to onAbort. + */ +export interface AbortInfo { + /** The reason for the abort, if provided */ + reason?: string + /** Duration until abort in milliseconds */ + duration: number +} + +/** + * Information passed to onError. + */ +export interface ErrorInfo { + /** The error that caused the failure */ + error: unknown + /** Duration until error in milliseconds */ + duration: number +} + +// =========================== +// Middleware Interface +// =========================== + +/** + * Chat middleware interface. + * + * All hooks are optional. Middleware is composed in array order: + * - `onConfig`: config piped through middlewares in order (first transform influences later) + * - `onChunk`: each output chunk is fed into the next middleware in order + * + * @example Logging middleware + * ```ts + * const loggingMiddleware: ChatMiddleware = { + * name: 'logging', + * onStart(ctx) { console.log('Chat started', ctx.requestId) }, + * onChunk(ctx, chunk) { console.log('Chunk:', chunk.type) }, + * onFinish(ctx, info) { console.log('Done:', info.duration, 'ms') }, + * } + * ``` + * + * @example Redaction middleware + * ```ts + * const redactionMiddleware: ChatMiddleware = { + * name: 'redaction', + * onChunk(ctx, chunk) { + * if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + * return { ...chunk, delta: redact(chunk.delta) } + * } + * }, + * } + * ``` + */ +export interface ChatMiddleware { + /** Optional name for debugging and identification */ + name?: string + + /** + * Called to observe or transform the chat configuration. + * Called at init and at the beginning of each agent iteration. + * + * Return a partial config to merge with the current config, or void to pass through. + * Only the fields you return are overwritten — everything else is preserved. + */ + onConfig?: ( + ctx: ChatMiddlewareContext, + config: ChatMiddlewareConfig, + ) => + | void + | null + | Partial + | Promise> + + /** + * Called when the chat run starts (after initial onConfig). + */ + onStart?: (ctx: ChatMiddlewareContext) => void | Promise + + /** + * Called for every chunk yielded by chat(). + * Can observe, transform, expand, or drop chunks. + * + * @returns void (pass through), chunk (replace), chunk[] (expand), null (drop) + */ + onChunk?: ( + ctx: ChatMiddlewareContext, + chunk: StreamChunk, + ) => + | void + | StreamChunk + | Array + | null + | Promise | null> + + /** + * Called before a tool is executed. + * Can observe, transform args, skip execution, or abort the run. + */ + onBeforeToolCall?: ( + ctx: ChatMiddlewareContext, + hookCtx: ToolCallHookContext, + ) => BeforeToolCallDecision | Promise + + /** + * Called after a tool execution completes (success or failure). + */ + onAfterToolCall?: ( + ctx: ChatMiddlewareContext, + info: AfterToolCallInfo, + ) => void | Promise + + /** + * Called when usage data is available from a RUN_FINISHED chunk. + * Called once per model iteration that reports usage. + */ + onUsage?: ( + ctx: ChatMiddlewareContext, + usage: UsageInfo, + ) => void | Promise + + /** + * Called when the chat run completes normally. + * Exactly one of onFinish/onAbort/onError will be called per run. + */ + onFinish?: ( + ctx: ChatMiddlewareContext, + info: FinishInfo, + ) => void | Promise + + /** + * Called when the chat run is aborted. + * Exactly one of onFinish/onAbort/onError will be called per run. + */ + onAbort?: ( + ctx: ChatMiddlewareContext, + info: AbortInfo, + ) => void | Promise + + /** + * Called when the chat run encounters an unhandled error. + * Exactly one of onFinish/onAbort/onError will be called per run. + */ + onError?: ( + ctx: ChatMiddlewareContext, + info: ErrorInfo, + ) => void | Promise +} diff --git a/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts b/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts index 594274110..c81cc09f2 100644 --- a/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts +++ b/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts @@ -10,6 +10,33 @@ import type { ToolCallStartEvent, ToolExecutionContext, } from '../../../types' +import type { + AfterToolCallInfo, + BeforeToolCallDecision, +} from '../middleware/types' + +/** + * Optional middleware hooks for tool execution. + * When provided, these callbacks are invoked before/after each tool execution. + */ +export interface ToolExecutionMiddlewareHooks { + onBeforeToolCall?: ( + toolCall: ToolCall, + tool: Tool | undefined, + args: unknown, + ) => Promise + onAfterToolCall?: (info: AfterToolCallInfo) => Promise +} + +/** + * Error thrown when middleware decides to abort the chat run during tool execution. + */ +export class MiddlewareAbortError extends Error { + constructor(reason: string) { + super(reason) + this.name = 'MiddlewareAbortError' + } +} /** * Manages tool call accumulation and execution for the chat() method's automatic tool execution loop. @@ -289,6 +316,154 @@ async function* executeWithEventPolling( return state.result } +/** + * Apply a middleware onBeforeToolCall decision. + * Returns the (possibly transformed) input if execution should proceed, + * or undefined if the tool call was skipped (result already pushed). + * Throws MiddlewareAbortError if the decision is 'abort'. + */ +async function applyBeforeToolCallDecision( + toolCall: ToolCall, + tool: Tool, + input: unknown, + toolName: string, + middlewareHooks: ToolExecutionMiddlewareHooks, + results: Array, +): Promise<{ proceed: true; input: unknown } | { proceed: false }> { + if (!middlewareHooks.onBeforeToolCall) { + return { proceed: true, input } + } + + const decision = await middlewareHooks.onBeforeToolCall(toolCall, tool, input) + if (!decision) { + return { proceed: true, input } + } + + if (decision.type === 'abort') { + throw new MiddlewareAbortError(decision.reason || 'Aborted by middleware') + } + + if (decision.type === 'skip') { + const skipResult = decision.result + results.push({ + toolCallId: toolCall.id, + toolName, + result: + typeof skipResult === 'string' + ? JSON.parse(skipResult) + : skipResult || null, + duration: 0, + }) + if (middlewareHooks.onAfterToolCall) { + await middlewareHooks.onAfterToolCall({ + toolCall, + tool, + toolName, + toolCallId: toolCall.id, + ok: true, + duration: 0, + result: skipResult, + }) + } + return { proceed: false } + } + + + return { proceed: true, input: decision.args } + +} + +/** + * Execute a server-side tool with event polling, output validation, and middleware hooks. + * Yields CustomEvent chunks during execution and pushes the result to the results array. + */ +async function* executeServerTool( + toolCall: ToolCall, + tool: Tool, + toolName: string, + input: unknown, + context: ToolExecutionContext, + pendingEvents: Array, + results: Array, + middlewareHooks?: ToolExecutionMiddlewareHooks, +): AsyncGenerator { + const startTime = Date.now() + try { + const executionPromise = Promise.resolve(tool.execute!(input, context)) + let result = yield* executeWithEventPolling(executionPromise, pendingEvents) + const duration = Date.now() - startTime + + // Flush remaining events + while (pendingEvents.length > 0) { + yield pendingEvents.shift()! + } + + // Validate output against outputSchema if provided + if ( + tool.outputSchema && + isStandardSchema(tool.outputSchema) && + result !== undefined && + result !== null + ) { + result = parseWithStandardSchema(tool.outputSchema, result) + } + + const finalResult = + typeof result === 'string' ? JSON.parse(result) : result || null + + results.push({ + toolCallId: toolCall.id, + toolName, + result: finalResult, + duration, + }) + + if (middlewareHooks?.onAfterToolCall) { + await middlewareHooks.onAfterToolCall({ + toolCall, + tool, + toolName, + toolCallId: toolCall.id, + ok: true, + duration, + result: finalResult, + }) + } + } catch (error: unknown) { + const duration = Date.now() - startTime + + // Flush remaining events + while (pendingEvents.length > 0) { + yield pendingEvents.shift()! + } + + if (error instanceof MiddlewareAbortError) { + throw error + } + + const message = error instanceof Error ? error.message : 'Unknown error' + results.push({ + toolCallId: toolCall.id, + toolName, + result: { error: message }, + state: 'output-error', + duration, + }) + + if (middlewareHooks?.onAfterToolCall) { + await middlewareHooks.onAfterToolCall({ + toolCall, + tool, + toolName, + toolCallId: toolCall.id, + ok: false, + duration, + error, + }) + } + } +} + /** * Execute tool calls based on their configuration. * Yields CustomEvent chunks during tool execution for real-time progress updates. @@ -313,6 +488,7 @@ export async function* executeToolCalls( eventName: string, value: Record, ) => CustomEvent, + middlewareHooks?: ToolExecutionMiddlewareHooks, ): AsyncGenerator { const results: Array = [] const needsApproval: Array = [] @@ -388,13 +564,6 @@ export async function* executeToolCalls( }, } - // Helper to flush any pending events - function* flushEvents(): Generator { - while (pendingEvents.length > 0) { - yield pendingEvents.shift()! - } - } - // CASE 1: Client-side tool (no execute function) if (!tool.execute) { // Check if tool needs approval @@ -468,51 +637,30 @@ export async function* executeToolCalls( const approved = approvals.get(approvalId) if (approved) { - // Execute after approval - const startTime = Date.now() - try { - const executionPromise = Promise.resolve( - tool.execute(input, context), - ) - let result = yield* executeWithEventPolling( - executionPromise, - pendingEvents, - ) - const duration = Date.now() - startTime - yield* flushEvents() - - // Validate output against outputSchema if provided (for Standard Schema compliant schemas) - if ( - tool.outputSchema && - isStandardSchema(tool.outputSchema) && - result !== undefined && - result !== null - ) { - result = parseWithStandardSchema(tool.outputSchema, result) - } - - results.push({ - toolCallId: toolCall.id, + // Apply middleware before-hook for approved tools + if (middlewareHooks) { + const decision = await applyBeforeToolCallDecision( + toolCall, + tool, + input, toolName, - result: - typeof result === 'string' - ? JSON.parse(result) - : result || null, - duration, - }) - } catch (error: unknown) { - const duration = Date.now() - startTime - yield* flushEvents() - const message = - error instanceof Error ? error.message : 'Unknown error' - results.push({ - toolCallId: toolCall.id, - toolName, - result: { error: message }, - state: 'output-error', - duration, - }) + middlewareHooks, + results, + ) + if (!decision.proceed) continue + input = decision.input } + + yield* executeServerTool( + toolCall, + tool, + toolName, + input, + context, + pendingEvents, + results, + middlewareHooks, + ) } else { // User declined results.push({ @@ -535,45 +683,29 @@ export async function* executeToolCalls( } // CASE 3: Normal server tool - execute immediately - const startTime = Date.now() - try { - const executionPromise = Promise.resolve(tool.execute(input, context)) - let result = yield* executeWithEventPolling( - executionPromise, - pendingEvents, - ) - const duration = Date.now() - startTime - yield* flushEvents() - - // Validate output against outputSchema if provided (for Standard Schema compliant schemas) - if ( - tool.outputSchema && - isStandardSchema(tool.outputSchema) && - result !== undefined && - result !== null - ) { - result = parseWithStandardSchema(tool.outputSchema, result) - } - - results.push({ - toolCallId: toolCall.id, - toolName, - result: - typeof result === 'string' ? JSON.parse(result) : result || null, - duration, - }) - } catch (error: unknown) { - const duration = Date.now() - startTime - yield* flushEvents() - const message = error instanceof Error ? error.message : 'Unknown error' - results.push({ - toolCallId: toolCall.id, + if (middlewareHooks) { + const decision = await applyBeforeToolCallDecision( + toolCall, + tool, + input, toolName, - result: { error: message }, - state: 'output-error', - duration, - }) + middlewareHooks, + results, + ) + if (!decision.proceed) continue + input = decision.input } + + yield* executeServerTool( + toolCall, + tool, + toolName, + input, + context, + pendingEvents, + results, + middlewareHooks, + ) } return { results, needsApproval, needsClientExecution } diff --git a/packages/typescript/ai/src/index.ts b/packages/typescript/ai/src/index.ts index 7f0d4fece..d44ed4fec 100644 --- a/packages/typescript/ai/src/index.ts +++ b/packages/typescript/ai/src/index.ts @@ -70,6 +70,26 @@ export { combineStrategies, } from './activities/chat/agent-loop-strategies' +// Chat middleware +export type { + ChatMiddleware, + ChatMiddlewareContext, + ChatMiddlewarePhase, + ChatMiddlewareConfig, + ToolCallHookContext, + BeforeToolCallDecision, + AfterToolCallInfo, + UsageInfo, + FinishInfo, + AbortInfo, + ErrorInfo, + ToolCacheMiddlewareOptions, + ToolCacheStorage, + ToolCacheEntry, +} from './activities/chat/middleware/index' + +export { toolCacheMiddleware } from './activities/chat/middleware/index' + // All types export * from './types' diff --git a/packages/typescript/ai/tests/chat.test.ts b/packages/typescript/ai/tests/chat.test.ts index 65bda0067..b4408f140 100644 --- a/packages/typescript/ai/tests/chat.test.ts +++ b/packages/typescript/ai/tests/chat.test.ts @@ -1,137 +1,13 @@ import { describe, expect, it, vi } from 'vitest' import { chat, createChatOptions } from '../src/activities/chat/index' -import type { AnyTextAdapter } from '../src/activities/chat/adapter' -import type { StreamChunk, Tool } from '../src/types' - -// ============================================================================ -// Helpers -// ============================================================================ - -/** Create a typed StreamChunk with minimal boilerplate. */ -function chunk( - type: T, - fields: Omit, 'type' | 'timestamp'>, -): StreamChunk { - return { type, timestamp: Date.now(), ...fields } as unknown as StreamChunk -} - -/** Shorthand chunk factories for common AG-UI events. */ -const ev = { - runStarted: (runId = 'run-1') => chunk('RUN_STARTED', { runId }), - textStart: (messageId = 'msg-1') => - chunk('TEXT_MESSAGE_START', { messageId, role: 'assistant' as const }), - textContent: (delta: string, messageId = 'msg-1') => - chunk('TEXT_MESSAGE_CONTENT', { messageId, delta }), - textEnd: (messageId = 'msg-1') => chunk('TEXT_MESSAGE_END', { messageId }), - toolStart: (toolCallId: string, toolName: string, index?: number) => - chunk('TOOL_CALL_START', { - toolCallId, - toolName, - ...(index !== undefined ? { index } : {}), - }), - toolArgs: (toolCallId: string, delta: string) => - chunk('TOOL_CALL_ARGS', { toolCallId, delta }), - toolEnd: ( - toolCallId: string, - toolName: string, - opts?: { input?: unknown; result?: string }, - ) => chunk('TOOL_CALL_END', { toolCallId, toolName, ...opts }), - runFinished: ( - finishReason: - | 'stop' - | 'length' - | 'content_filter' - | 'tool_calls' - | null = 'stop', - runId = 'run-1', - ) => chunk('RUN_FINISHED', { runId, finishReason }), - runError: (message: string, runId = 'run-1') => - chunk('RUN_ERROR', { runId, error: { message } }), - stepFinished: (delta: string, stepId = 'step-1') => - chunk('STEP_FINISHED', { stepId, delta }), -} - -/** - * Create a mock adapter that satisfies AnyTextAdapter. - * `chatStreamFn` receives the options and returns an AsyncIterable of chunks. - * Multiple invocations can be tracked via the returned `calls` array. - */ -function createMockAdapter(options: { - chatStreamFn?: (opts: any) => AsyncIterable - /** Array of chunk sequences: chatStream returns iterations[0] on first call, iterations[1] on second, etc. */ - iterations?: Array> - structuredOutput?: (opts: any) => Promise<{ data: unknown; rawText: string }> -}) { - const calls: Array = [] - let callIndex = 0 - - const adapter: AnyTextAdapter = { - kind: 'text' as const, - name: 'mock', - model: 'test-model' as const, - '~types': { - providerOptions: {} as Record, - inputModalities: ['text'] as readonly ['text'], - messageMetadataByModality: { - text: undefined as unknown, - image: undefined as unknown, - audio: undefined as unknown, - video: undefined as unknown, - document: undefined as unknown, - }, - }, - chatStream: (opts: any) => { - calls.push(opts) - - if (options.chatStreamFn) { - return options.chatStreamFn(opts) - } - - if (options.iterations) { - const chunks = options.iterations[callIndex] || [] - callIndex++ - return (async function* () { - for (const c of chunks) yield c - })() - } - - return (async function* () {})() - }, - structuredOutput: - options.structuredOutput ?? (async () => ({ data: {}, rawText: '{}' })), - } - - return { adapter, calls } -} - -/** Collect all chunks from an async iterable. */ -async function collectChunks( - stream: AsyncIterable, -): Promise> { - const chunks: Array = [] - for await (const c of stream) { - chunks.push(c) - } - return chunks -} - -/** Simple server tool for testing. */ -function serverTool(name: string, executeFn: (args: any) => any): Tool { - return { - name, - description: `Test tool: ${name}`, - execute: executeFn, - } -} - -/** Client tool (no execute function). */ -function clientTool(name: string, opts?: { needsApproval?: boolean }): Tool { - return { - name, - description: `Client tool: ${name}`, - needsApproval: opts?.needsApproval, - } -} +import type { StreamChunk } from '../src/types' +import { + ev, + createMockAdapter, + collectChunks, + serverTool, + clientTool, +} from './test-utils' // ============================================================================ // Tests diff --git a/packages/typescript/ai/tests/middleware.test.ts b/packages/typescript/ai/tests/middleware.test.ts new file mode 100644 index 000000000..5b0701ea8 --- /dev/null +++ b/packages/typescript/ai/tests/middleware.test.ts @@ -0,0 +1,2404 @@ +/* eslint-disable @typescript-eslint/require-await */ +import { describe, expect, it, vi } from 'vitest' +import { chat } from '../src/activities/chat/index' +import { + collectChunks, + createMockAdapter, + ev, + getDeltas, + serverTool, +} from './test-utils' +import type { StreamChunk } from '../src/types' +import type { + ChatMiddleware, + ChatMiddlewareContext, +} from '../src/activities/chat/middleware/types' + +// ============================================================================ +// Tests +// ============================================================================ + +describe('chat() middleware', () => { + // ========================================================================== + // Basic lifecycle hooks + // ========================================================================== + describe('lifecycle hooks', () => { + it('should call onStart and onFinish for a simple stream', async () => { + const onStart = vi.fn() + const onFinish = vi.fn() + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textStart(), + ev.textContent('Hello'), + ev.textEnd(), + ev.runFinished('stop'), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onStart, + onFinish, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + + await collectChunks(stream as AsyncIterable) + + expect(onStart).toHaveBeenCalledOnce() + expect(onFinish).toHaveBeenCalledOnce() + expect(onFinish.mock.calls[0]![1]).toMatchObject({ + finishReason: 'stop', + content: 'Hello', + }) + }) + + it('should call onError on adapter errors', async () => { + const onError = vi.fn() + const onFinish = vi.fn() + + const { adapter } = createMockAdapter({ + // eslint-disable-next-line require-yield + chatStreamFn: async function* () { + throw new Error('Adapter failure') + }, + }) + + const middleware: ChatMiddleware = { + name: 'test', + onError, + onFinish, + } + + try { + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + } catch { + // Expected + } + + expect(onError).toHaveBeenCalledOnce() + expect(onFinish).not.toHaveBeenCalled() + }) + + it('should call exactly one terminal hook per run', async () => { + const onStart = vi.fn() + const onFinish = vi.fn() + const onAbort = vi.fn() + const onError = vi.fn() + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onStart, + onFinish, + onAbort, + onError, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Exactly one terminal hook + const terminalCalls = + onFinish.mock.calls.length + + onAbort.mock.calls.length + + onError.mock.calls.length + expect(terminalCalls).toBe(1) + expect(onFinish).toHaveBeenCalledOnce() + }) + }) + + // ========================================================================== + // onConfig + // ========================================================================== + describe('onConfig', () => { + it('should call onConfig at init phase', async () => { + const phases: Array = [] + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onConfig: (ctx) => { + // Capture phase at call time since context is mutable + phases.push(ctx.phase) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Called at least twice: init + beforeModel + expect(phases).toEqual(['init', 'beforeModel']) + }) + + it('should allow onConfig to transform system prompts', async () => { + const { adapter, calls } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onConfig: (_ctx, config) => ({ + systemPrompts: [...config.systemPrompts, 'Added by middleware'], + }), + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + systemPrompts: ['Original'], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // The adapter should receive the transformed system prompts + expect(calls[0]!.systemPrompts).toContain('Added by middleware') + expect(calls[0]!.systemPrompts).toContain('Original') + }) + + it('should pipe config through multiple middlewares in order', async () => { + const { adapter, calls } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const mw1: ChatMiddleware = { + name: 'first', + onConfig: () => ({ + maxTokens: 100, + }), + } + + const mw2: ChatMiddleware = { + name: 'second', + onConfig: (_ctx, config) => ({ + // Can see what first middleware set + maxTokens: (config.maxTokens ?? 0) + 50, + }), + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [mw1, mw2], + }) + await collectChunks(stream as AsyncIterable) + + // Adapter should get maxTokens = 150 (100 + 50) + expect(calls[0]!.maxTokens).toBe(150) + }) + }) + + // ========================================================================== + // onChunk + // ========================================================================== + describe('onChunk', () => { + it('should observe all chunks', async () => { + const chunkTypes: Array = [] + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textStart(), + ev.textContent('Hello'), + ev.textEnd(), + ev.runFinished('stop'), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'observer', + onChunk: (_ctx, chunk) => { + chunkTypes.push(chunk.type) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(chunkTypes).toEqual([ + 'RUN_STARTED', + 'TEXT_MESSAGE_START', + 'TEXT_MESSAGE_CONTENT', + 'TEXT_MESSAGE_END', + 'RUN_FINISHED', + ]) + }) + + it('should allow chunk transformation', async () => { + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('secret'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'redact', + onChunk: (_ctx, chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + return { ...chunk, delta: '***' } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + const chunks = await collectChunks(stream as AsyncIterable) + + const deltas = getDeltas(chunks) + expect(deltas).toEqual(['***']) + }) + + it('should allow chunk dropping (null)', async () => { + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textContent('Hello'), + ev.textContent('World'), + ev.runFinished('stop'), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'dropper', + onChunk: (_ctx, chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT' && chunk.delta === 'Hello') { + return null + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + const chunks = await collectChunks(stream as AsyncIterable) + + expect(getDeltas(chunks)).toEqual(['World']) + }) + + it('should allow chunk expansion (chunk[])', async () => { + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('AB'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'expander', + onChunk: (_ctx, chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + return [ + { ...chunk, delta: 'A' } as StreamChunk, + { ...chunk, delta: 'B' } as StreamChunk, + ] + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + const chunks = await collectChunks(stream as AsyncIterable) + + expect(getDeltas(chunks)).toEqual(['A', 'B']) + }) + + it('should pipe chunks through multiple middleware in order', async () => { + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hello'), ev.runFinished('stop')], + ], + }) + + const mw1: ChatMiddleware = { + name: 'upper', + onChunk: (_ctx, chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + return { ...chunk, delta: chunk.delta.toUpperCase() } + } + return undefined + }, + } + + const mw2: ChatMiddleware = { + name: 'exclaim', + onChunk: (_ctx, chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + return { ...chunk, delta: chunk.delta + '!' } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [mw1, mw2], + }) + const chunks = await collectChunks(stream as AsyncIterable) + + // First middleware uppercases, second appends ! + expect(getDeltas(chunks)).toEqual(['HELLO!']) + }) + + it('should handle expansion followed by another middleware', async () => { + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('AB'), ev.runFinished('stop')], + ], + }) + + const mw1: ChatMiddleware = { + name: 'expander', + onChunk: (_ctx, chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + return [ + { ...chunk, delta: 'A' } as StreamChunk, + { ...chunk, delta: 'B' } as StreamChunk, + ] + } + return undefined + }, + } + + const mw2: ChatMiddleware = { + name: 'suffix', + onChunk: (_ctx, chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + return { ...chunk, delta: chunk.delta + '!' } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [mw1, mw2], + }) + const chunks = await collectChunks(stream as AsyncIterable) + + // Both expanded chunks go through suffix middleware + expect(getDeltas(chunks)).toEqual(['A!', 'B!']) + }) + }) + + // ========================================================================== + // onBeforeToolCall / onAfterToolCall + // ========================================================================== + describe('tool call hooks', () => { + it('should call onBeforeToolCall and onAfterToolCall', async () => { + const onBeforeToolCall = vi.fn() + const onAfterToolCall = vi.fn() + + const tool = serverTool('myTool', () => ({ result: 'done' })) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onBeforeToolCall, + onAfterToolCall, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(onBeforeToolCall).toHaveBeenCalledOnce() + expect(onBeforeToolCall.mock.calls[0]![1].toolName).toBe('myTool') + expect(onAfterToolCall).toHaveBeenCalledOnce() + expect(onAfterToolCall.mock.calls[0]![1].ok).toBe(true) + }) + + it('should support transformArgs decision', async () => { + const executeFn = vi.fn(() => ({ result: 'done' })) + const tool = serverTool('myTool', executeFn) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{"x":1}'), + ev.toolEnd('tc-1', 'myTool', { input: { x: 1 } }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onBeforeToolCall: () => ({ + type: 'transformArgs' as const, + args: { x: 42 }, + }), + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Tool should be called with transformed args + expect(executeFn).toHaveBeenCalledWith({ x: 42 }, expect.anything()) + }) + + it('should support skip decision to return a cached result', async () => { + const executeFn = vi.fn(() => ({ result: 'should not be called' })) + const onAfterToolCall = vi.fn() + const tool = serverTool('myTool', executeFn) + + const cachedResult = { weather: 'sunny', temp: 72 } + + const { adapter, calls } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{"city":"NYC"}'), + ev.toolEnd('tc-1', 'myTool', { input: { city: 'NYC' } }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'cache', + onBeforeToolCall: () => ({ + type: 'skip' as const, + result: cachedResult, + }), + onAfterToolCall, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Tool execute should NOT be called + expect(executeFn).not.toHaveBeenCalled() + + // onAfterToolCall should receive the cached result with ok=true, duration=0 + expect(onAfterToolCall).toHaveBeenCalledOnce() + const afterInfo = onAfterToolCall.mock.calls[0]![1] + expect(afterInfo.ok).toBe(true) + expect(afterInfo.duration).toBe(0) + expect(afterInfo.result).toEqual(cachedResult) + expect(afterInfo.toolName).toBe('myTool') + expect(afterInfo.toolCallId).toBe('tc-1') + + // The cached result should be fed back to the model in the next iteration + // (the adapter is called a second time with tool result messages) + expect(calls.length).toBe(2) + const secondCallMessages = calls[1]!.messages as Array<{ + role: string + content?: unknown + }> + const toolResultMsg = secondCallMessages.find( + (m) => m.role === 'tool', + ) + expect(toolResultMsg).toBeDefined() + }) + + it('should skip multiple tools independently based on cache', async () => { + const executeA = vi.fn(() => ({ a: 'executed' })) + const executeB = vi.fn(() => ({ b: 'executed' })) + const toolA = serverTool('toolA', executeA) + const toolB = serverTool('toolB', executeB) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'toolA'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'toolA', { input: {} }), + ev.toolStart('tc-2', 'toolB'), + ev.toolArgs('tc-2', '{}'), + ev.toolEnd('tc-2', 'toolB', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('done'), ev.runFinished('stop')], + ], + }) + + // Cache only toolA, let toolB execute normally + const middleware: ChatMiddleware = { + name: 'selective-cache', + onBeforeToolCall: (_ctx, hookCtx) => { + if (hookCtx.toolName === 'toolA') { + return { + type: 'skip' as const, + result: { a: 'cached' }, + } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [toolA, toolB], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // toolA skipped, toolB executed + expect(executeA).not.toHaveBeenCalled() + expect(executeB).toHaveBeenCalledOnce() + }) + + it('should support abort decision from onBeforeToolCall', async () => { + const onAbort = vi.fn() + const onFinish = vi.fn() + const executeFn = vi.fn(() => ({ result: 'should not be called' })) + const tool = serverTool('myTool', executeFn) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onBeforeToolCall: () => ({ + type: 'abort' as const, + reason: 'policy violation', + }), + onAbort, + onFinish, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Tool should NOT be called + expect(executeFn).not.toHaveBeenCalled() + // onAbort should be called, not onFinish + expect(onAbort).toHaveBeenCalledOnce() + expect(onAbort.mock.calls[0]![1].reason).toBe('policy violation') + expect(onFinish).not.toHaveBeenCalled() + }) + }) + + // ========================================================================== + // onUsage + // ========================================================================== + describe('onUsage', () => { + it('should call onUsage once with usage from RUN_FINISHED', async () => { + const onUsage = vi.fn() + const usage = { promptTokens: 10, completionTokens: 5, totalTokens: 15 } + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textContent('hi'), + ev.runFinished('stop', 'run-1', usage), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onUsage, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Called once when the RUN_FINISHED chunk has usage + expect(onUsage).toHaveBeenCalledTimes(1) + expect(onUsage.mock.calls[0]![1]).toEqual(usage) + }) + }) + + // ========================================================================== + // Context + // ========================================================================== + describe('context', () => { + it('should pass user context to middleware hooks', async () => { + const receivedCtx: Array = [] + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onStart: (ctx) => { + receivedCtx.push(ctx.context) + }, + onChunk: (ctx) => { + receivedCtx.push(ctx.context) + }, + } + + const userContext = { userId: 'u-123', sessionId: 's-abc' } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + context: userContext, + }) + await collectChunks(stream as AsyncIterable) + + // All hooks should receive the same context object + for (const ctx of receivedCtx) { + expect(ctx).toBe(userContext) + } + }) + + it('should provide requestId and streamId in context', async () => { + let capturedCtx: ChatMiddlewareContext | undefined + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onStart: (ctx) => { + capturedCtx = ctx + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(capturedCtx).toBeDefined() + expect(capturedCtx!.requestId).toBeTruthy() + expect(capturedCtx!.streamId).toBeTruthy() + }) + }) + + // ========================================================================== + // Backward compatibility + // ========================================================================== + describe('backward compatibility', () => { + it('should work exactly the same without middleware', async () => { + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textStart(), + ev.textContent('Hello'), + ev.textEnd(), + ev.runFinished('stop'), + ], + ], + }) + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + }) + + const chunks = await collectChunks(stream as AsyncIterable) + + expect(chunks.length).toBe(5) + expect(chunks.map((c) => c.type)).toEqual([ + 'RUN_STARTED', + 'TEXT_MESSAGE_START', + 'TEXT_MESSAGE_CONTENT', + 'TEXT_MESSAGE_END', + 'RUN_FINISHED', + ]) + }) + + it('should work with empty middleware array', async () => { + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('Hello'), ev.runFinished('stop')], + ], + }) + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [], + }) + + const chunks = await collectChunks(stream as AsyncIterable) + + expect(chunks.length).toBe(3) + }) + }) + + // ========================================================================== + // Defer + // ========================================================================== + describe('defer', () => { + it('should await deferred promises after terminal hook', async () => { + const order: Array = [] + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onStart: (ctx) => { + ctx.defer( + new Promise((resolve) => { + setTimeout(() => { + order.push('deferred') + resolve() + }, 10) + }), + ) + }, + onFinish: () => { + order.push('finish') + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Both should have completed + expect(order).toContain('finish') + expect(order).toContain('deferred') + }) + }) + + // ========================================================================== + // Async middleware + // ========================================================================== + describe('async middleware', () => { + it('should handle async onChunk without reordering chunks', async () => { + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textContent('A'), + ev.textContent('B'), + ev.textContent('C'), + ev.runFinished('stop'), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'slow', + onChunk: async (_ctx, chunk) => { + // Simulate async delay + await new Promise((r) => setTimeout(r, 5)) + if (chunk.type === 'TEXT_MESSAGE_CONTENT') { + return { + ...chunk, + delta: chunk.delta.toLowerCase(), + } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + const chunks = await collectChunks(stream as AsyncIterable) + + // Order should be preserved despite async middleware + expect(getDeltas(chunks)).toEqual(['a', 'b', 'c']) + }) + }) + + // ========================================================================== + // Per-iteration onConfig + // ========================================================================== + describe('per-iteration onConfig', () => { + it('should call onConfig before each model iteration', async () => { + const phases: Array = [] + const iterations: Array = [] + + const tool = serverTool('myTool', () => ({ result: 'done' })) + + const { adapter } = createMockAdapter({ + iterations: [ + // First iteration: model calls a tool + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + // Second iteration: model responds with text + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onConfig: (ctx) => { + phases.push(ctx.phase) + iterations.push(ctx.iteration) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // init + beforeModel (iter 0) + beforeModel (iter 1) + expect(phases).toEqual(['init', 'beforeModel', 'beforeModel']) + expect(iterations).toEqual([0, 0, 1]) + }) + }) + + // ========================================================================== + // ctx.abort() from middleware hooks + // ========================================================================== + describe('ctx.abort()', () => { + it('should abort the run when ctx.abort() is called from onChunk', async () => { + const onAbort = vi.fn() + const onFinish = vi.fn() + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textContent('A'), + ev.textContent('B'), + ev.textContent('C'), + ev.runFinished('stop'), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'aborter', + onChunk: (ctx, chunk) => { + if (chunk.type === 'TEXT_MESSAGE_CONTENT' && chunk.delta === 'B') { + ctx.abort('seen enough') + } + return undefined + }, + onAbort, + onFinish, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(onAbort).toHaveBeenCalledOnce() + expect(onAbort.mock.calls[0]![1].reason).toBe('seen enough') + expect(onFinish).not.toHaveBeenCalled() + }) + }) + + // ========================================================================== + // onAfterToolCall error path + // ========================================================================== + describe('onAfterToolCall error handling', () => { + it('should report tool execution errors in onAfterToolCall', async () => { + const afterCalls: Array<{ ok: boolean; error?: unknown }> = [] + + const tool = serverTool('failTool', () => { + throw new Error('tool exploded') + }) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'failTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'failTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('recovered'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'error-observer', + onAfterToolCall: (_ctx, info) => { + afterCalls.push({ ok: info.ok, error: info.error }) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(afterCalls).toHaveLength(1) + expect(afterCalls[0]!.ok).toBe(false) + expect(afterCalls[0]!.error).toBeInstanceOf(Error) + }) + }) + + // ========================================================================== + // Multiple middleware onBeforeToolCall composition + // ========================================================================== + describe('onBeforeToolCall composition', () => { + it('should use first non-void decision from multiple middleware', async () => { + const executeFn = vi.fn(() => ({ result: 'done' })) + const tool = serverTool('myTool', executeFn) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{"x":1}'), + ev.toolEnd('tc-1', 'myTool', { input: { x: 1 } }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + // First middleware returns void (passes through) + const mw1: ChatMiddleware = { + name: 'pass-through', + onBeforeToolCall: vi.fn(), + } + + // Second middleware transforms args + const mw2: ChatMiddleware = { + name: 'transformer', + onBeforeToolCall: () => ({ + type: 'transformArgs' as const, + args: { x: 99 }, + }), + } + + // Third middleware should NOT be reached since mw2 returned a decision + const mw3OnBefore = vi.fn() + const mw3: ChatMiddleware = { + name: 'should-not-run', + onBeforeToolCall: mw3OnBefore, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [mw1, mw2, mw3], + }) + await collectChunks(stream as AsyncIterable) + + // Tool should be called with mw2's transformed args + expect(executeFn).toHaveBeenCalledWith({ x: 99 }, expect.anything()) + // mw3's onBeforeToolCall should NOT have been called + expect(mw3OnBefore).not.toHaveBeenCalled() + }) + }) + + // ========================================================================== + // onConfig tools transform + // ========================================================================== + describe('onConfig tools transform', () => { + it('should allow middleware to add tools via onConfig', async () => { + const addedToolExecute = vi.fn(() => ({ added: true })) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'addedTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'addedTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'tool-injector', + onConfig: (_ctx, config) => ({ + tools: [ + ...config.tools, + { + name: 'addedTool', + description: 'Added by middleware', + execute: addedToolExecute, + }, + ], + }), + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(addedToolExecute).toHaveBeenCalledOnce() + }) + }) + + // ========================================================================== + // Async onStart + // ========================================================================== + describe('async onStart', () => { + it('should await async onStart before streaming begins', async () => { + const order: Array = [] + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'async-start', + onStart: async () => { + await new Promise((r) => setTimeout(r, 20)) + order.push('onStart-done') + }, + onChunk: () => { + order.push('chunk') + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // onStart should complete before any chunks are processed + expect(order[0]).toBe('onStart-done') + expect(order.filter((o) => o === 'chunk').length).toBeGreaterThan(0) + }) + }) + + // ========================================================================== + // chunkIndex tracking + // ========================================================================== + describe('chunkIndex tracking', () => { + it('should increment chunkIndex for each yielded chunk', async () => { + const indices: Array = [] + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textContent('A'), + ev.textContent('B'), + ev.runFinished('stop'), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'index-tracker', + onChunk: (ctx) => { + indices.push(ctx.chunkIndex) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Should be sequential indices starting from 0 + expect(indices).toEqual([0, 1, 2, 3]) + }) + }) + + // ========================================================================== + // conversationId propagation + // ========================================================================== + describe('conversationId', () => { + it('should propagate conversationId to middleware context', async () => { + let capturedConvId: string | undefined + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onStart: (ctx) => { + capturedConvId = ctx.conversationId + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + conversationId: 'conv-42', + }) + await collectChunks(stream as AsyncIterable) + + expect(capturedConvId).toBe('conv-42') + }) + }) + + // ========================================================================== + // End-to-end hook ordering with tool loop + // ========================================================================== + describe('full hook ordering', () => { + it('should fire all hooks in correct order during a tool-call loop', async () => { + const events: Array = [] + + const tool = serverTool('myTool', () => ({ done: true })) + + const usage = { + promptTokens: 10, + completionTokens: 5, + totalTokens: 15, + } + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls', 'run-1', usage), + ], + [ + ev.runStarted(), + ev.textContent('Done'), + ev.runFinished('stop', 'run-2', usage), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'order-tracker', + onConfig: (ctx) => { + events.push(`onConfig:${ctx.phase}`) + }, + onStart: () => { + events.push('onStart') + }, + onChunk: (_ctx, chunk) => { + events.push(`onChunk:${chunk.type}`) + }, + onBeforeToolCall: (_ctx, hookCtx) => { + events.push(`onBeforeToolCall:${hookCtx.toolName}`) + }, + onAfterToolCall: (_ctx, info) => { + events.push(`onAfterToolCall:${info.toolName}:${info.ok}`) + }, + onUsage: () => { + events.push('onUsage') + }, + onFinish: () => { + events.push('onFinish') + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Verify the expected ordering + expect(events).toEqual([ + // Init phase + 'onConfig:init', + 'onStart', + // First model call (beforeModel phase) + 'onConfig:beforeModel', + 'onChunk:RUN_STARTED', + 'onChunk:TOOL_CALL_START', + 'onChunk:TOOL_CALL_ARGS', + 'onChunk:TOOL_CALL_END', + 'onChunk:RUN_FINISHED', + 'onUsage', + // Tool execution phase + 'onBeforeToolCall:myTool', + 'onAfterToolCall:myTool:true', + // Second model call (beforeModel phase) + 'onConfig:beforeModel', + 'onChunk:RUN_STARTED', + 'onChunk:TEXT_MESSAGE_CONTENT', + 'onChunk:RUN_FINISHED', + 'onUsage', + // Terminal + 'onFinish', + ]) + }) + }) + + // ========================================================================== + // Phase transitions in context + // ========================================================================== + describe('phase transitions', () => { + it('should set correct phase during each hook', async () => { + const phaseLog: Array<{ hook: string; phase: string }> = [] + + const tool = serverTool('myTool', () => ({ ok: true })) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'phase-tracker', + onConfig: (ctx) => { + phaseLog.push({ hook: 'onConfig', phase: ctx.phase }) + }, + onStart: (ctx) => { + phaseLog.push({ hook: 'onStart', phase: ctx.phase }) + }, + onChunk: (ctx) => { + phaseLog.push({ hook: 'onChunk', phase: ctx.phase }) + }, + onBeforeToolCall: (ctx) => { + phaseLog.push({ hook: 'onBeforeToolCall', phase: ctx.phase }) + }, + onAfterToolCall: (ctx) => { + phaseLog.push({ hook: 'onAfterToolCall', phase: ctx.phase }) + }, + onFinish: (ctx) => { + phaseLog.push({ hook: 'onFinish', phase: ctx.phase }) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Verify phases for key hooks + const configPhases = phaseLog + .filter((e) => e.hook === 'onConfig') + .map((e) => e.phase) + expect(configPhases).toEqual(['init', 'beforeModel', 'beforeModel']) + + // onChunk should be in 'modelStream' phase + const chunkPhases = phaseLog + .filter((e) => e.hook === 'onChunk') + .map((e) => e.phase) + expect(chunkPhases.every((p) => p === 'modelStream')).toBe(true) + + // onBeforeToolCall should be in 'beforeTools' phase + const beforeToolPhases = phaseLog + .filter((e) => e.hook === 'onBeforeToolCall') + .map((e) => e.phase) + expect(beforeToolPhases).toEqual(['beforeTools']) + + // onAfterToolCall should be in 'beforeTools' phase (still in tool execution) + const afterToolPhases = phaseLog + .filter((e) => e.hook === 'onAfterToolCall') + .map((e) => e.phase) + expect(afterToolPhases).toEqual(['beforeTools']) + }) + }) + + // ========================================================================== + // onConfig transforms messages + // ========================================================================== + describe('onConfig message transform', () => { + it('should allow middleware to prepend messages via onConfig', async () => { + const { adapter, calls } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'message-injector', + onConfig: (_ctx, config) => ({ + messages: [ + { role: 'user' as const, content: 'Context: you are helpful' }, + ...config.messages, + ], + }), + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Adapter should receive the extra message prepended + const adapterMessages = calls[0]!.messages as Array<{ content: string }> + expect(adapterMessages.length).toBeGreaterThanOrEqual(2) + expect(adapterMessages[0]!.content).toBe('Context: you are helpful') + }) + }) + + // ========================================================================== + // onConfig transforms temperature/topP/maxTokens + // ========================================================================== + describe('onConfig parameter transforms', () => { + it('should allow middleware to transform temperature, topP, and maxTokens', async () => { + const { adapter, calls } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'param-override', + onConfig: () => ({ + temperature: 0.9, + topP: 0.8, + maxTokens: 500, + }), + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + temperature: 0.1, + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(calls[0]!.temperature).toBe(0.9) + expect(calls[0]!.topP).toBe(0.8) + expect(calls[0]!.maxTokens).toBe(500) + }) + }) + + // ========================================================================== + // onFinish info fields + // ========================================================================== + describe('onFinish info', () => { + it('should provide duration and usage in FinishInfo', async () => { + const onFinish = vi.fn() + const usage = { promptTokens: 20, completionTokens: 10, totalTokens: 30 } + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textContent('Hello world'), + ev.runFinished('stop', 'run-1', usage), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onFinish, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(onFinish).toHaveBeenCalledOnce() + const info = onFinish.mock.calls[0]![1] + expect(info.finishReason).toBe('stop') + expect(info.content).toBe('Hello world') + expect(info.duration).toBeGreaterThanOrEqual(0) + expect(info.usage).toEqual(usage) + }) + + it('should accumulate content across multiple text chunks', async () => { + const onFinish = vi.fn() + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textContent('Hello'), + ev.textContent(' '), + ev.textContent('world'), + ev.runFinished('stop'), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onFinish, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(onFinish.mock.calls[0]![1].content).toBe('Hello world') + }) + }) + + // ========================================================================== + // onError info fields + // ========================================================================== + describe('onError info', () => { + it('should provide the error object and duration in ErrorInfo', async () => { + const onError = vi.fn() + + const { adapter } = createMockAdapter({ + // eslint-disable-next-line require-yield + chatStreamFn: async function* () { + throw new Error('network timeout') + }, + }) + + const middleware: ChatMiddleware = { + name: 'test', + onError, + } + + try { + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + } catch { + // Expected + } + + expect(onError).toHaveBeenCalledOnce() + const info = onError.mock.calls[0]![1] + expect(info.error).toBeInstanceOf(Error) + expect((info.error as Error).message).toBe('network timeout') + expect(info.duration).toBeGreaterThanOrEqual(0) + }) + }) + + // ========================================================================== + // Multiple tools with middleware + // ========================================================================== + describe('multiple tools', () => { + it('should call onBeforeToolCall and onAfterToolCall for each tool', async () => { + const beforeNames: Array = [] + const afterNames: Array = [] + + const toolA = serverTool('toolA', () => ({ a: true })) + const toolB = serverTool('toolB', () => ({ b: true })) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'toolA'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'toolA', { input: {} }), + ev.toolStart('tc-2', 'toolB'), + ev.toolArgs('tc-2', '{}'), + ev.toolEnd('tc-2', 'toolB', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'multi-tool-observer', + onBeforeToolCall: (_ctx, hookCtx) => { + beforeNames.push(hookCtx.toolName) + }, + onAfterToolCall: (_ctx, info) => { + afterNames.push(info.toolName) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [toolA, toolB], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(beforeNames).toEqual(['toolA', 'toolB']) + expect(afterNames).toEqual(['toolA', 'toolB']) + }) + }) + + // ========================================================================== + // Deferred promise rejection doesn't crash + // ========================================================================== + describe('deferred rejection', () => { + it('should not crash when a deferred promise rejects', async () => { + const onFinish = vi.fn() + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onStart: (ctx) => { + ctx.defer(Promise.reject(new Error('fire and forget failure'))) + }, + onFinish, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + + // Should not throw + await collectChunks(stream as AsyncIterable) + + expect(onFinish).toHaveBeenCalledOnce() + }) + }) + + // ========================================================================== + // onAfterToolCall result and duration + // ========================================================================== + describe('onAfterToolCall details', () => { + it('should provide result and positive duration on success', async () => { + const afterCalls: Array<{ + result: unknown + duration: number + toolCallId: string + }> = [] + + const tool = serverTool('slowTool', async () => { + await new Promise((r) => setTimeout(r, 15)) + return { data: 42 } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'slowTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'slowTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onAfterToolCall: (_ctx, info) => { + afterCalls.push({ + result: info.result, + duration: info.duration, + toolCallId: info.toolCallId, + }) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(afterCalls).toHaveLength(1) + expect(afterCalls[0]!.result).toEqual({ data: 42 }) + expect(afterCalls[0]!.duration).toBeGreaterThanOrEqual(10) + expect(afterCalls[0]!.toolCallId).toBe('tc-1') + }) + }) + + // ========================================================================== + // Multiple middleware hook ordering + // ========================================================================== + describe('multiple middleware ordering', () => { + it('should call hooks on all middlewares in array order', async () => { + const order: Array = [] + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const mw1: ChatMiddleware = { + name: 'first', + onStart: () => { + order.push('mw1:onStart') + }, + onChunk: () => { + order.push('mw1:onChunk') + }, + onFinish: () => { + order.push('mw1:onFinish') + }, + } + + const mw2: ChatMiddleware = { + name: 'second', + onStart: () => { + order.push('mw2:onStart') + }, + onChunk: () => { + order.push('mw2:onChunk') + }, + onFinish: () => { + order.push('mw2:onFinish') + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [mw1, mw2], + }) + await collectChunks(stream as AsyncIterable) + + // onStart: mw1 then mw2 + const startEvents = order.filter((e) => e.includes('onStart')) + expect(startEvents).toEqual(['mw1:onStart', 'mw2:onStart']) + + // onChunk: for each chunk, mw1 then mw2 + // With 3 chunks (RUN_STARTED, TEXT_MESSAGE_CONTENT, RUN_FINISHED), + // each should go through mw1 then mw2 + const chunkEvents = order.filter((e) => e.includes('onChunk')) + for (let i = 0; i < chunkEvents.length; i += 2) { + expect(chunkEvents[i]).toBe('mw1:onChunk') + expect(chunkEvents[i + 1]).toBe('mw2:onChunk') + } + + // onFinish: mw1 then mw2 + const finishEvents = order.filter((e) => e.includes('onFinish')) + expect(finishEvents).toEqual(['mw1:onFinish', 'mw2:onFinish']) + }) + }) + + // ========================================================================== + // ctx.abort() from onStart + // ========================================================================== + describe('abort from onStart', () => { + it('should abort before streaming when ctx.abort() is called in onStart', async () => { + const onAbort = vi.fn() + const onChunk = vi.fn() + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'early-aborter', + onStart: (ctx) => { + ctx.abort('not allowed') + }, + onChunk, + onAbort, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(onAbort).toHaveBeenCalledOnce() + expect(onAbort.mock.calls[0]![1].reason).toBe('not allowed') + // No chunks should have been processed since we aborted before streaming + expect(onChunk).not.toHaveBeenCalled() + }) + }) + + // ========================================================================== + // onConfig returning void passes through + // ========================================================================== + describe('onConfig passthrough', () => { + it('should pass through config unchanged when onConfig returns void', async () => { + const { adapter, calls } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'observer-only', + onConfig: () => { + // observe but don't return anything + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + systemPrompts: ['Be helpful'], + temperature: 0.5, + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Original config should reach the adapter untouched + expect(calls[0]!.systemPrompts).toEqual(['Be helpful']) + expect(calls[0]!.temperature).toBe(0.5) + }) + }) + + // ========================================================================== + // onConfig metadata and modelOptions transform + // ========================================================================== + describe('onConfig metadata transform', () => { + it('should allow middleware to set metadata and modelOptions', async () => { + const { adapter, calls } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'meta-injector', + onConfig: (_ctx, config) => ({ + metadata: { ...config.metadata, injected: true }, + }), + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + metadata: { original: true }, + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(calls[0]!.metadata).toEqual({ original: true, injected: true }) + }) + }) + + // ========================================================================== + // onUsage not called when no usage data + // ========================================================================== + describe('onUsage without usage data', () => { + it('should not call onUsage when RUN_FINISHED has no usage', async () => { + const onUsage = vi.fn() + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'test', + onUsage, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + expect(onUsage).not.toHaveBeenCalled() + }) + }) + + + // ========================================================================== + // All hooks from a single middleware object + // ========================================================================== + describe('single middleware with all hooks', () => { + it('should invoke every hook on one middleware in a full tool-call flow', async () => { + const hooksCalled = new Set() + + const tool = serverTool('myTool', () => ({ ok: true })) + const usage = { promptTokens: 5, completionTokens: 3, totalTokens: 8 } + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls', 'run-1', usage), + ], + [ + ev.runStarted(), + ev.textContent('Done'), + ev.runFinished('stop', 'run-2', usage), + ], + ], + }) + + const middleware: ChatMiddleware = { + name: 'all-hooks', + onConfig: () => { + hooksCalled.add('onConfig') + }, + onStart: () => { + hooksCalled.add('onStart') + }, + onChunk: () => { + hooksCalled.add('onChunk') + }, + onBeforeToolCall: () => { + hooksCalled.add('onBeforeToolCall') + }, + onAfterToolCall: () => { + hooksCalled.add('onAfterToolCall') + }, + onUsage: () => { + hooksCalled.add('onUsage') + }, + onFinish: () => { + hooksCalled.add('onFinish') + }, + onAbort: () => { + hooksCalled.add('onAbort') + }, + onError: () => { + hooksCalled.add('onError') + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // All hooks except onAbort and onError should have been called + expect(hooksCalled).toContain('onConfig') + expect(hooksCalled).toContain('onStart') + expect(hooksCalled).toContain('onChunk') + expect(hooksCalled).toContain('onBeforeToolCall') + expect(hooksCalled).toContain('onAfterToolCall') + expect(hooksCalled).toContain('onUsage') + expect(hooksCalled).toContain('onFinish') + // These are exclusive terminal hooks — only onFinish should fire + expect(hooksCalled).not.toContain('onAbort') + expect(hooksCalled).not.toContain('onError') + }) + }) + + // ========================================================================== + // onConfig removing tools + // ========================================================================== + describe('onConfig tool removal', () => { + it('should allow middleware to remove tools via onConfig', async () => { + const executeFn = vi.fn(() => ({ result: 'done' })) + const toolToRemove = serverTool('blocked', executeFn) + const toolToKeep = serverTool('allowed', () => ({ ok: true })) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'blocked'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'blocked', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('ok'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'tool-filter', + onConfig: (_ctx, config) => ({ + tools: config.tools.filter((t) => t.name !== 'blocked'), + }), + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [toolToRemove, toolToKeep], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // 'blocked' tool was removed by middleware, so it should NOT execute + expect(executeFn).not.toHaveBeenCalled() + }) + }) + + // ========================================================================== + // Multiple deferred promises + // ========================================================================== + describe('multiple deferred promises', () => { + it('should await all deferred promises from multiple hooks', async () => { + const completed: Array = [] + + const { adapter } = createMockAdapter({ + iterations: [ + [ev.runStarted(), ev.textContent('hi'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'multi-defer', + onStart: (ctx) => { + ctx.defer( + new Promise((resolve) => { + setTimeout(() => { + completed.push('defer-1') + resolve() + }, 5) + }), + ) + ctx.defer( + new Promise((resolve) => { + setTimeout(() => { + completed.push('defer-2') + resolve() + }, 10) + }), + ) + }, + onChunk: (ctx) => { + ctx.defer( + new Promise((resolve) => { + setTimeout(() => { + completed.push('defer-3') + resolve() + }, 5) + }), + ) + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // All deferred promises should have settled + expect(completed).toContain('defer-1') + expect(completed).toContain('defer-2') + expect(completed).toContain('defer-3') + }) + }) + + // ========================================================================== + // Per-iteration config transforms (multi-step agent loop) + // ========================================================================== + describe('per-iteration config transforms', () => { + it('should apply different config transforms at init, iteration 0, and iteration 1', async () => { + const tool = serverTool('myTool', () => ({ ok: true })) + + const { adapter, calls } = createMockAdapter({ + iterations: [ + // Iteration 0: model calls a tool + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + // Iteration 1: model responds with text + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const middleware: ChatMiddleware = { + name: 'per-step-config', + onConfig: (ctx) => { + if (ctx.phase === 'init') { + return { + systemPrompts: ['init-prompt'], + temperature: 0.1, + } + } + if (ctx.phase === 'beforeModel' && ctx.iteration === 0) { + return { + systemPrompts: ['iter-0-prompt'], + temperature: 0.5, + maxTokens: 100, + } + } + if (ctx.phase === 'beforeModel' && ctx.iteration === 1) { + return { + systemPrompts: ['iter-1-prompt'], + temperature: 0.9, + maxTokens: 200, + } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // Iteration 0: adapter receives iter-0 config (overrides init) + expect(calls[0]!.systemPrompts).toEqual(['iter-0-prompt']) + expect(calls[0]!.temperature).toBe(0.5) + expect(calls[0]!.maxTokens).toBe(100) + + // Iteration 1: adapter receives iter-1 config + expect(calls[1]!.systemPrompts).toEqual(['iter-1-prompt']) + expect(calls[1]!.temperature).toBe(0.9) + expect(calls[1]!.maxTokens).toBe(200) + }) + + it('should accumulate config changes across multiple middleware per iteration', async () => { + const tool = serverTool('myTool', () => ({ ok: true })) + + const { adapter, calls } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + // First middleware: adds a system prompt each iteration + const mw1: ChatMiddleware = { + name: 'prompt-adder', + onConfig: (ctx, config) => { + if (ctx.phase === 'beforeModel') { + return { + systemPrompts: [ + ...config.systemPrompts, + `added-by-mw1-iter-${ctx.iteration}`, + ], + } + } + return undefined + }, + } + + // Second middleware: doubles maxTokens on iteration 1 + const mw2: ChatMiddleware = { + name: 'token-scaler', + onConfig: (ctx, config) => { + if (ctx.phase === 'beforeModel' && ctx.iteration === 1) { + return { + maxTokens: (config.maxTokens ?? 100) * 2, + } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + systemPrompts: ['base'], + maxTokens: 100, + middleware: [mw1, mw2], + }) + await collectChunks(stream as AsyncIterable) + + // Iteration 0: mw1 adds prompt, mw2 does nothing + expect(calls[0]!.systemPrompts).toEqual(['base', 'added-by-mw1-iter-0']) + expect(calls[0]!.maxTokens).toBe(100) + + // Iteration 1: mw1 adds prompt, mw2 doubles maxTokens + // Note: mw1's change from iter-0 persists since applyMiddlewareConfig updates the engine + expect(calls[1]!.systemPrompts).toContain('added-by-mw1-iter-1') + expect(calls[1]!.maxTokens).toBe(200) + }) + + it('should let middleware observe config changes from the previous iteration', async () => { + const configSnapshots: Array<{ + phase: string + iteration: number + maxTokens?: number + systemPrompts: Array + }> = [] + + const tool = serverTool('myTool', () => ({ ok: true })) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'myTool'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'myTool', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + // Middleware that modifies maxTokens and observes state + const middleware: ChatMiddleware = { + name: 'observer-mutator', + onConfig: (ctx, config) => { + configSnapshots.push({ + phase: ctx.phase, + iteration: ctx.iteration, + maxTokens: config.maxTokens, + systemPrompts: [...config.systemPrompts], + }) + + // On each beforeModel call, bump maxTokens by 50 + if (ctx.phase === 'beforeModel') { + return { + maxTokens: (config.maxTokens ?? 0) + 50, + } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + maxTokens: 100, + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // init: sees original maxTokens=100 + expect(configSnapshots[0]).toMatchObject({ + phase: 'init', + iteration: 0, + maxTokens: 100, + }) + + // beforeModel iter 0: sees maxTokens=100 (init didn't modify it) + expect(configSnapshots[1]).toMatchObject({ + phase: 'beforeModel', + iteration: 0, + maxTokens: 100, + }) + + // beforeModel iter 1: sees maxTokens=150 (iter 0 added 50) + expect(configSnapshots[2]).toMatchObject({ + phase: 'beforeModel', + iteration: 1, + maxTokens: 150, + }) + }) + + it('should let middleware change tools between iterations', async () => { + const executedTools: Array = [] + let onConfigCallCount = 0 + + const toolA = serverTool('toolA', () => { + executedTools.push('toolA') + return { a: true } + }) + + const toolB = serverTool('toolB', () => { + executedTools.push('toolB') + return { b: true } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + // Iteration 0: model calls toolA + [ + ev.runStarted(), + ev.toolStart('tc-1', 'toolA'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'toolA', { input: {} }), + ev.runFinished('tool_calls'), + ], + // Iteration 1: model calls toolB + [ + ev.runStarted(), + ev.toolStart('tc-2', 'toolB'), + ev.toolArgs('tc-2', '{}'), + ev.toolEnd('tc-2', 'toolB', { input: {} }), + ev.runFinished('tool_calls'), + ], + // Iteration 2: model responds with text + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + // Middleware swaps the tools on iteration 1 + const middleware: ChatMiddleware = { + name: 'tool-swapper', + onConfig: (ctx) => { + onConfigCallCount++ + if (ctx.phase === 'beforeModel' && ctx.iteration === 0) { + // Only expose toolA + return { tools: [toolA] } + } + if (ctx.phase === 'beforeModel' && ctx.iteration >= 1) { + // Expose both tools + return { tools: [toolA, toolB] } + } + return undefined + }, + } + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [toolA, toolB], + middleware: [middleware], + }) + await collectChunks(stream as AsyncIterable) + + // toolA executed on iteration 0, toolB on iteration 1 + expect(executedTools).toContain('toolA') + expect(executedTools).toContain('toolB') + // onConfig called at init + 3 beforeModel iterations + expect(onConfigCallCount).toBe(4) + }) + }) +}) diff --git a/packages/typescript/ai/tests/test-utils.ts b/packages/typescript/ai/tests/test-utils.ts new file mode 100644 index 000000000..66bfa5ba5 --- /dev/null +++ b/packages/typescript/ai/tests/test-utils.ts @@ -0,0 +1,181 @@ +import type { AnyTextAdapter } from '../src/activities/chat/adapter' +import type { + StreamChunk, + TextMessageContentEvent, + Tool, +} from '../src/types' + +// ============================================================================ +// Chunk factory +// ============================================================================ + +/** Create a typed StreamChunk with minimal boilerplate. */ +export function chunk( + type: T, + fields: Omit, 'type' | 'timestamp'>, +): StreamChunk { + return { type, timestamp: Date.now(), ...fields } as unknown as StreamChunk +} + +// ============================================================================ +// Event shorthand builders +// ============================================================================ + +/** Shorthand chunk factories for common AG-UI events. */ +export const ev = { + runStarted: (runId = 'run-1') => chunk('RUN_STARTED', { runId }), + textStart: (messageId = 'msg-1') => + chunk('TEXT_MESSAGE_START', { messageId, role: 'assistant' as const }), + textContent: (delta: string, messageId = 'msg-1') => + chunk('TEXT_MESSAGE_CONTENT', { messageId, delta }), + textEnd: (messageId = 'msg-1') => chunk('TEXT_MESSAGE_END', { messageId }), + toolStart: (toolCallId: string, toolName: string, index?: number) => + chunk('TOOL_CALL_START', { + toolCallId, + toolName, + ...(index !== undefined ? { index } : {}), + }), + toolArgs: (toolCallId: string, delta: string) => + chunk('TOOL_CALL_ARGS', { toolCallId, delta }), + toolEnd: ( + toolCallId: string, + toolName: string, + opts?: { input?: unknown; result?: string }, + ) => chunk('TOOL_CALL_END', { toolCallId, toolName, ...opts }), + runFinished: ( + finishReason: + | 'stop' + | 'length' + | 'content_filter' + | 'tool_calls' + | null = 'stop', + runId = 'run-1', + usage?: { + promptTokens: number + completionTokens: number + totalTokens: number + }, + ) => + chunk('RUN_FINISHED', { runId, finishReason, ...(usage ? { usage } : {}) }), + runError: (message: string, runId = 'run-1') => + chunk('RUN_ERROR', { runId, error: { message } }), + stepFinished: (delta: string, stepId = 'step-1') => + chunk('STEP_FINISHED', { stepId, delta }), +} + +// ============================================================================ +// Mock adapter +// ============================================================================ + +/** + * Create a mock adapter that satisfies AnyTextAdapter. + * `chatStreamFn` receives the options and returns an AsyncIterable of chunks. + * Multiple invocations can be tracked via the returned `calls` array. + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any -- mock adapter callbacks receive internal SDK types +export function createMockAdapter(options: { + chatStreamFn?: (opts: any) => AsyncIterable + /** Array of chunk sequences: chatStream returns iterations[0] on first call, iterations[1] on second, etc. */ + iterations?: Array> + structuredOutput?: (opts: any) => Promise<{ data: unknown; rawText: string }> +}) { + const calls: Array> = [] + let callIndex = 0 + + const adapter: AnyTextAdapter = { + kind: 'text' as const, + name: 'mock', + model: 'test-model' as const, + '~types': { + providerOptions: {} as Record, + inputModalities: ['text'] as readonly ['text'], + messageMetadataByModality: { + text: undefined as unknown, + image: undefined as unknown, + audio: undefined as unknown, + video: undefined as unknown, + document: undefined as unknown, + }, + }, + chatStream: (opts: any) => { + calls.push(opts) + + if (options.chatStreamFn) { + return options.chatStreamFn(opts) + } + + if (options.iterations) { + const chunks = options.iterations[callIndex] || [] + callIndex++ + return (async function* () { + for (const c of chunks) yield c + })() + } + + return (async function* () {})() + }, + structuredOutput: + options.structuredOutput ?? (async () => ({ data: {}, rawText: '{}' })), + } + + return { adapter, calls } +} + +// ============================================================================ +// Stream collection +// ============================================================================ + +/** Collect all chunks from an async iterable. */ +export async function collectChunks( + stream: AsyncIterable, +): Promise> { + const chunks: Array = [] + for await (const c of stream) { + chunks.push(c) + } + return chunks +} + +// ============================================================================ +// Type guards & extraction helpers +// ============================================================================ + +/** Type guard for TEXT_MESSAGE_CONTENT chunks. */ +export function isTextContent( + c: StreamChunk, +): c is TextMessageContentEvent { + return c.type === 'TEXT_MESSAGE_CONTENT' +} + +/** Extract all text deltas from a chunk array. */ +export function getDeltas(chunks: Array): Array { + return chunks.filter(isTextContent).map((c) => c.delta) +} + +// ============================================================================ +// Tool helpers +// ============================================================================ + +/** Simple server tool for testing. */ +export function serverTool( + name: string, + executeFn: (args: unknown) => unknown, +): Tool { + return { + name, + description: `Test tool: ${name}`, + execute: executeFn, + } +} + +/** Client tool (no execute function). */ +export function clientTool( + name: string, + opts?: { needsApproval?: boolean }, +): Tool { + return { + name, + description: `Client tool: ${name}`, + needsApproval: opts?.needsApproval, + } +} diff --git a/packages/typescript/ai/tests/tool-cache-middleware.test.ts b/packages/typescript/ai/tests/tool-cache-middleware.test.ts new file mode 100644 index 000000000..1e96a7cce --- /dev/null +++ b/packages/typescript/ai/tests/tool-cache-middleware.test.ts @@ -0,0 +1,722 @@ +import { describe, expect, it, vi } from 'vitest' +import { chat } from '../src/activities/chat/index' +import { toolCacheMiddleware } from '../src/activities/chat/middleware/tool-cache-middleware' +import type { ToolCacheEntry, ToolCacheStorage } from '../src/activities/chat/middleware/tool-cache-middleware' +import type { StreamChunk } from '../src/types' +import { + ev, + createMockAdapter, + collectChunks, + serverTool, +} from './test-utils' + +// ============================================================================ +// Tests +// ============================================================================ + +describe('toolCacheMiddleware', () => { + it('should cache tool results and skip execution on cache hit', async () => { + let callCount = 0 + const tool = serverTool('getWeather', () => { + callCount++ + return { temp: 72, condition: 'sunny' } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + // Iteration 0: model calls getWeather("NYC") + [ + ev.runStarted(), + ev.toolStart('tc-1', 'getWeather'), + ev.toolArgs('tc-1', '{"city":"NYC"}'), + ev.toolEnd('tc-1', 'getWeather', { input: { city: 'NYC' } }), + ev.runFinished('tool_calls'), + ], + // Iteration 1: model calls getWeather("NYC") again (same args) + [ + ev.runStarted(), + ev.toolStart('tc-2', 'getWeather'), + ev.toolArgs('tc-2', '{"city":"NYC"}'), + ev.toolEnd('tc-2', 'getWeather', { input: { city: 'NYC' } }), + ev.runFinished('tool_calls'), + ], + // Iteration 2: model responds with text + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [toolCacheMiddleware()], + }) + await collectChunks(stream as AsyncIterable) + + // Tool should only be executed once — second call served from cache + expect(callCount).toBe(1) + }) + + it('should not cache when args differ', async () => { + let callCount = 0 + const tool = serverTool('getWeather', (args) => { + callCount++ + return { city: (args as { city: string }).city, temp: 72 } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'getWeather'), + ev.toolArgs('tc-1', '{"city":"NYC"}'), + ev.toolEnd('tc-1', 'getWeather', { input: { city: 'NYC' } }), + ev.runFinished('tool_calls'), + ], + // Different args — should NOT hit cache + [ + ev.runStarted(), + ev.toolStart('tc-2', 'getWeather'), + ev.toolArgs('tc-2', '{"city":"LA"}'), + ev.toolEnd('tc-2', 'getWeather', { input: { city: 'LA' } }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [toolCacheMiddleware()], + }) + await collectChunks(stream as AsyncIterable) + + // Both calls should execute — different args + expect(callCount).toBe(2) + }) + + it('should respect toolNames filter', async () => { + let weatherCalls = 0 + let stockCalls = 0 + const weatherTool = serverTool('getWeather', () => { + weatherCalls++ + return { temp: 72 } + }) + const stockTool = serverTool('getStock', () => { + stockCalls++ + return { price: 100 } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + // Both tools called + [ + ev.runStarted(), + ev.toolStart('tc-1', 'getWeather'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'getWeather', { input: {} }), + ev.toolStart('tc-2', 'getStock'), + ev.toolArgs('tc-2', '{}'), + ev.toolEnd('tc-2', 'getStock', { input: {} }), + ev.runFinished('tool_calls'), + ], + // Both called again + [ + ev.runStarted(), + ev.toolStart('tc-3', 'getWeather'), + ev.toolArgs('tc-3', '{}'), + ev.toolEnd('tc-3', 'getWeather', { input: {} }), + ev.toolStart('tc-4', 'getStock'), + ev.toolArgs('tc-4', '{}'), + ev.toolEnd('tc-4', 'getStock', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + // Only cache getWeather + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [weatherTool, stockTool], + middleware: [toolCacheMiddleware({ toolNames: ['getWeather'] })], + }) + await collectChunks(stream as AsyncIterable) + + // getWeather: 1 execute + 1 cache hit = 1 call + expect(weatherCalls).toBe(1) + // getStock: not cached, both calls execute + expect(stockCalls).toBe(2) + }) + + it('should respect TTL and not serve expired entries', async () => { + vi.useFakeTimers() + + try { + let callCount = 0 + const tool = serverTool('getData', () => { + callCount++ + return { data: callCount } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'getData'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'getData', { input: {} }), + ev.runFinished('tool_calls'), + ], + // Second call — same args but after TTL expires + [ + ev.runStarted(), + ev.toolStart('tc-2', 'getData'), + ev.toolArgs('tc-2', '{}'), + ev.toolEnd('tc-2', 'getData', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const cacheMiddleware = toolCacheMiddleware({ ttl: 5000 }) + + // Manually simulate the cache flow with controlled time + // First tool call: cache miss, execute, store + const beforeResult1 = await cacheMiddleware.onBeforeToolCall!( + {} as Parameters>[0], + { + toolCall: { + id: 'tc-1', + function: { name: 'getData', arguments: '{}' }, + }, + tool: tool, + args: {}, + toolName: 'getData', + toolCallId: 'tc-1', + }, + ) + // No cache entry — should proceed + expect(beforeResult1).toBeUndefined() + + // Simulate successful execution and store result + await cacheMiddleware.onAfterToolCall!( + {} as Parameters>[0], + { + toolCall: { + id: 'tc-1', + function: { name: 'getData', arguments: '{}' }, + }, + tool: tool, + toolName: 'getData', + toolCallId: 'tc-1', + ok: true, + duration: 10, + result: { data: 1 }, + }, + ) + + // Second call immediately — should hit cache + const beforeResult2 = await cacheMiddleware.onBeforeToolCall!( + {} as Parameters>[0], + { + toolCall: { + id: 'tc-2', + function: { name: 'getData', arguments: '{}' }, + }, + tool: tool, + args: {}, + toolName: 'getData', + toolCallId: 'tc-2', + }, + ) + expect(beforeResult2).toEqual({ type: 'skip', result: { data: 1 } }) + + // Advance time past TTL + vi.advanceTimersByTime(6000) + + // Third call after TTL — should miss cache + const beforeResult3 = await cacheMiddleware.onBeforeToolCall!( + {} as Parameters>[0], + { + toolCall: { + id: 'tc-3', + function: { name: 'getData', arguments: '{}' }, + }, + tool: tool, + args: {}, + toolName: 'getData', + toolCallId: 'tc-3', + }, + ) + // Expired — should proceed with execution + expect(beforeResult3).toBeUndefined() + } finally { + vi.useRealTimers() + } + }) + + it('should respect maxSize and evict oldest entries', async () => { + const results: Array = [] + const tool = serverTool('lookup', (args) => { + const key = (args as { key: string }).key + return { value: `result-${key}` } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + // Fill cache: key=a, key=b (maxSize=2) + [ + ev.runStarted(), + ev.toolStart('tc-1', 'lookup'), + ev.toolArgs('tc-1', '{"key":"a"}'), + ev.toolEnd('tc-1', 'lookup', { input: { key: 'a' } }), + ev.toolStart('tc-2', 'lookup'), + ev.toolArgs('tc-2', '{"key":"b"}'), + ev.toolEnd('tc-2', 'lookup', { input: { key: 'b' } }), + ev.runFinished('tool_calls'), + ], + // Add key=c — should evict key=a + [ + ev.runStarted(), + ev.toolStart('tc-3', 'lookup'), + ev.toolArgs('tc-3', '{"key":"c"}'), + ev.toolEnd('tc-3', 'lookup', { input: { key: 'c' } }), + ev.runFinished('tool_calls'), + ], + // key=b should still be cached, key=a should miss + [ + ev.runStarted(), + ev.toolStart('tc-4', 'lookup'), + ev.toolArgs('tc-4', '{"key":"b"}'), + ev.toolEnd('tc-4', 'lookup', { input: { key: 'b' } }), + ev.toolStart('tc-5', 'lookup'), + ev.toolArgs('tc-5', '{"key":"a"}'), + ev.toolEnd('tc-5', 'lookup', { input: { key: 'a' } }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const executeSpy = vi.fn((args: unknown) => { + const key = (args as { key: string }).key + const result = { value: `result-${key}` } + results.push(result) + return result + }) + const spyTool = serverTool('lookup', executeSpy) + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [spyTool], + middleware: [toolCacheMiddleware({ maxSize: 2 })], + }) + await collectChunks(stream as AsyncIterable) + + // Calls: a(exec), b(exec), c(exec, evicts a), b(cache hit), a(exec, was evicted) + expect(executeSpy).toHaveBeenCalledTimes(4) + const executedKeys = executeSpy.mock.calls.map( + (call) => (call[0] as { key: string }).key, + ) + expect(executedKeys).toEqual(['a', 'b', 'c', 'a']) + }) + + it('should support custom keyFn', async () => { + let callCount = 0 + const tool = serverTool('search', () => { + callCount++ + return { results: [] } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'search'), + ev.toolArgs('tc-1', '{"query":"hello","page":1}'), + ev.toolEnd('tc-1', 'search', { + input: { query: 'hello', page: 1 }, + }), + ev.runFinished('tool_calls'), + ], + // Same query but different page — custom keyFn ignores page + [ + ev.runStarted(), + ev.toolStart('tc-2', 'search'), + ev.toolArgs('tc-2', '{"query":"hello","page":2}'), + ev.toolEnd('tc-2', 'search', { + input: { query: 'hello', page: 2 }, + }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + // Custom keyFn that only uses the query field, ignoring page + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [ + toolCacheMiddleware({ + keyFn: (toolName, args) => { + const query = (args as { query: string }).query + return `${toolName}:${query}` + }, + }), + ], + }) + await collectChunks(stream as AsyncIterable) + + // Second call should hit cache since keyFn ignores page + expect(callCount).toBe(1) + }) + + it('should not cache failed tool executions', async () => { + let callCount = 0 + const tool = serverTool('flaky', () => { + callCount++ + if (callCount === 1) { + throw new Error('temporary failure') + } + return { ok: true } + }) + + const { adapter } = createMockAdapter({ + iterations: [ + // First call — will fail + [ + ev.runStarted(), + ev.toolStart('tc-1', 'flaky'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'flaky', { input: {} }), + ev.runFinished('tool_calls'), + ], + // Second call — same args, should NOT be cached (first failed) + [ + ev.runStarted(), + ev.toolStart('tc-2', 'flaky'), + ev.toolArgs('tc-2', '{}'), + ev.toolEnd('tc-2', 'flaky', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [toolCacheMiddleware()], + }) + await collectChunks(stream as AsyncIterable) + + // Both calls should execute — failed results are not cached + expect(callCount).toBe(2) + }) + + it('should return the middleware with name "tool-cache"', () => { + const mw = toolCacheMiddleware() + expect(mw.name).toBe('tool-cache-middleware') + }) + + // ========================================================================== + // Custom storage + // ========================================================================== + + describe('custom storage', () => { + function createMapStorage(): ToolCacheStorage & { + store: Map + } { + const store = new Map() + return { + store, + getItem: (key) => store.get(key), + setItem: (key, value) => { + store.set(key, value) + }, + deleteItem: (key) => { + store.delete(key) + }, + } + } + + it('should use custom storage for cache hits', async () => { + let callCount = 0 + const tool = serverTool('getWeather', () => { + callCount++ + return { temp: 72 } + }) + + const storage = createMapStorage() + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'getWeather'), + ev.toolArgs('tc-1', '{"city":"NYC"}'), + ev.toolEnd('tc-1', 'getWeather', { input: { city: 'NYC' } }), + ev.runFinished('tool_calls'), + ], + [ + ev.runStarted(), + ev.toolStart('tc-2', 'getWeather'), + ev.toolArgs('tc-2', '{"city":"NYC"}'), + ev.toolEnd('tc-2', 'getWeather', { input: { city: 'NYC' } }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [toolCacheMiddleware({ storage })], + }) + await collectChunks(stream as AsyncIterable) + + expect(callCount).toBe(1) + expect(storage.store.size).toBe(1) + }) + + it('should work with async storage', async () => { + let callCount = 0 + const tool = serverTool('getData', () => { + callCount++ + return { value: callCount } + }) + + const store = new Map() + const asyncStorage: ToolCacheStorage = { + getItem: async (key) => store.get(key), + setItem: async (key, value) => { + store.set(key, value) + }, + deleteItem: async (key) => { + store.delete(key) + }, + } + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'getData'), + ev.toolArgs('tc-1', '{}'), + ev.toolEnd('tc-1', 'getData', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ + ev.runStarted(), + ev.toolStart('tc-2', 'getData'), + ev.toolArgs('tc-2', '{}'), + ev.toolEnd('tc-2', 'getData', { input: {} }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + const stream = chat({ + adapter, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [toolCacheMiddleware({ storage: asyncStorage })], + }) + await collectChunks(stream as AsyncIterable) + + expect(callCount).toBe(1) + expect(store.size).toBe(1) + }) + + it('should call deleteItem for expired entries', async () => { + vi.useFakeTimers() + + try { + const tool = serverTool('getData', () => ({ data: 1 })) + + const storage = createMapStorage() + const deleteSpy = vi.fn(storage.deleteItem.bind(storage)) + storage.deleteItem = deleteSpy + + const cacheMiddleware = toolCacheMiddleware({ + ttl: 5000, + storage, + }) + + // Store an entry via onAfterToolCall + await cacheMiddleware.onAfterToolCall!( + {} as Parameters< + NonNullable + >[0], + { + toolCall: { + id: 'tc-1', + function: { name: 'getData', arguments: '{}' }, + }, + tool: tool, + toolName: 'getData', + toolCallId: 'tc-1', + ok: true, + duration: 10, + result: { data: 1 }, + }, + ) + + // Advance time past TTL + vi.advanceTimersByTime(6000) + + // Try to get — should be expired + const result = await cacheMiddleware.onBeforeToolCall!( + {} as Parameters< + NonNullable + >[0], + { + toolCall: { + id: 'tc-2', + function: { name: 'getData', arguments: '{}' }, + }, + tool: tool, + args: {}, + toolName: 'getData', + toolCallId: 'tc-2', + }, + ) + + expect(result).toBeUndefined() + // deleteItem called for expired entry + expect(deleteSpy).toHaveBeenCalled() + } finally { + vi.useRealTimers() + } + }) + + it('should call storage methods with correct keys', async () => { + const tool = serverTool('search', () => ({ results: [] })) + + const storage = createMapStorage() + const getSpy = vi.fn(storage.getItem.bind(storage)) + const setSpy = vi.fn(storage.setItem.bind(storage)) + storage.getItem = getSpy + storage.setItem = setSpy + + const cacheMiddleware = toolCacheMiddleware({ storage }) + + // Trigger a before check (miss) + await cacheMiddleware.onBeforeToolCall!( + {} as Parameters< + NonNullable + >[0], + { + toolCall: { + id: 'tc-1', + function: { name: 'search', arguments: '{"q":"hello"}' }, + }, + tool: tool, + args: { q: 'hello' }, + toolName: 'search', + toolCallId: 'tc-1', + }, + ) + + const expectedKey = JSON.stringify(['search', { q: 'hello' }]) + expect(getSpy).toHaveBeenCalledWith(expectedKey) + + // Store a result + await cacheMiddleware.onAfterToolCall!( + {} as Parameters< + NonNullable + >[0], + { + toolCall: { + id: 'tc-1', + function: { name: 'search', arguments: '{"q":"hello"}' }, + }, + tool: tool, + toolName: 'search', + toolCallId: 'tc-1', + ok: true, + duration: 5, + result: { results: [] }, + }, + ) + + expect(setSpy).toHaveBeenCalledWith(expectedKey, { + result: { results: [] }, + timestamp: expect.any(Number), + }) + }) + + it('should share cache across multiple chat() calls via shared storage', async () => { + let callCount = 0 + const tool = serverTool('getWeather', () => { + callCount++ + return { temp: 72 } + }) + + const storage = createMapStorage() + + // First chat() call — populates cache + const { adapter: adapter1 } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-1', 'getWeather'), + ev.toolArgs('tc-1', '{"city":"NYC"}'), + ev.toolEnd('tc-1', 'getWeather', { input: { city: 'NYC' } }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + await collectChunks( + chat({ + adapter: adapter1, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [toolCacheMiddleware({ storage })], + }) as AsyncIterable, + ) + + expect(callCount).toBe(1) + + // Second chat() call — should hit shared storage cache + const { adapter: adapter2 } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.toolStart('tc-2', 'getWeather'), + ev.toolArgs('tc-2', '{"city":"NYC"}'), + ev.toolEnd('tc-2', 'getWeather', { input: { city: 'NYC' } }), + ev.runFinished('tool_calls'), + ], + [ev.runStarted(), ev.textContent('Done'), ev.runFinished('stop')], + ], + }) + + await collectChunks( + chat({ + adapter: adapter2, + messages: [{ role: 'user', content: 'Hi' }], + tools: [tool], + middleware: [toolCacheMiddleware({ storage })], + }) as AsyncIterable, + ) + + // Still 1 — second call served from shared storage + expect(callCount).toBe(1) + }) + }) +}) From 8739a2e22fd2bf9b18c9da17ecc7e7e34eeb54f8 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 13:37:31 +0000 Subject: [PATCH 02/16] ci: apply automated fixes --- .../ai/src/activities/chat/index.ts | 5 +--- .../src/activities/chat/tools/tool-calls.ts | 2 -- .../typescript/ai/tests/middleware.test.ts | 16 +++++++---- packages/typescript/ai/tests/test-utils.ts | 10 ++----- .../ai/tests/tool-cache-middleware.test.ts | 28 +++++++++++-------- 5 files changed, 30 insertions(+), 31 deletions(-) diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index 59339287e..16bc48071 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -357,10 +357,7 @@ class TextEngine< } do { - if ( - this.earlyTermination || - this.isCancelled() - ) { + if (this.earlyTermination || this.isCancelled()) { return } diff --git a/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts b/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts index c81cc09f2..c542c88c3 100644 --- a/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts +++ b/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts @@ -368,9 +368,7 @@ async function applyBeforeToolCallDecision( return { proceed: false } } - return { proceed: true, input: decision.args } - } /** diff --git a/packages/typescript/ai/tests/middleware.test.ts b/packages/typescript/ai/tests/middleware.test.ts index 5b0701ea8..61c9d6ea9 100644 --- a/packages/typescript/ai/tests/middleware.test.ts +++ b/packages/typescript/ai/tests/middleware.test.ts @@ -308,7 +308,10 @@ describe('chat() middleware', () => { const middleware: ChatMiddleware = { name: 'dropper', onChunk: (_ctx, chunk) => { - if (chunk.type === 'TEXT_MESSAGE_CONTENT' && chunk.delta === 'Hello') { + if ( + chunk.type === 'TEXT_MESSAGE_CONTENT' && + chunk.delta === 'Hello' + ) { return null } return undefined @@ -571,9 +574,7 @@ describe('chat() middleware', () => { role: string content?: unknown }> - const toolResultMsg = secondCallMessages.find( - (m) => m.role === 'tool', - ) + const toolResultMsg = secondCallMessages.find((m) => m.role === 'tool') expect(toolResultMsg).toBeDefined() }) @@ -1027,7 +1028,11 @@ describe('chat() middleware', () => { ev.toolEnd('tc-1', 'failTool', { input: {} }), ev.runFinished('tool_calls'), ], - [ev.runStarted(), ev.textContent('recovered'), ev.runFinished('stop')], + [ + ev.runStarted(), + ev.textContent('recovered'), + ev.runFinished('stop'), + ], ], }) @@ -1942,7 +1947,6 @@ describe('chat() middleware', () => { }) }) - // ========================================================================== // All hooks from a single middleware object // ========================================================================== diff --git a/packages/typescript/ai/tests/test-utils.ts b/packages/typescript/ai/tests/test-utils.ts index 66bfa5ba5..380ca3ac9 100644 --- a/packages/typescript/ai/tests/test-utils.ts +++ b/packages/typescript/ai/tests/test-utils.ts @@ -1,9 +1,5 @@ import type { AnyTextAdapter } from '../src/activities/chat/adapter' -import type { - StreamChunk, - TextMessageContentEvent, - Tool, -} from '../src/types' +import type { StreamChunk, TextMessageContentEvent, Tool } from '../src/types' // ============================================================================ // Chunk factory @@ -141,9 +137,7 @@ export async function collectChunks( // ============================================================================ /** Type guard for TEXT_MESSAGE_CONTENT chunks. */ -export function isTextContent( - c: StreamChunk, -): c is TextMessageContentEvent { +export function isTextContent(c: StreamChunk): c is TextMessageContentEvent { return c.type === 'TEXT_MESSAGE_CONTENT' } diff --git a/packages/typescript/ai/tests/tool-cache-middleware.test.ts b/packages/typescript/ai/tests/tool-cache-middleware.test.ts index 1e96a7cce..bed85fbe2 100644 --- a/packages/typescript/ai/tests/tool-cache-middleware.test.ts +++ b/packages/typescript/ai/tests/tool-cache-middleware.test.ts @@ -1,14 +1,12 @@ import { describe, expect, it, vi } from 'vitest' import { chat } from '../src/activities/chat/index' import { toolCacheMiddleware } from '../src/activities/chat/middleware/tool-cache-middleware' -import type { ToolCacheEntry, ToolCacheStorage } from '../src/activities/chat/middleware/tool-cache-middleware' +import type { + ToolCacheEntry, + ToolCacheStorage, +} from '../src/activities/chat/middleware/tool-cache-middleware' import type { StreamChunk } from '../src/types' -import { - ev, - createMockAdapter, - collectChunks, - serverTool, -} from './test-utils' +import { ev, createMockAdapter, collectChunks, serverTool } from './test-utils' // ============================================================================ // Tests @@ -188,7 +186,9 @@ describe('toolCacheMiddleware', () => { // Manually simulate the cache flow with controlled time // First tool call: cache miss, execute, store const beforeResult1 = await cacheMiddleware.onBeforeToolCall!( - {} as Parameters>[0], + {} as Parameters< + NonNullable + >[0], { toolCall: { id: 'tc-1', @@ -205,7 +205,9 @@ describe('toolCacheMiddleware', () => { // Simulate successful execution and store result await cacheMiddleware.onAfterToolCall!( - {} as Parameters>[0], + {} as Parameters< + NonNullable + >[0], { toolCall: { id: 'tc-1', @@ -222,7 +224,9 @@ describe('toolCacheMiddleware', () => { // Second call immediately — should hit cache const beforeResult2 = await cacheMiddleware.onBeforeToolCall!( - {} as Parameters>[0], + {} as Parameters< + NonNullable + >[0], { toolCall: { id: 'tc-2', @@ -241,7 +245,9 @@ describe('toolCacheMiddleware', () => { // Third call after TTL — should miss cache const beforeResult3 = await cacheMiddleware.onBeforeToolCall!( - {} as Parameters>[0], + {} as Parameters< + NonNullable + >[0], { toolCall: { id: 'tc-3', From 76a354e51825f011fcbd9f1e374132df3fe56d6c Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Thu, 26 Feb 2026 14:55:50 +0100 Subject: [PATCH 03/16] chore: fix --- packages/typescript/ai/tests/chat.test.ts | 26 +++++++++---------- .../ai/tests/tool-cache-middleware.test.ts | 14 ++++++---- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/packages/typescript/ai/tests/chat.test.ts b/packages/typescript/ai/tests/chat.test.ts index b4408f140..8ab8816b9 100644 --- a/packages/typescript/ai/tests/chat.test.ts +++ b/packages/typescript/ai/tests/chat.test.ts @@ -63,8 +63,8 @@ describe('chat()', () => { await collectChunks(stream as AsyncIterable) expect(calls).toHaveLength(1) - expect(calls[0].messages).toBeDefined() - expect(calls[0].messages[0].role).toBe('user') + expect(calls[0]!.messages).toBeDefined() + expect((calls[0]!.messages as Array<{ role: string }>)[0]!.role).toBe('user') }) it('should pass systemPrompts to the adapter', async () => { @@ -80,7 +80,7 @@ describe('chat()', () => { await collectChunks(stream as AsyncIterable) - expect(calls[0].systemPrompts).toEqual(['You are a helpful assistant']) + expect(calls[0]!.systemPrompts).toEqual(['You are a helpful assistant']) }) it('should pass temperature, topP, maxTokens to the adapter', async () => { @@ -98,9 +98,9 @@ describe('chat()', () => { await collectChunks(stream as AsyncIterable) - expect(calls[0].temperature).toBe(0.5) - expect(calls[0].topP).toBe(0.9) - expect(calls[0].maxTokens).toBe(100) + expect(calls[0]!.temperature).toBe(0.5) + expect(calls[0]!.topP).toBe(0.9) + expect(calls[0]!.maxTokens).toBe(100) }) }) @@ -221,9 +221,9 @@ describe('chat()', () => { expect(calls).toHaveLength(2) // Second call should have tool result in messages - const secondCallMessages = calls[1].messages + const secondCallMessages = calls[1]!.messages as Array<{ role: string }> const toolResultMsg = secondCallMessages.find( - (m: any) => m.role === 'tool', + (m) => m.role === 'tool', ) expect(toolResultMsg).toBeDefined() }) @@ -313,9 +313,9 @@ describe('chat()', () => { expect(timeSpy).toHaveBeenCalledTimes(1) // Second adapter call should have both tool results - const secondCallMessages = calls[1].messages + const secondCallMessages = calls[1]!.messages as Array<{ role: string }> const toolResultMsgs = secondCallMessages.filter( - (m: any) => m.role === 'tool', + (m) => m.role === 'tool', ) expect(toolResultMsgs).toHaveLength(2) }) @@ -477,8 +477,8 @@ describe('chat()', () => { // Adapter should have been called with the tool result in messages expect(calls).toHaveLength(1) - const adapterMessages = calls[0].messages - const toolMsg = adapterMessages.find((m: any) => m.role === 'tool') + const adapterMessages = calls[0]!.messages as Array<{ role: string }> + const toolMsg = adapterMessages.find((m) => m.role === 'tool') expect(toolMsg).toBeDefined() }) @@ -924,7 +924,7 @@ describe('chat()', () => { }) await collectChunks(stream as AsyncIterable) - expect(calls[0].modelOptions).toEqual({ customParam: 'value' }) + expect(calls[0]!.modelOptions).toEqual({ customParam: 'value' }) }) it('should handle TEXT_MESSAGE_CONTENT with content field', async () => { diff --git a/packages/typescript/ai/tests/tool-cache-middleware.test.ts b/packages/typescript/ai/tests/tool-cache-middleware.test.ts index bed85fbe2..9293ebd11 100644 --- a/packages/typescript/ai/tests/tool-cache-middleware.test.ts +++ b/packages/typescript/ai/tests/tool-cache-middleware.test.ts @@ -160,7 +160,7 @@ describe('toolCacheMiddleware', () => { return { data: callCount } }) - const { adapter } = createMockAdapter({ + const { adapter: _adapter } = createMockAdapter({ iterations: [ [ ev.runStarted(), @@ -192,6 +192,7 @@ describe('toolCacheMiddleware', () => { { toolCall: { id: 'tc-1', + type: 'function', function: { name: 'getData', arguments: '{}' }, }, tool: tool, @@ -211,6 +212,7 @@ describe('toolCacheMiddleware', () => { { toolCall: { id: 'tc-1', + type: 'function', function: { name: 'getData', arguments: '{}' }, }, tool: tool, @@ -230,6 +232,7 @@ describe('toolCacheMiddleware', () => { { toolCall: { id: 'tc-2', + type: 'function', function: { name: 'getData', arguments: '{}' }, }, tool: tool, @@ -251,6 +254,7 @@ describe('toolCacheMiddleware', () => { { toolCall: { id: 'tc-3', + type: 'function', function: { name: 'getData', arguments: '{}' }, }, tool: tool, @@ -268,10 +272,6 @@ describe('toolCacheMiddleware', () => { it('should respect maxSize and evict oldest entries', async () => { const results: Array = [] - const tool = serverTool('lookup', (args) => { - const key = (args as { key: string }).key - return { value: `result-${key}` } - }) const { adapter } = createMockAdapter({ iterations: [ @@ -569,6 +569,7 @@ describe('toolCacheMiddleware', () => { { toolCall: { id: 'tc-1', + type: 'function', function: { name: 'getData', arguments: '{}' }, }, tool: tool, @@ -591,6 +592,7 @@ describe('toolCacheMiddleware', () => { { toolCall: { id: 'tc-2', + type: 'function', function: { name: 'getData', arguments: '{}' }, }, tool: tool, @@ -627,6 +629,7 @@ describe('toolCacheMiddleware', () => { { toolCall: { id: 'tc-1', + type: 'function', function: { name: 'search', arguments: '{"q":"hello"}' }, }, tool: tool, @@ -647,6 +650,7 @@ describe('toolCacheMiddleware', () => { { toolCall: { id: 'tc-1', + type: 'function', function: { name: 'search', arguments: '{"q":"hello"}' }, }, tool: tool, From 82174a70ba137252a62697e9ee39cfdfb1542a4d Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Fri, 27 Feb 2026 12:08:20 +0100 Subject: [PATCH 04/16] Add test utilities and tool cache middleware tests - Introduced `test-utils.ts` with utility functions for creating mock chunks, adapters, and collecting stream chunks. - Added comprehensive tests for `tool-cache-middleware` in `tool-cache-middleware.test.ts` to verify caching behavior, including: - Caching tool results and skipping execution on cache hits. - Handling different arguments and ensuring cache misses. - Respecting tool name filters and TTL for cache entries. - Implementing custom storage solutions for cache management. - Validating that failed tool executions are not cached. --- .../ai/src/activities/chat/index.ts | 463 +++++------------- .../src/activities/chat/middleware/compose.ts | 32 ++ .../chat/middleware/devtools-middleware.ts | 289 +++++++++++ .../src/activities/chat/middleware/index.ts | 4 + .../src/activities/chat/middleware/types.ts | 108 ++++ packages/typescript/ai/src/index.ts | 7 +- 6 files changed, 574 insertions(+), 329 deletions(-) create mode 100644 packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index 16bc48071..e8f78f9e4 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -5,7 +5,6 @@ * This is a self-contained module with implementation, types, and JSDoc. */ -import { aiEventClient } from '../../event-client.js' import { streamToText } from '../../stream-to-response.js' import { MiddlewareAbortError, @@ -20,6 +19,7 @@ import { import { maxIterations as maxIterationsStrategy } from './agent-loop-strategies' import { convertMessagesToModelMessages } from './messages' import { MiddlewareRunner } from './middleware/compose' +import { devtoolsMiddleware } from './middleware/devtools-middleware' import type { ApprovalRequest, ClientToolRequest, @@ -254,7 +254,6 @@ class TextEngine< private eventOptions?: Record private eventToolNames?: Array private finishedEvent: RunFinishedEvent | null = null - private shouldEmitStreamEnd = true private earlyTermination = false private toolPhase: ToolPhaseResult = 'continue' private cyclePhase: CyclePhase = 'processText' @@ -301,8 +300,12 @@ class TextEngine< : undefined this.effectiveSignal = config.params.abortController?.signal - // Initialize middleware - this.middlewareRunner = new MiddlewareRunner(config.middleware || []) + // Initialize middleware — devtools middleware is always first + const allMiddleware = [ + devtoolsMiddleware(), + ...(config.middleware || []), + ] + this.middlewareRunner = new MiddlewareRunner(allMiddleware) this.middlewareAbortController = new AbortController() this.middlewareCtx = { requestId: this.requestId, @@ -320,6 +323,25 @@ class TextEngine< defer: (promise: Promise) => { this.deferredPromises.push(promise) }, + // Provider / adapter info + provider: config.adapter.name, + model: config.params.model, + source: 'server', + streaming: true, + // Config-derived (updated in beforeRun and applyMiddlewareConfig) + systemPrompts: this.systemPrompts, + toolNames: undefined, + options: undefined, + modelOptions: config.params.modelOptions, + // Computed + messageCount: this.initialMessageCount, + hasTools: this.tools.length > 0, + // Mutable per-iteration + currentMessageId: null, + accumulatedContent: '', + // References + messages: this.messages, + createId: (prefix: string) => this.createId(prefix), } } @@ -338,18 +360,16 @@ class TextEngine< try { // Run initial onConfig (phase = init) - if (this.middlewareRunner.hasMiddleware) { - this.middlewareCtx.phase = 'init' - const initialConfig = this.buildMiddlewareConfig() - const transformedConfig = await this.middlewareRunner.runOnConfig( - this.middlewareCtx, - initialConfig, - ) - this.applyMiddlewareConfig(transformedConfig) + this.middlewareCtx.phase = 'init' + const initialConfig = this.buildMiddlewareConfig() + const transformedConfig = await this.middlewareRunner.runOnConfig( + this.middlewareCtx, + initialConfig, + ) + this.applyMiddlewareConfig(transformedConfig) - // Run onStart - await this.middlewareRunner.runOnStart(this.middlewareCtx) - } + // Run onStart (devtools middleware emits text:request:started and initial messages here) + await this.middlewareRunner.runOnStart(this.middlewareCtx) const pendingPhase = yield* this.checkForPendingToolCalls() if (pendingPhase === 'wait') { @@ -361,20 +381,18 @@ class TextEngine< return } - this.beginCycle() + await this.beginCycle() if (this.cyclePhase === 'processText') { // Run onConfig before each model call (phase = beforeModel) - if (this.middlewareRunner.hasMiddleware) { - this.middlewareCtx.phase = 'beforeModel' - this.middlewareCtx.iteration = this.iterationCount - const iterConfig = this.buildMiddlewareConfig() - const transformedConfig = await this.middlewareRunner.runOnConfig( - this.middlewareCtx, - iterConfig, - ) - this.applyMiddlewareConfig(transformedConfig) - } + this.middlewareCtx.phase = 'beforeModel' + this.middlewareCtx.iteration = this.iterationCount + const iterConfig = this.buildMiddlewareConfig() + const transformedConfig = await this.middlewareRunner.runOnConfig( + this.middlewareCtx, + iterConfig, + ) + this.applyMiddlewareConfig(transformedConfig) yield* this.streamModelResponse() } else { @@ -384,8 +402,8 @@ class TextEngine< this.endCycle() } while (this.shouldContinue()) - // Call terminal onFinish hook - if (this.middlewareRunner.hasMiddleware && !this.terminalHookCalled) { + // Call terminal onFinish hook (skip when waiting for client — stream is paused, not finished) + if (!this.terminalHookCalled && this.toolPhase !== 'wait') { this.terminalHookCalled = true await this.middlewareRunner.runOnFinish(this.middlewareCtx, { finishReason: this.lastFinishReason, @@ -395,7 +413,7 @@ class TextEngine< }) } } catch (error: unknown) { - if (this.middlewareRunner.hasMiddleware && !this.terminalHookCalled) { + if (!this.terminalHookCalled) { this.terminalHookCalled = true if (error instanceof MiddlewareAbortError) { // Middleware abort decision — call onAbort, not onError @@ -418,11 +436,7 @@ class TextEngine< } } finally { // Check for abort terminal hook - if ( - this.middlewareRunner.hasMiddleware && - !this.terminalHookCalled && - this.isCancelled() - ) { + if (!this.terminalHookCalled && this.isCancelled()) { this.terminalHookCalled = true await this.middlewareRunner.runOnAbort(this.middlewareCtx, { reason: this.abortReason, @@ -430,8 +444,6 @@ class TextEngine< }) } - this.afterRun() - // Await deferred promises (non-blocking side effects) if (this.deferredPromises.length > 0) { await Promise.allSettled(this.deferredPromises) @@ -443,7 +455,7 @@ class TextEngine< this.streamStartTime = Date.now() const { tools, temperature, topP, maxTokens, metadata } = this.params - // Gather flattened options into an object for event emission + // Gather flattened options into an object for context const options: Record = {} if (temperature !== undefined) options.temperature = temperature if (topP !== undefined) options.topP = topP @@ -453,70 +465,14 @@ class TextEngine< this.eventOptions = Object.keys(options).length > 0 ? options : undefined this.eventToolNames = tools?.map((t) => t.name) - aiEventClient.emit('text:request:started', { - ...this.buildTextEventContext(), - timestamp: Date.now(), - }) - - // Always emit messages for tracking: - // - For existing conversations (with conversationId): only emit the latest user message - // - For new conversations (without conversationId): emit all messages for reconstruction - const messagesToEmit = this.params.conversationId - ? this.messages.slice(-1).filter((m) => m.role === 'user') - : this.messages - - messagesToEmit.forEach((message, index) => { - const messageIndex = this.params.conversationId - ? this.messages.length - 1 - : index - const messageId = this.createId('msg') - const baseContext = this.buildTextEventContext() - const content = this.getContentString(message.content) - - aiEventClient.emit('text:message:created', { - ...baseContext, - messageId, - role: message.role, - content, - toolCalls: message.toolCalls, - messageIndex, - timestamp: Date.now(), - }) - - if (message.role === 'user') { - aiEventClient.emit('text:message:user', { - ...baseContext, - messageId, - role: 'user', - content, - messageIndex, - timestamp: Date.now(), - }) - } - }) + // Update middleware context with computed fields + this.middlewareCtx.options = this.eventOptions + this.middlewareCtx.toolNames = this.eventToolNames } - private afterRun(): void { - if (!this.shouldEmitStreamEnd) { - return - } - - const now = Date.now() - // Emit text:request:completed with final state - aiEventClient.emit('text:request:completed', { - ...this.buildTextEventContext(), - content: this.accumulatedContent, - messageId: this.currentMessageId || undefined, - finishReason: this.lastFinishReason || undefined, - usage: this.finishedEvent?.usage, - duration: now - this.streamStartTime, - timestamp: now, - }) - } - - private beginCycle(): void { + private async beginCycle(): Promise { if (this.cyclePhase === 'processText') { - this.beginIteration() + await this.beginIteration() } } @@ -530,18 +486,19 @@ class TextEngine< this.iterationCount++ } - private beginIteration(): void { + private async beginIteration(): Promise { this.currentMessageId = this.createId('msg') this.accumulatedContent = '' this.finishedEvent = null - const baseContext = this.buildTextEventContext() - aiEventClient.emit('text:message:created', { - ...baseContext, + // Update mutable context fields + this.middlewareCtx.currentMessageId = this.currentMessageId + this.middlewareCtx.accumulatedContent = '' + + // Notify middleware of new iteration (devtools emits assistant message:created here) + await this.middlewareRunner.runOnIteration(this.middlewareCtx, { + iteration: this.iterationCount, messageId: this.currentMessageId, - role: 'assistant', - content: '', - timestamp: Date.now(), }) } @@ -560,9 +517,7 @@ class TextEngine< : undefined, })) - if (this.middlewareRunner.hasMiddleware) { - this.middlewareCtx.phase = 'modelStream' - } + this.middlewareCtx.phase = 'modelStream' for await (const chunk of this.adapter.chatStream({ model: this.params.model, @@ -582,28 +537,23 @@ class TextEngine< this.totalChunkCount++ - // Pipe chunk through middleware - if (this.middlewareRunner.hasMiddleware) { - const outputChunks = await this.middlewareRunner.runOnChunk( + // Pipe chunk through middleware (devtools middleware observes and emits events) + const outputChunks = await this.middlewareRunner.runOnChunk( + this.middlewareCtx, + chunk, + ) + for (const outputChunk of outputChunks) { + yield outputChunk + this.handleStreamChunk(outputChunk) + this.middlewareCtx.chunkIndex++ + } + + // Handle usage via middleware + if (chunk.type === 'RUN_FINISHED' && chunk.usage) { + await this.middlewareRunner.runOnUsage( this.middlewareCtx, - chunk, + chunk.usage, ) - for (const outputChunk of outputChunks) { - yield outputChunk - this.handleStreamChunk(outputChunk) - this.middlewareCtx.chunkIndex++ - } - - // Handle usage via middleware - if (chunk.type === 'RUN_FINISHED' && chunk.usage) { - await this.middlewareRunner.runOnUsage( - this.middlewareCtx, - chunk.usage, - ) - } - } else { - yield chunk - this.handleStreamChunk(chunk) } if (this.earlyTermination) { @@ -655,100 +605,35 @@ class TextEngine< } else { this.accumulatedContent += chunk.delta } - aiEventClient.emit('text:chunk:content', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - content: this.accumulatedContent, - delta: chunk.delta, - timestamp: Date.now(), - }) } private handleToolCallStartEvent(chunk: ToolCallStartEvent): void { this.toolCallManager.addToolCallStartEvent(chunk) - aiEventClient.emit('text:chunk:tool-call', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - toolCallId: chunk.toolCallId, - toolName: chunk.toolName, - index: chunk.index ?? 0, - arguments: '', - timestamp: Date.now(), - }) } private handleToolCallArgsEvent(chunk: ToolCallArgsEvent): void { this.toolCallManager.addToolCallArgsEvent(chunk) - aiEventClient.emit('text:chunk:tool-call', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - toolCallId: chunk.toolCallId, - toolName: '', - index: 0, - arguments: chunk.delta, - timestamp: Date.now(), - }) } private handleToolCallEndEvent(chunk: ToolCallEndEvent): void { this.toolCallManager.completeToolCall(chunk) - aiEventClient.emit('text:chunk:tool-result', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - toolCallId: chunk.toolCallId, - result: chunk.result || '', - timestamp: Date.now(), - }) } private handleRunFinishedEvent(chunk: RunFinishedEvent): void { - aiEventClient.emit('text:chunk:done', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - finishReason: chunk.finishReason, - usage: chunk.usage, - timestamp: Date.now(), - }) - - if (chunk.usage) { - aiEventClient.emit('text:usage', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - usage: chunk.usage, - timestamp: Date.now(), - }) - } - this.finishedEvent = chunk this.lastFinishReason = chunk.finishReason } private handleRunErrorEvent( - chunk: Extract, + _chunk: Extract, ): void { - aiEventClient.emit('text:chunk:error', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - error: chunk.error.message, - timestamp: Date.now(), - }) this.earlyTermination = true - this.shouldEmitStreamEnd = false } private handleStepFinishedEvent( - chunk: Extract, + _chunk: Extract, ): void { - // Handle thinking/reasoning content from STEP_FINISHED events - if (chunk.content || chunk.delta) { - aiEventClient.emit('text:chunk:thinking', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - content: chunk.content || '', - delta: chunk.delta, - timestamp: Date.now(), - }) - } + // State tracking for STEP_FINISHED is handled by middleware } private async *checkForPendingToolCalls(): AsyncGenerator< @@ -776,29 +661,37 @@ class TextEngine< // Consume the async generator, yielding custom events and collecting the return value const executionResult = yield* this.drainToolCallGenerator(generator) + // Notify middleware of tool phase completion (devtools emits aggregate events here) + await this.middlewareRunner.runOnToolPhaseComplete(this.middlewareCtx, { + toolCalls: pendingToolCalls, + results: executionResult.results, + needsApproval: executionResult.needsApproval, + needsClientExecution: executionResult.needsClientExecution, + }) + if ( executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0 ) { - for (const chunk of this.emitApprovalRequests( + for (const chunk of this.buildApprovalChunks( executionResult.needsApproval, finishEvent, )) { yield chunk } - for (const chunk of this.emitClientToolInputs( + for (const chunk of this.buildClientToolChunks( executionResult.needsClientExecution, finishEvent, )) { yield chunk } - this.shouldEmitStreamEnd = false + this.setToolPhase('wait') return 'wait' } - const toolResultChunks = this.emitToolResults( + const toolResultChunks = this.buildToolResultChunks( executionResult.results, finishEvent, ) @@ -826,9 +719,7 @@ class TextEngine< this.addAssistantToolCallMessage(toolCalls) - if (this.middlewareRunner.hasMiddleware) { - this.middlewareCtx.phase = 'beforeTools' - } + this.middlewareCtx.phase = 'beforeTools' const { approvals, clientToolResults } = this.collectClientState() @@ -838,37 +729,33 @@ class TextEngine< approvals, clientToolResults, (eventName, data) => this.createCustomEventChunk(eventName, data), - this.middlewareRunner.hasMiddleware - ? { - onBeforeToolCall: async (toolCall, tool, args) => { - const hookCtx = { - toolCall, - tool, - args, - toolName: toolCall.function.name, - toolCallId: toolCall.id, - } - return this.middlewareRunner.runOnBeforeToolCall( - this.middlewareCtx, - hookCtx, - ) - }, - onAfterToolCall: async (info) => { - await this.middlewareRunner.runOnAfterToolCall( - this.middlewareCtx, - info, - ) - }, + { + onBeforeToolCall: async (toolCall, tool, args) => { + const hookCtx = { + toolCall, + tool, + args, + toolName: toolCall.function.name, + toolCallId: toolCall.id, } - : undefined, + return this.middlewareRunner.runOnBeforeToolCall( + this.middlewareCtx, + hookCtx, + ) + }, + onAfterToolCall: async (info) => { + await this.middlewareRunner.runOnAfterToolCall( + this.middlewareCtx, + info, + ) + }, + }, ) // Consume the async generator, yielding custom events and collecting the return value const executionResult = yield* this.drainToolCallGenerator(generator) - if (this.middlewareRunner.hasMiddleware) { - this.middlewareCtx.phase = 'afterTools' - } + this.middlewareCtx.phase = 'afterTools' // Check if middleware aborted during tool execution if (this.isMiddlewareAborted()) { @@ -876,18 +763,26 @@ class TextEngine< return } + // Notify middleware of tool phase completion (devtools emits aggregate events here) + await this.middlewareRunner.runOnToolPhaseComplete(this.middlewareCtx, { + toolCalls, + results: executionResult.results, + needsApproval: executionResult.needsApproval, + needsClientExecution: executionResult.needsClientExecution, + }) + if ( executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0 ) { - for (const chunk of this.emitApprovalRequests( + for (const chunk of this.buildApprovalChunks( executionResult.needsApproval, finishEvent, )) { yield chunk } - for (const chunk of this.emitClientToolInputs( + for (const chunk of this.buildClientToolChunks( executionResult.needsClientExecution, finishEvent, )) { @@ -898,7 +793,7 @@ class TextEngine< return } - const toolResultChunks = this.emitToolResults( + const toolResultChunks = this.buildToolResultChunks( executionResult.results, finishEvent, ) @@ -921,7 +816,6 @@ class TextEngine< } private addAssistantToolCallMessage(toolCalls: Array): void { - const messageId = this.currentMessageId ?? this.createId('msg') this.messages = [ ...this.messages, { @@ -930,15 +824,6 @@ class TextEngine< toolCalls, }, ] - - aiEventClient.emit('text:message:created', { - ...this.buildTextEventContext(), - messageId, - role: 'assistant', - content: this.accumulatedContent || '', - toolCalls, - timestamp: Date.now(), - }) } /** @@ -1019,24 +904,13 @@ class TextEngine< return { approvals, clientToolResults } } - private emitApprovalRequests( + private buildApprovalChunks( approvals: Array, finishEvent: RunFinishedEvent, ): Array { const chunks: Array = [] for (const approval of approvals) { - aiEventClient.emit('tools:approval:requested', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - toolCallId: approval.toolCallId, - toolName: approval.toolName, - input: approval.input, - approvalId: approval.approvalId, - timestamp: Date.now(), - }) - - // Emit a CUSTOM event for approval requests chunks.push({ type: 'CUSTOM', timestamp: Date.now(), @@ -1057,23 +931,13 @@ class TextEngine< return chunks } - private emitClientToolInputs( + private buildClientToolChunks( clientRequests: Array, finishEvent: RunFinishedEvent, ): Array { const chunks: Array = [] for (const clientTool of clientRequests) { - aiEventClient.emit('tools:input:available', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - toolCallId: clientTool.toolCallId, - toolName: clientTool.toolName, - input: clientTool.input, - timestamp: Date.now(), - }) - - // Emit a CUSTOM event for client tool inputs chunks.push({ type: 'CUSTOM', timestamp: Date.now(), @@ -1090,26 +954,15 @@ class TextEngine< return chunks } - private emitToolResults( + private buildToolResultChunks( results: Array, finishEvent: RunFinishedEvent, ): Array { const chunks: Array = [] for (const result of results) { - aiEventClient.emit('tools:call:completed', { - ...this.buildTextEventContext(), - messageId: this.currentMessageId || undefined, - toolCallId: result.toolCallId, - toolName: result.toolName, - result: result.result, - duration: result.duration ?? 0, - timestamp: Date.now(), - }) - const content = JSON.stringify(result.result) - // Emit TOOL_CALL_END event chunks.push({ type: 'TOOL_CALL_END', timestamp: Date.now(), @@ -1127,14 +980,6 @@ class TextEngine< toolCallId: result.toolCallId, }, ] - - aiEventClient.emit('text:message:created', { - ...this.buildTextEventContext(), - messageId: this.createId('msg'), - role: 'tool', - content, - timestamp: Date.now(), - }) } return chunks @@ -1243,55 +1088,17 @@ class TextEngine< metadata: config.metadata, modelOptions: config.modelOptions, } - } - - private buildTextEventContext(): { - requestId: string - streamId: string - provider: string - model: string - clientId?: string - source?: 'client' | 'server' - systemPrompts?: Array - toolNames?: Array - options?: Record - modelOptions?: Record - messageCount: number - hasTools: boolean - streaming: boolean - } { - return { - requestId: this.requestId, - streamId: this.streamId, - provider: this.adapter.name, - model: this.params.model, - clientId: this.params.conversationId, - source: 'server', - systemPrompts: - this.systemPrompts.length > 0 ? this.systemPrompts : undefined, - toolNames: this.eventToolNames, - options: this.eventOptions, - modelOptions: this.params.modelOptions, - messageCount: this.initialMessageCount, - hasTools: this.tools.length > 0, - streaming: true, - } - } - private getContentString(content: ModelMessage['content']): string { - if (typeof content === 'string') return content - const text = - content - ?.map((part) => (part.type === 'text' ? part.content : '')) - .join('') || '' - return text + // Sync context fields that depend on config + this.middlewareCtx.messages = this.messages + this.middlewareCtx.systemPrompts = this.systemPrompts + this.middlewareCtx.hasTools = this.tools.length > 0 + this.middlewareCtx.toolNames = this.tools.map((t) => t.name) + this.middlewareCtx.modelOptions = config.modelOptions } private setToolPhase(phase: ToolPhaseResult): void { this.toolPhase = phase - if (phase === 'wait') { - this.shouldEmitStreamEnd = false - } } /** diff --git a/packages/typescript/ai/src/activities/chat/middleware/compose.ts b/packages/typescript/ai/src/activities/chat/middleware/compose.ts index 025caeb8f..d8d2df4fc 100644 --- a/packages/typescript/ai/src/activities/chat/middleware/compose.ts +++ b/packages/typescript/ai/src/activities/chat/middleware/compose.ts @@ -8,7 +8,9 @@ import type { ChatMiddlewareContext, ErrorInfo, FinishInfo, + IterationInfo, ToolCallHookContext, + ToolPhaseCompleteInfo, UsageInfo, } from './types' @@ -182,4 +184,34 @@ export class MiddlewareRunner { } } } + + /** + * Run onIteration on all middleware in order. + * Called at the start of each agent loop iteration. + */ + async runOnIteration( + ctx: ChatMiddlewareContext, + info: IterationInfo, + ): Promise { + for (const mw of this.middlewares) { + if (mw.onIteration) { + await mw.onIteration(ctx, info) + } + } + } + + /** + * Run onToolPhaseComplete on all middleware in order. + * Called after all tool calls in an iteration have been processed. + */ + async runOnToolPhaseComplete( + ctx: ChatMiddlewareContext, + info: ToolPhaseCompleteInfo, + ): Promise { + for (const mw of this.middlewares) { + if (mw.onToolPhaseComplete) { + await mw.onToolPhaseComplete(ctx, info) + } + } + } } diff --git a/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts b/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts new file mode 100644 index 000000000..6eb98b371 --- /dev/null +++ b/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts @@ -0,0 +1,289 @@ +import { aiEventClient } from '../../../event-client.js' +import type { ModelMessage } from '../../../types' +import type { + ChatMiddleware, + ChatMiddlewareContext, + IterationInfo, + ToolPhaseCompleteInfo, +} from './types' + +/** + * Build the common event context object used by all devtools events. + */ +function buildEventContext(ctx: ChatMiddlewareContext) { + return { + requestId: ctx.requestId, + streamId: ctx.streamId, + provider: ctx.provider, + model: ctx.model, + clientId: ctx.conversationId, + source: ctx.source, + systemPrompts: + ctx.systemPrompts.length > 0 ? ctx.systemPrompts : undefined, + toolNames: ctx.toolNames, + options: ctx.options, + modelOptions: ctx.modelOptions, + messageCount: ctx.messageCount, + hasTools: ctx.hasTools, + streaming: ctx.streaming, + } +} + +/** + * Extract text content from a ModelMessage content field. + */ +function getContentString(content: ModelMessage['content']): string { + if (typeof content === 'string') return content + return ( + content + ?.map((part) => (part.type === 'text' ? part.content : '')) + .join('') || '' + ) +} + +/** + * Internal devtools middleware that emits all DevTools events. + * Auto-injected as the FIRST middleware in the TextEngine. + * + * All hooks are observation-only — `onChunk` returns void to pass through + * without transforming chunks. + */ +export function devtoolsMiddleware(): ChatMiddleware { + // Local mutable state — tracked here because the devtools middleware + // runs first, before the engine updates ctx.currentMessageId / ctx.accumulatedContent + let localMessageId: string | null = null + let localAccumulatedContent = '' + + return { + name: 'devtools', + + onStart(ctx) { + // Emit text:request:started + aiEventClient.emit('text:request:started', { + ...buildEventContext(ctx), + timestamp: Date.now(), + }) + + // Emit text:message:created for initial messages + const messages = ctx.messages + const messagesToEmit = ctx.conversationId + ? messages.slice(-1).filter((m) => m.role === 'user') + : messages + + messagesToEmit.forEach((message, index) => { + const messageIndex = ctx.conversationId + ? messages.length - 1 + : index + const messageId = ctx.createId('msg') + const base = buildEventContext(ctx) + const content = getContentString(message.content) + + aiEventClient.emit('text:message:created', { + ...base, + messageId, + role: message.role, + content, + toolCalls: message.toolCalls, + messageIndex, + timestamp: Date.now(), + }) + + if (message.role === 'user') { + aiEventClient.emit('text:message:user', { + ...base, + messageId, + role: 'user' as const, + content, + messageIndex, + timestamp: Date.now(), + }) + } + }) + }, + + onIteration(ctx: ChatMiddlewareContext, info: IterationInfo) { + localMessageId = info.messageId + localAccumulatedContent = '' + + aiEventClient.emit('text:message:created', { + ...buildEventContext(ctx), + messageId: info.messageId, + role: 'assistant' as const, + content: '', + timestamp: Date.now(), + }) + }, + + onChunk(ctx, chunk) { + const base = buildEventContext(ctx) + + switch (chunk.type) { + case 'TEXT_MESSAGE_CONTENT': { + if (chunk.content) { + localAccumulatedContent = chunk.content + } else { + localAccumulatedContent += chunk.delta + } + aiEventClient.emit('text:chunk:content', { + ...base, + messageId: localMessageId || undefined, + content: localAccumulatedContent, + delta: chunk.delta, + timestamp: Date.now(), + }) + break + } + case 'TOOL_CALL_START': { + aiEventClient.emit('text:chunk:tool-call', { + ...base, + messageId: localMessageId || undefined, + toolCallId: chunk.toolCallId, + toolName: chunk.toolName, + index: chunk.index ?? 0, + arguments: '', + timestamp: Date.now(), + }) + break + } + case 'TOOL_CALL_ARGS': { + aiEventClient.emit('text:chunk:tool-call', { + ...base, + messageId: localMessageId || undefined, + toolCallId: chunk.toolCallId, + toolName: '', + index: 0, + arguments: chunk.delta, + timestamp: Date.now(), + }) + break + } + case 'TOOL_CALL_END': { + aiEventClient.emit('text:chunk:tool-result', { + ...base, + messageId: localMessageId || undefined, + toolCallId: chunk.toolCallId, + result: chunk.result || '', + timestamp: Date.now(), + }) + break + } + case 'RUN_FINISHED': { + aiEventClient.emit('text:chunk:done', { + ...base, + messageId: localMessageId || undefined, + finishReason: chunk.finishReason, + usage: chunk.usage, + timestamp: Date.now(), + }) + if (chunk.usage) { + aiEventClient.emit('text:usage', { + ...base, + messageId: localMessageId || undefined, + usage: chunk.usage, + timestamp: Date.now(), + }) + } + break + } + case 'RUN_ERROR': { + aiEventClient.emit('text:chunk:error', { + ...base, + messageId: localMessageId || undefined, + error: chunk.error.message, + timestamp: Date.now(), + }) + break + } + case 'STEP_FINISHED': { + if (chunk.content || chunk.delta) { + aiEventClient.emit('text:chunk:thinking', { + ...base, + messageId: localMessageId || undefined, + content: chunk.content || '', + delta: chunk.delta, + timestamp: Date.now(), + }) + } + break + } + } + + // Return void — observation only, pass through unchanged + }, + + onToolPhaseComplete(ctx, info: ToolPhaseCompleteInfo) { + const base = buildEventContext(ctx) + + // Emit text:message:created for assistant message with tool calls + if (info.toolCalls.length > 0) { + aiEventClient.emit('text:message:created', { + ...base, + messageId: localMessageId ?? ctx.createId('msg'), + role: 'assistant' as const, + content: localAccumulatedContent || '', + toolCalls: info.toolCalls, + timestamp: Date.now(), + }) + } + + // Emit tools:approval:requested for each pending approval + for (const approval of info.needsApproval) { + aiEventClient.emit('tools:approval:requested', { + ...base, + messageId: localMessageId || undefined, + toolCallId: approval.toolCallId, + toolName: approval.toolName, + input: approval.input, + approvalId: approval.approvalId, + timestamp: Date.now(), + }) + } + + // Emit tools:input:available for each client tool + for (const clientTool of info.needsClientExecution) { + aiEventClient.emit('tools:input:available', { + ...base, + messageId: localMessageId || undefined, + toolCallId: clientTool.toolCallId, + toolName: clientTool.toolName, + input: clientTool.input, + timestamp: Date.now(), + }) + } + + // Emit tools:call:completed and text:message:created (tool role) for each result + for (const result of info.results) { + aiEventClient.emit('tools:call:completed', { + ...base, + messageId: localMessageId || undefined, + toolCallId: result.toolCallId, + toolName: result.toolName, + result: result.result, + duration: result.duration ?? 0, + timestamp: Date.now(), + }) + + const content = JSON.stringify(result.result) + aiEventClient.emit('text:message:created', { + ...base, + messageId: ctx.createId('msg'), + role: 'tool' as const, + content, + timestamp: Date.now(), + }) + } + }, + + onFinish(ctx, info) { + aiEventClient.emit('text:request:completed', { + ...buildEventContext(ctx), + content: info.content, + messageId: localMessageId || undefined, + finishReason: info.finishReason || undefined, + usage: info.usage, + duration: info.duration, + timestamp: Date.now(), + }) + }, + } +} diff --git a/packages/typescript/ai/src/activities/chat/middleware/index.ts b/packages/typescript/ai/src/activities/chat/middleware/index.ts index d2175bb14..9020a248a 100644 --- a/packages/typescript/ai/src/activities/chat/middleware/index.ts +++ b/packages/typescript/ai/src/activities/chat/middleware/index.ts @@ -6,6 +6,8 @@ export type { ToolCallHookContext, BeforeToolCallDecision, AfterToolCallInfo, + IterationInfo, + ToolPhaseCompleteInfo, UsageInfo, FinishInfo, AbortInfo, @@ -14,6 +16,8 @@ export type { export { MiddlewareRunner } from './compose' +export { devtoolsMiddleware } from './devtools-middleware' + export { toolCacheMiddleware } from './tool-cache-middleware' export type { ToolCacheMiddlewareOptions, diff --git a/packages/typescript/ai/src/activities/chat/middleware/types.ts b/packages/typescript/ai/src/activities/chat/middleware/types.ts index 3f4d8cb33..19ce04586 100644 --- a/packages/typescript/ai/src/activities/chat/middleware/types.ts +++ b/packages/typescript/ai/src/activities/chat/middleware/types.ts @@ -48,6 +48,49 @@ export interface ChatMiddlewareContext { * after the terminal hook (onFinish/onAbort/onError). */ defer: (promise: Promise) => void + + // --- Provider / adapter info (immutable for the lifetime of the request) --- + + /** Provider name (e.g., 'openai', 'anthropic') */ + provider: string + /** Model identifier (e.g., 'gpt-4o') */ + model: string + /** Source of the chat invocation — always 'server' for server-side chat */ + source: 'client' | 'server' + /** Whether the chat is streaming */ + streaming: boolean + + // --- Config-derived info (may update per-iteration via onConfig) --- + + /** System prompts configured for this chat */ + systemPrompts: Array + /** Names of configured tools, if any */ + toolNames?: Array + /** Flattened generation options (temperature, topP, maxTokens, metadata) */ + options?: Record + /** Provider-specific model options */ + modelOptions?: Record + + // --- Computed info --- + + /** Number of messages at the start of the request */ + messageCount: number + /** Whether tools are configured */ + hasTools: boolean + + // --- Mutable per-iteration state --- + + /** Current assistant message ID (changes per iteration) */ + currentMessageId: string | null + /** Accumulated text content for the current iteration */ + accumulatedContent: string + + // --- References --- + + /** Current messages array (read-only view) */ + messages: ReadonlyArray + /** Generate a unique ID with the given prefix */ + createId: (prefix: string) => string } // =========================== @@ -126,6 +169,53 @@ export interface AfterToolCallInfo { error?: unknown } +// =========================== +// Iteration Info +// =========================== + +/** + * Information passed to onIteration at the start of each agent loop iteration. + */ +export interface IterationInfo { + /** 0-based iteration index */ + iteration: number + /** The assistant message ID created for this iteration */ + messageId: string +} + +// =========================== +// Tool Phase Complete Info +// =========================== + +/** + * Aggregate information passed to onToolPhaseComplete after all tool calls + * in an iteration have been processed. + */ +export interface ToolPhaseCompleteInfo { + /** Tool calls that were assigned to the assistant message */ + toolCalls: Array + /** Completed tool results */ + results: Array<{ + toolCallId: string + toolName: string + result: unknown + duration?: number + }> + /** Tools that need user approval */ + needsApproval: Array<{ + toolCallId: string + toolName: string + input: unknown + approvalId: string + }> + /** Tools that need client-side execution */ + needsClientExecution: Array<{ + toolCallId: string + toolName: string + input: unknown + }> +} + // =========================== // Usage Info // =========================== @@ -240,6 +330,15 @@ export interface ChatMiddleware { */ onStart?: (ctx: ChatMiddlewareContext) => void | Promise + /** + * Called at the start of each agent loop iteration, after a new assistant message ID + * is created. Use this to observe iteration boundaries. + */ + onIteration?: ( + ctx: ChatMiddlewareContext, + info: IterationInfo, + ) => void | Promise + /** * Called for every chunk yielded by chat(). * Can observe, transform, expand, or drop chunks. @@ -273,6 +372,15 @@ export interface ChatMiddleware { info: AfterToolCallInfo, ) => void | Promise + /** + * Called after all tool calls in an iteration have been processed. + * Provides aggregate data about tool execution results, approvals, and client tools. + */ + onToolPhaseComplete?: ( + ctx: ChatMiddlewareContext, + info: ToolPhaseCompleteInfo, + ) => void | Promise + /** * Called when usage data is available from a RUN_FINISHED chunk. * Called once per model iteration that reports usage. diff --git a/packages/typescript/ai/src/index.ts b/packages/typescript/ai/src/index.ts index d44ed4fec..8fe3456fa 100644 --- a/packages/typescript/ai/src/index.ts +++ b/packages/typescript/ai/src/index.ts @@ -79,6 +79,8 @@ export type { ToolCallHookContext, BeforeToolCallDecision, AfterToolCallInfo, + IterationInfo, + ToolPhaseCompleteInfo, UsageInfo, FinishInfo, AbortInfo, @@ -88,7 +90,10 @@ export type { ToolCacheEntry, } from './activities/chat/middleware/index' -export { toolCacheMiddleware } from './activities/chat/middleware/index' +export { + devtoolsMiddleware, + toolCacheMiddleware, +} from './activities/chat/middleware/index' // All types export * from './types' From a6ece423904d1bf6de80bf042ed73d0fee8eb600 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:09:27 +0000 Subject: [PATCH 05/16] ci: apply automated fixes --- packages/typescript/ai/src/activities/chat/index.ts | 10 ++-------- .../chat/middleware/devtools-middleware.ts | 7 ++----- packages/typescript/ai/tests/chat.test.ts | 12 +++++------- 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index e8f78f9e4..4c9ddb28d 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -301,10 +301,7 @@ class TextEngine< this.effectiveSignal = config.params.abortController?.signal // Initialize middleware — devtools middleware is always first - const allMiddleware = [ - devtoolsMiddleware(), - ...(config.middleware || []), - ] + const allMiddleware = [devtoolsMiddleware(), ...(config.middleware || [])] this.middlewareRunner = new MiddlewareRunner(allMiddleware) this.middlewareAbortController = new AbortController() this.middlewareCtx = { @@ -550,10 +547,7 @@ class TextEngine< // Handle usage via middleware if (chunk.type === 'RUN_FINISHED' && chunk.usage) { - await this.middlewareRunner.runOnUsage( - this.middlewareCtx, - chunk.usage, - ) + await this.middlewareRunner.runOnUsage(this.middlewareCtx, chunk.usage) } if (this.earlyTermination) { diff --git a/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts b/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts index 6eb98b371..f6e1e2423 100644 --- a/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts +++ b/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts @@ -18,8 +18,7 @@ function buildEventContext(ctx: ChatMiddlewareContext) { model: ctx.model, clientId: ctx.conversationId, source: ctx.source, - systemPrompts: - ctx.systemPrompts.length > 0 ? ctx.systemPrompts : undefined, + systemPrompts: ctx.systemPrompts.length > 0 ? ctx.systemPrompts : undefined, toolNames: ctx.toolNames, options: ctx.options, modelOptions: ctx.modelOptions, @@ -71,9 +70,7 @@ export function devtoolsMiddleware(): ChatMiddleware { : messages messagesToEmit.forEach((message, index) => { - const messageIndex = ctx.conversationId - ? messages.length - 1 - : index + const messageIndex = ctx.conversationId ? messages.length - 1 : index const messageId = ctx.createId('msg') const base = buildEventContext(ctx) const content = getContentString(message.content) diff --git a/packages/typescript/ai/tests/chat.test.ts b/packages/typescript/ai/tests/chat.test.ts index 8ab8816b9..705e79fa0 100644 --- a/packages/typescript/ai/tests/chat.test.ts +++ b/packages/typescript/ai/tests/chat.test.ts @@ -64,7 +64,9 @@ describe('chat()', () => { expect(calls).toHaveLength(1) expect(calls[0]!.messages).toBeDefined() - expect((calls[0]!.messages as Array<{ role: string }>)[0]!.role).toBe('user') + expect((calls[0]!.messages as Array<{ role: string }>)[0]!.role).toBe( + 'user', + ) }) it('should pass systemPrompts to the adapter', async () => { @@ -222,9 +224,7 @@ describe('chat()', () => { // Second call should have tool result in messages const secondCallMessages = calls[1]!.messages as Array<{ role: string }> - const toolResultMsg = secondCallMessages.find( - (m) => m.role === 'tool', - ) + const toolResultMsg = secondCallMessages.find((m) => m.role === 'tool') expect(toolResultMsg).toBeDefined() }) @@ -314,9 +314,7 @@ describe('chat()', () => { // Second adapter call should have both tool results const secondCallMessages = calls[1]!.messages as Array<{ role: string }> - const toolResultMsgs = secondCallMessages.filter( - (m) => m.role === 'tool', - ) + const toolResultMsgs = secondCallMessages.filter((m) => m.role === 'tool') expect(toolResultMsgs).toHaveLength(2) }) }) From e5d40a60acdd7dab3d3610b105c8a51b3051aaf2 Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Mon, 2 Mar 2026 14:01:10 +0100 Subject: [PATCH 06/16] feat: enhance middleware instrumentation and add iteration events - Added new iteration events for tracking the start and completion of agent loop iterations. - Implemented middleware instrumentation to emit events for hook executions, configuration transformations, and chunk transformations. - Updated devtools middleware to handle iteration events and emit relevant data. - Expanded event client interfaces to include new iteration and middleware event types. - Introduced styles for iteration timeline and associated UI components. --- .../ts-react-chat/src/routes/api.tanchat.ts | 28 +- .../src/components/ConversationDetails.tsx | 268 ++++--- .../conversation/ConversationHeader.tsx | 149 ++-- .../conversation/ConversationTabs.tsx | 27 +- .../components/conversation/IterationCard.tsx | 614 +++++++++++++++ .../conversation/IterationTimeline.tsx | 392 ++++++++++ .../conversation/MiddlewareEventsSection.tsx | 59 ++ .../src/components/conversation/index.ts | 4 + .../ai-devtools/src/store/ai-context.tsx | 496 ++++++++++++- .../ai-devtools/src/store/ai-store.ts | 9 +- .../ai-devtools/src/styles/use-styles.ts | 702 ++++++++++++++++++ .../src/activities/chat/middleware/compose.ts | 181 ++++- .../chat/middleware/devtools-middleware.ts | 47 +- packages/typescript/ai/src/event-client.ts | 68 ++ 14 files changed, 2807 insertions(+), 237 deletions(-) create mode 100644 packages/typescript/ai-devtools/src/components/conversation/IterationCard.tsx create mode 100644 packages/typescript/ai-devtools/src/components/conversation/IterationTimeline.tsx create mode 100644 packages/typescript/ai-devtools/src/components/conversation/MiddlewareEventsSection.tsx diff --git a/examples/ts-react-chat/src/routes/api.tanchat.ts b/examples/ts-react-chat/src/routes/api.tanchat.ts index 45454c32d..a9e313467 100644 --- a/examples/ts-react-chat/src/routes/api.tanchat.ts +++ b/examples/ts-react-chat/src/routes/api.tanchat.ts @@ -11,7 +11,7 @@ import { anthropicText } from '@tanstack/ai-anthropic' import { geminiText } from '@tanstack/ai-gemini' import { openRouterText } from '@tanstack/ai-openrouter' import { grokText } from '@tanstack/ai-grok' -import type { AnyTextAdapter } from '@tanstack/ai' +import type { AnyTextAdapter, ChatMiddleware } from '@tanstack/ai' import { addToCartToolDef, addToWishListToolDef, @@ -69,6 +69,31 @@ const addToCartToolServer = addToCartToolDef.server((args, context) => { } }) +const loggingMiddleware: ChatMiddleware = { + name: 'logging', + onConfig(ctx, config) { + console.log(`[logging] onConfig iteration=${ctx.iteration} model=${ctx.model} tools=${config.tools.length}`) + }, + onStart(ctx) { + console.log(`[logging] onStart requestId=${ctx.requestId}`) + }, + onIteration(ctx, info) { + console.log(`[logging] onIteration iteration=${info.iteration}`) + }, + onBeforeToolCall(ctx, toolCtx) { + console.log(`[logging] onBeforeToolCall tool=${toolCtx.toolName}`) + }, + onAfterToolCall(ctx, info) { + console.log(`[logging] onAfterToolCall tool=${info.toolName} result=${JSON.stringify(info.result).slice(0, 100)}`) + }, + onFinish(ctx, info) { + console.log(`[logging] onFinish reason=${info.finishReason} iterations=${ctx.iteration}`) + }, + onUsage(ctx, usage) { + console.log(`[logging] onUsage tokens=${usage.totalTokens} input=${usage.promptTokens} output=${usage.completionTokens}, total: ${usage.totalTokens}`) + }, +} + export const Route = createFileRoute('/api/tanchat')({ server: { handlers: { @@ -159,6 +184,7 @@ export const Route = createFileRoute('/api/tanchat')({ addToWishListToolDef, getPersonalGuitarPreferenceToolDef, ], + middleware: [loggingMiddleware], systemPrompts: [SYSTEM_PROMPT], agentLoopStrategy: maxIterations(20), messages, diff --git a/packages/typescript/ai-devtools/src/components/ConversationDetails.tsx b/packages/typescript/ai-devtools/src/components/ConversationDetails.tsx index a465d5421..7700ce8cd 100644 --- a/packages/typescript/ai-devtools/src/components/ConversationDetails.tsx +++ b/packages/typescript/ai-devtools/src/components/ConversationDetails.tsx @@ -6,6 +6,7 @@ import { ChunksTab, ConversationHeader, ConversationTabs, + IterationTimeline, MessagesTab, SummariesTab, } from './conversation' @@ -23,69 +24,83 @@ export const ConversationDetails: Component = () => { return state.conversations[state.activeConversationId] } - // Update active tab when conversation changes + const hasIterations = () => { + const conv = activeConversation() + return conv && conv.iterations.length > 0 + } + + const hasActivityTabs = () => { + const conv = activeConversation() + if (!conv) return false + return ( + conv.hasSummarize || + conv.hasImage || + conv.hasSpeech || + conv.hasTranscription || + conv.hasVideo + ) + } + + // Update active tab when conversation changes (only for non-iteration views) createEffect(() => { const conv = activeConversation() - if (conv) { - // For server conversations, always use chunks (messages tab is hidden) - if (conv.type === 'server') { - if (conv.chunks.length > 0) { - setActiveTab('chunks') - } else if ( - conv.hasSummarize || - (conv.summaries && conv.summaries.length > 0) - ) { - setActiveTab('summaries') - } else if ( - conv.hasImage || - (conv.imageEvents && conv.imageEvents.length > 0) - ) { - setActiveTab('image') - } else if ( - conv.hasSpeech || - (conv.speechEvents && conv.speechEvents.length > 0) - ) { - setActiveTab('speech') - } else if ( - conv.hasTranscription || - (conv.transcriptionEvents && conv.transcriptionEvents.length > 0) - ) { - setActiveTab('transcription') - } else if ( - conv.hasVideo || - (conv.videoEvents && conv.videoEvents.length > 0) - ) { - setActiveTab('video') - } else { - setActiveTab('chunks') - } + if (!conv) return + + // If iterations exist, the timeline is the primary view — only set tab for activity + if (conv.iterations.length > 0) { + if (conv.hasSummarize || (conv.summaries && conv.summaries.length > 0)) { + setActiveTab('summaries') + } else if ( + conv.hasImage || + (conv.imageEvents && conv.imageEvents.length > 0) + ) { + setActiveTab('image') + } + return + } + + // No iterations — use flat message/chunk view + if (conv.type === 'server') { + if (conv.chunks.length > 0) { + setActiveTab('chunks') + } else if ( + conv.hasSummarize || + (conv.summaries && conv.summaries.length > 0) + ) { + setActiveTab('summaries') + } else if ( + conv.hasImage || + (conv.imageEvents && conv.imageEvents.length > 0) + ) { + setActiveTab('image') + } else if ( + conv.hasSpeech || + (conv.speechEvents && conv.speechEvents.length > 0) + ) { + setActiveTab('speech') + } else if ( + conv.hasTranscription || + (conv.transcriptionEvents && conv.transcriptionEvents.length > 0) + ) { + setActiveTab('transcription') + } else if ( + conv.hasVideo || + (conv.videoEvents && conv.videoEvents.length > 0) + ) { + setActiveTab('video') + } else { + setActiveTab('chunks') + } + } else { + if (conv.messages.length > 0) { + setActiveTab('messages') + } else if ( + conv.hasImage || + (conv.imageEvents && conv.imageEvents.length > 0) + ) { + setActiveTab('image') } else { - // For client conversations, default to messages tab - if (conv.messages.length > 0) { - setActiveTab('messages') - } else if ( - conv.hasImage || - (conv.imageEvents && conv.imageEvents.length > 0) - ) { - setActiveTab('image') - } else if ( - conv.hasSpeech || - (conv.speechEvents && conv.speechEvents.length > 0) - ) { - setActiveTab('speech') - } else if ( - conv.hasTranscription || - (conv.transcriptionEvents && conv.transcriptionEvents.length > 0) - ) { - setActiveTab('transcription') - } else if ( - conv.hasVideo || - (conv.videoEvents && conv.videoEvents.length > 0) - ) { - setActiveTab('video') - } else { - setActiveTab('messages') - } + setActiveTab('messages') } } }) @@ -102,46 +117,101 @@ export const ConversationDetails: Component = () => { {(conv) => (
- -
- - - - - - - - - - - - - - - - - - - - +
+ - -
+
+ + + {/* Fallback: flat message/chunk view when no iterations */} + + +
+ + + + + + + + + + + + + + + + + + + + + +
+
+ + {/* Activity tabs shown below iterations when relevant */} + + +
+ + + + + + + + + + + + + + + +
+
)} diff --git a/packages/typescript/ai-devtools/src/components/conversation/ConversationHeader.tsx b/packages/typescript/ai-devtools/src/components/conversation/ConversationHeader.tsx index 24d40a1e3..9a8e8d451 100644 --- a/packages/typescript/ai-devtools/src/components/conversation/ConversationHeader.tsx +++ b/packages/typescript/ai-devtools/src/components/conversation/ConversationHeader.tsx @@ -1,5 +1,4 @@ -import { For, Show } from 'solid-js' -import { JsonTree } from '@tanstack/devtools-ui' +import { Show } from 'solid-js' import { useStyles } from '../../styles/use-styles' import { formatDuration } from '../utils' import type { Component } from 'solid-js' @@ -15,11 +14,41 @@ export const ConversationHeader: Component = ( const styles = useStyles() const conv = () => props.conversation - const toolNames = () => conv().toolNames ?? [] - const options = () => conv().options - const modelOptions = () => conv().modelOptions - const iterationCount = () => conv().iterationCount - const systemPrompts = () => conv().systemPrompts ?? [] + const iterationCount = () => conv().iterationCount ?? conv().iterations.length + const totalDuration = () => { + if (!conv().completedAt) return undefined + return conv().completedAt! - conv().startedAt + } + + const totalMessages = () => conv().messages.length + + const totalToolCalls = () => { + let count = 0 + for (const iter of conv().iterations) { + if (iter.finishReason === 'tool_calls') count++ + } + return count + } + + // Sum usage across all iterations + const totalUsage = () => { + if (conv().usage) return conv().usage + if (conv().iterations.length === 0) return undefined + let promptTokens = 0 + let completionTokens = 0 + for (const iter of conv().iterations) { + if (iter.usage) { + promptTokens += iter.usage.promptTokens + completionTokens += iter.usage.completionTokens + } + } + if (promptTokens === 0 && completionTokens === 0) return undefined + return { + promptTokens, + completionTokens, + totalTokens: promptTokens + completionTokens, + } + } return (
@@ -39,107 +68,39 @@ export const ConversationHeader: Component = ( > {conv().status}
- 1}> -
- 🔄 {iterationCount()} iterations -
-
- {conv().model && `Model: ${conv().model}`} - {conv().provider && ` • Provider: ${conv().provider}`} - {conv().completedAt && - ` • Duration: ${formatDuration(conv().completedAt! - conv().startedAt)}`} + {totalDuration() !== undefined && formatDuration(totalDuration())} + 0}> + {totalDuration() !== undefined && ' · '} + {iterationCount()} {iterationCount() === 1 ? 'iteration' : 'iterations'} + + 0}> + {(totalDuration() !== undefined || iterationCount() > 0) && ' · '} + {totalMessages()} {totalMessages() === 1 ? 'message' : 'messages'} + + 0}> + {' · '}{totalToolCalls()} tool {totalToolCalls() === 1 ? 'call' : 'calls'} +
- {/* Tools list - always visible */} - 0}> -
- 🔧 -
- - {(toolName) => ( - - {toolName} - - )} - -
-
-
- {/* Options - always visible in compact form */} - 0}> -
- - ⚙️ Options: - -
- - {([key, value]) => ( - - {key}:{' '} - {typeof value === 'object' - ? JSON.stringify(value) - : String(value)} - - )} - -
-
-
- +
- 🎯 Tokens: + Tokens: - Prompt: {conv().usage?.promptTokens.toLocaleString() || 0} + {totalUsage()?.promptTokens.toLocaleString() || 0} in - + · - Completion: {conv().usage?.completionTokens.toLocaleString() || 0} + {totalUsage()?.completionTokens.toLocaleString() || 0} out - + · - Total: {conv().usage?.totalTokens.toLocaleString() || 0} + {totalUsage()?.totalTokens.toLocaleString() || 0} total
- {/* Model options - collapsible */} - 0}> -
- - 🧪 Model options - -
- } - defaultExpansionDepth={2} - /> -
-
-
- {/* System prompts - collapsible */} - 0}> -
- - 🧩 System prompts ({systemPrompts().length}) - -
- - {(prompt, index) => ( -
-
- #{index() + 1} -
-
- {prompt} -
-
- )} -
-
-
-
) diff --git a/packages/typescript/ai-devtools/src/components/conversation/ConversationTabs.tsx b/packages/typescript/ai-devtools/src/components/conversation/ConversationTabs.tsx index b69206b1a..b6a224ea2 100644 --- a/packages/typescript/ai-devtools/src/components/conversation/ConversationTabs.tsx +++ b/packages/typescript/ai-devtools/src/components/conversation/ConversationTabs.tsx @@ -21,6 +21,7 @@ interface ConversationTabsProps { export const ConversationTabs: Component = (props) => { const styles = useStyles() const conv = () => props.conversation + const hasIterations = () => conv().iterations.length > 0 // Total raw chunks = sum of all chunkCounts const totalRawChunks = () => @@ -79,11 +80,12 @@ export const ConversationTabs: Component = (props) => { previousVideoCount = count }) - // Determine if we should show any chat-related tabs - // For server conversations, don't show messages tab - only chunks + // When iterations exist, only show activity tabs (no messages/chunks) const hasMessages = () => - conv().type === 'client' && conv().messages.length > 0 - const hasChunks = () => conv().chunks.length > 0 || conv().type === 'server' + !hasIterations() && conv().type === 'client' && conv().messages.length > 0 + const hasChunks = () => + !hasIterations() && + (conv().chunks.length > 0 || conv().type === 'server') const hasSummaries = () => conv().hasSummarize || summariesCount() > 0 const hasImage = () => conv().hasImage || imageCount() > 0 const hasSpeech = () => conv().hasSpeech || speechCount() > 0 @@ -111,7 +113,6 @@ export const ConversationTabs: Component = (props) => { return (
- {/* Show messages tab for client conversations or when there are messages */} - {/* Show chunks tab for server conversations or when there are chunks */} - {/* Show summaries tab if there are summarize operations */} @@ -159,7 +158,7 @@ export const ConversationTabs: Component = (props) => { } ${imagePulse() ? styles().conversationDetails.tabButtonPulse : ''}`} onClick={() => props.onTabChange('image')} > - 🖼️ Image ({imageCount()}) + Image ({imageCount()}) @@ -171,7 +170,7 @@ export const ConversationTabs: Component = (props) => { } ${speechPulse() ? styles().conversationDetails.tabButtonPulse : ''}`} onClick={() => props.onTabChange('speech')} > - 🔊 Speech ({speechCount()}) + Speech ({speechCount()}) @@ -187,7 +186,7 @@ export const ConversationTabs: Component = (props) => { }`} onClick={() => props.onTabChange('transcription')} > - 📝 Transcription ({transcriptionCount()}) + Transcription ({transcriptionCount()}) @@ -199,7 +198,7 @@ export const ConversationTabs: Component = (props) => { } ${videoPulse() ? styles().conversationDetails.tabButtonPulse : ''}`} onClick={() => props.onTabChange('video')} > - 🎬 Video ({videoCount()}) + Video ({videoCount()})
diff --git a/packages/typescript/ai-devtools/src/components/conversation/IterationCard.tsx b/packages/typescript/ai-devtools/src/components/conversation/IterationCard.tsx new file mode 100644 index 000000000..6f5af5eb2 --- /dev/null +++ b/packages/typescript/ai-devtools/src/components/conversation/IterationCard.tsx @@ -0,0 +1,614 @@ +import { For, Index, Match, Show, Switch, createMemo, createSignal } from 'solid-js' +import { JsonTree } from '@tanstack/devtools-ui' +import { useStyles } from '../../styles/use-styles' +import { formatDuration } from '../utils' +import { SystemPromptItem } from './IterationTimeline' +import type { Iteration, Message, MiddlewareEvent, ToolCall } from '../../store/ai-store' +import type { Component } from 'solid-js' + +interface IterationCardProps { + iteration: Iteration + previousIteration?: Iteration + messages: Array + index: number + isLast: boolean +} + +// --- Step types --- + +type IterationStep = + | { kind: 'middleware'; event: MiddlewareEvent } + | { kind: 'thinking'; message: Message } + | { kind: 'assistant'; message: Message } + | { kind: 'tool_call'; toolCall: ToolCall; message: Message } + | { kind: 'tool_result'; message: Message } + +// --- Helpers --- + +function getIterationLabel(iter: Iteration, displayIndex: number): string { + if (!iter.completedAt) return `Iteration ${displayIndex} — Generating...` + if (iter.finishReason === 'error') return `Iteration ${displayIndex} — Error` + return `Iteration ${displayIndex}` +} + +/** + * Build steps in insertion order — no timestamp sorting. + * Events are emitted in order by the server, so we respect that order. + */ +function buildSteps( + iter: Iteration, + allMessages: Array, +): Array { + const steps: Array = [] + + // 1. Middleware events come first (they happen before/during generation) + for (const event of iter.middlewareEvents) { + steps.push({ kind: 'middleware', event }) + } + + // 2. Messages in their natural order from the store + const iterMessages = allMessages.filter( + (m) => iter.messageIds.includes(m.id) && m.role !== 'user', + ) + + for (const msg of iterMessages) { + if (msg.role === 'assistant') { + // Show thinking/reasoning as its own step before text content + if (msg.thinkingContent) { + steps.push({ kind: 'thinking', message: msg }) + } + if (msg.toolCalls && msg.toolCalls.length > 0) { + for (const tc of msg.toolCalls) { + steps.push({ kind: 'tool_call', toolCall: tc, message: msg }) + } + } + if (msg.content) { + steps.push({ kind: 'assistant', message: msg }) + } + } else if (msg.role === 'tool') { + steps.push({ kind: 'tool_result', message: msg }) + } + } + + return steps +} + +function truncate(str: string, max: number): string { + if (str.length <= max) return str + return str.slice(0, max) + '...' +} + +function tryParseJson(str: string): unknown | null { + try { + return JSON.parse(str) + } catch { + return null + } +} + +// --- Step renderers --- + +const MiddlewareStep: Component<{ step: Extract }> = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + const [expanded, setExpanded] = createSignal(false) + const ev = () => props.step.event + + const badgeClass = () => { + if (ev().wasDropped) return s().mwBadgeError + if (ev().hasTransform) return s().mwBadgeTransform + return s().mwBadgeDefault + } + + const suffix = () => { + if (ev().wasDropped) return 'DROP' + if (ev().hookName === 'onChunk' && ev().hasTransform) return 'TRANSFORM' + if (ev().hookName === 'onConfig' && ev().hasTransform) return 'TRANSFORM' + if (ev().hookName === 'onBeforeToolCall' && ev().hasTransform) return 'DECISION' + return null + } + + const hasChanges = () => ev().configChanges && Object.keys(ev().configChanges!).length > 0 + + return ( + <> +
+ Middleware + {ev().middlewareName} + {ev().hookName} + + {ev().duration}ms + + + {suffix()} + + + setExpanded(!expanded())}> + {expanded() ? 'hide changes' : 'show changes'} + + +
+ +
+ } + defaultExpansionDepth={2} + copyable + /> +
+
+ + ) +} + +const ThinkingStep: Component<{ step: Extract }> = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + const [expanded, setExpanded] = createSignal(false) + const msg = () => props.step.message + + const thinkingText = () => msg().thinkingContent || '' + const preview = () => truncate(thinkingText(), 150) + + return ( + <> +
+ Thinking + {preview() || '(empty)'} + 150}> + setExpanded(!expanded())}> + {expanded() ? 'hide' : 'show full'} + + +
+ +
{thinkingText()}
+
+ + ) +} + +const AssistantStep: Component<{ step: Extract }> = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + const [expanded, setExpanded] = createSignal(false) + const msg = () => props.step.message + + const contentLength = () => (msg().content || '').length + const isLong = () => contentLength() > 200 + + const preview = () => { + const text = msg().content || '' + if (text.length <= 200) return text + return truncate(text, 200) + } + + return ( + <> +
+ Response + + {preview() || '(empty)'} + + 200}> + setExpanded(!expanded())}> + {expanded() ? 'hide' : 'show full'} + + +
+ +
{msg().content}
+
+ + ) +} + +const ToolCallStep: Component<{ step: Extract }> = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + const [argsOpen, setArgsOpen] = createSignal(false) + const [resultOpen, setResultOpen] = createSignal(false) + const tc = () => props.step.toolCall + + const parsedArgs = () => { + const raw = tc().arguments || '{}' + return tryParseJson(raw) || raw + } + + const hasResult = () => tc().result !== undefined + + const parsedResult = () => { + if (!hasResult()) return null + if (typeof tc().result === 'string') { + return tryParseJson(tc().result as string) || tc().result + } + return tc().result + } + + return ( + <> +
setArgsOpen(!argsOpen())} style={{ cursor: 'pointer' }}> + Tool Call + {tc().name} + + {tc().duration}ms + + {'\u25B6'} +
+
+
+
+ } + defaultExpansionDepth={0} + copyable + /> +
+
+
+ +
setResultOpen(!resultOpen())} style={{ cursor: 'pointer' }}> + Result + {tc().name} + + {tc().duration}ms + + {'\u25B6'} +
+
+
+
+ } + defaultExpansionDepth={0} + copyable + /> +
+
+
+
+ + ) +} + +const ToolResultStep: Component<{ step: Extract }> = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + const [isOpen, setIsOpen] = createSignal(false) + const msg = () => props.step.message + + const parsedContent = () => { + const text = msg().content || '' + return tryParseJson(text) || text + } + + return ( + <> +
setIsOpen(!isOpen())} style={{ cursor: 'pointer' }}> + Result + {'\u25B6'} +
+
+
+
+ } + defaultExpansionDepth={0} + copyable + /> +
+
+
+ + ) +} + +// --- Main component --- + +export const IterationCard: Component = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + const [isOpen, setIsOpen] = createSignal(props.isLast) + const [configExpanded, setConfigExpanded] = createSignal(false) + + const iter = () => props.iteration + const isActive = () => !iter().completedAt + const isCompleted = () => !!iter().completedAt && iter().finishReason !== 'error' + const isError = () => iter().finishReason === 'error' + + const duration = () => { + if (!iter().completedAt) return undefined + return iter().completedAt! - iter().startedAt + } + + const label = () => getIterationLabel(iter(), props.index) + const steps = createMemo(() => buildSteps(iter(), props.messages)) + + /** + * Compute delta usage for display. + * The store holds cumulative usage per iteration (as reported by the provider). + * If a previous iteration exists on the same request, subtract its cumulative + * to get this iteration's incremental usage. + */ + const deltaUsage = createMemo(() => { + const usage = iter().usage + if (!usage) return undefined + const prev = props.previousIteration + if (!prev?.usage) return usage + // Only subtract if same request (cumulative values are per-request) + if (prev.requestId !== iter().requestId) return usage + return { + promptTokens: Math.max(0, usage.promptTokens - prev.usage.promptTokens), + completionTokens: Math.max(0, usage.completionTokens - prev.usage.completionTokens), + totalTokens: Math.max(0, usage.totalTokens - prev.usage.totalTokens), + } + }) + + const finishLabel = () => { + if (!iter().finishReason) return null + switch (iter().finishReason) { + case 'stop': return 'completed' + case 'tool_calls': return 'tool calls' + case 'error': return 'error' + case 'length': return 'max length' + default: return iter().finishReason + } + } + + const headerAccent = () => { + if (isActive()) return s().iterHeaderActive + if (isError()) return s().iterHeaderError + if (isCompleted()) return s().iterHeaderCompleted + return '' + } + + // Config data from this iteration + const configSubtitle = () => { + const parts: Array = [] + if (iter().model) parts.push(iter().model!) + if (iter().provider) parts.push(iter().provider!) + return parts.length > 0 ? parts.join(' \u00B7 ') : null + } + + const toolNames = () => iter().toolNames || [] + const systemPrompts = () => iter().systemPrompts || [] + + /** Count actual tool invocations in this iteration's messages */ + const toolInvocationCounts = createMemo(() => { + const counts = new Map() + const msgIds = new Set(iter().messageIds) + for (const msg of props.messages) { + if (msgIds.has(msg.id) && msg.toolCalls) { + for (const tc of msg.toolCalls) { + counts.set(tc.name, (counts.get(tc.name) || 0) + 1) + } + } + } + return counts + }) + + const totalToolCalls = createMemo(() => { + let count = 0 + for (const v of toolInvocationCounts().values()) count += v + return count + }) + const modelOptions = () => iter().modelOptions + const hasModelOptions = () => { + const opts = modelOptions() + return opts && Object.keys(opts).length > 0 + } + + const middlewareTransformCount = createMemo(() => { + let count = 0 + for (const ev of iter().middlewareEvents) { + if (ev.hasTransform) count++ + } + return count + }) + + const hasConfigChanged = () => { + const prev = props.previousIteration + if (!prev) return false + return iter().model !== prev.model || + iter().provider !== prev.provider || + JSON.stringify(iter().toolNames) !== JSON.stringify(prev.toolNames) + } + + const configDiffs = () => { + const prev = props.previousIteration + if (!prev) return [] + const diffs: Array<{ key: string; from: string; to: string }> = [] + if (iter().model !== prev.model) { + diffs.push({ key: 'model', from: prev.model || '(none)', to: iter().model || '(none)' }) + } + if (iter().provider !== prev.provider) { + diffs.push({ key: 'provider', from: prev.provider || '(none)', to: iter().provider || '(none)' }) + } + if (JSON.stringify(iter().toolNames) !== JSON.stringify(prev.toolNames)) { + diffs.push({ + key: 'tools', + from: prev.toolNames?.join(', ') || '(none)', + to: iter().toolNames?.join(', ') || '(none)', + }) + } + return diffs + } + + const hasExpandableConfig = () => + toolNames().length > 0 || systemPrompts().length > 0 || hasModelOptions() + + return ( +
+ {/* Header */} +
setIsOpen(!isOpen())}> +
+ {label()} + {/* Config subtitle — same pattern as user message card */} +
+ + {configSubtitle()} + + 0}> + + {toolNames().length} tool{toolNames().length === 1 ? '' : 's'} + + + 0}> + + {systemPrompts().length} system prompt{systemPrompts().length === 1 ? '' : 's'} + + + + options + + + config changed + + 0}> + + {middlewareTransformCount()} middleware transform{middlewareTransformCount() === 1 ? '' : 's'} + + + + { e.stopPropagation(); setConfigExpanded(!configExpanded()) }} + > + {configExpanded() ? 'hide config' : 'show config'} + + +
+
+
+ + + {finishLabel()} + + + 0}> + + 🔧 {totalToolCalls()} + + + + + ⏱️ {formatDuration(duration())} + + + + + 🎯 {deltaUsage()!.totalTokens.toLocaleString()} + + + + ⟳ streaming + +
+ + {'\u25B6'} + +
+ + {/* Expandable config panel — between header and steps */} +
+
+
+ +
+ Config Changes +
+ + {(diff) => ( +
+ {diff.key}: + {diff.from} + {'\u2192'} + {diff.to} +
+ )} +
+
+
+
+ 0}> +
+ Tools +
+ + {(name) => ( + + {name} + + {toolInvocationCounts().get(name) || 0} + + + )} + +
+
+
+ 0}> +
+ + System Prompts ({systemPrompts().length}) + + + {(prompt, i) => ( + + )} + +
+
+ +
+ Model Options +
+ } + defaultExpansionDepth={2} + copyable + /> +
+
+
+
+
+
+ + {/* Body — step-by-step timeline */} +
+
+ + {(step) => ( + + + } /> + + + } /> + + + } /> + + + } /> + + + } /> + + + )} + +
+
+
+ ) +} diff --git a/packages/typescript/ai-devtools/src/components/conversation/IterationTimeline.tsx b/packages/typescript/ai-devtools/src/components/conversation/IterationTimeline.tsx new file mode 100644 index 000000000..c7746e026 --- /dev/null +++ b/packages/typescript/ai-devtools/src/components/conversation/IterationTimeline.tsx @@ -0,0 +1,392 @@ +import { For, Index, Show, createMemo, createSignal } from 'solid-js' +import { JsonTree } from '@tanstack/devtools-ui' +import { useStyles } from '../../styles/use-styles' +import { formatDuration } from '../utils' +import { IterationCard } from './IterationCard' +import type { Iteration, Message } from '../../store/ai-store' +import type { Component } from 'solid-js' + +/** A group of iterations triggered by a single user message */ +interface UserMessageGroup { + userMessage: Message | null + iterations: Array +} + +interface IterationTimelineProps { + iterations: Array + messages: Array +} + +export const IterationTimeline: Component = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + + /** + * Group iterations by user messages. + * Memoized to avoid recomputing on unrelated store changes. + */ + const groups = createMemo((): Array => { + const userMessages = props.messages.filter((m) => m.role === 'user') + const iters = props.iterations + + if (userMessages.length === 0) { + return iters.length > 0 ? [{ userMessage: null, iterations: iters }] : [] + } + + const result: Array = [] + + const sortedUsers = [...userMessages].sort( + (a, b) => a.timestamp - b.timestamp, + ) + + for (let u = 0; u < sortedUsers.length; u++) { + const currentUser = sortedUsers[u]! + const nextUser = sortedUsers[u + 1] + + const groupIters = iters.filter((it) => { + if (it.startedAt < currentUser.timestamp) return false + if (nextUser && it.startedAt >= nextUser.timestamp) return false + return true + }) + + if (groupIters.length > 0) { + result.push({ userMessage: currentUser, iterations: groupIters }) + } + } + + // Catch any iterations before the first user message + if (sortedUsers[0]) { + const earlyIters = iters.filter( + (it) => it.startedAt < sortedUsers[0]!.timestamp, + ) + if (earlyIters.length > 0) { + result.unshift({ userMessage: null, iterations: earlyIters }) + } + } + + return result + }) + + return ( +
+ 0} + fallback={
No iterations recorded
} + > +
+ + {(group) => ( + + )} + +
+
+
+ ) +} + +/** Collapsible system prompt with preview */ +export const SystemPromptItem: Component<{ prompt: string; index: number }> = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + const [expanded, setExpanded] = createSignal(false) + + const isLong = () => props.prompt.length > 120 + const preview = () => + isLong() ? props.prompt.slice(0, 120) + '...' : props.prompt + + return ( +
+
isLong() && setExpanded(!expanded())} + > + #{props.index + 1} + + {expanded() ? '' : preview()} + + + + {expanded() ? 'collapse' : 'expand'} + + +
+ +
{props.prompt}
+
+
+ ) +} + +/** Card wrapping a user message and its child iterations */ +const UserMessageGroupCard: Component<{ + group: UserMessageGroup + allMessages: Array +}> = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + const [isOpen, setIsOpen] = createSignal(true) + const [configExpanded, setConfigExpanded] = createSignal(false) + + const group = () => props.group + const userMsg = () => group().userMessage + const iters = () => group().iterations + + const totalDuration = createMemo(() => { + let sum = 0 + for (const it of iters()) { + if (it.completedAt) { + sum += it.completedAt - it.startedAt + } + } + return sum > 0 ? sum : undefined + }) + + /** Count actual tool invocations across all messages in this group */ + const toolInvocationCounts = createMemo(() => { + const counts = new Map() + const allMsgIds = new Set() + for (const it of iters()) { + for (const id of it.messageIds) allMsgIds.add(id) + } + for (const msg of props.allMessages) { + if (allMsgIds.has(msg.id) && msg.toolCalls) { + for (const tc of msg.toolCalls) { + counts.set(tc.name, (counts.get(tc.name) || 0) + 1) + } + } + } + return counts + }) + + const totalToolCalls = createMemo(() => { + let count = 0 + for (const v of toolInvocationCounts().values()) count += v + return count + }) + + /** + * Total usage across this group. + * Iterations store CUMULATIVE usage per request, so we take the MAX + * cumulative from each request group (the last iteration's value + * represents the total for that request). + */ + const totalUsage = createMemo(() => { + const maxByRequest = new Map() + for (const it of iters()) { + if (!it.usage) continue + const key = it.requestId || '__default__' + const existing = maxByRequest.get(key) + if (!existing || it.usage.totalTokens > existing.prompt + existing.completion) { + maxByRequest.set(key, { + prompt: it.usage.promptTokens, + completion: it.usage.completionTokens, + }) + } + } + let prompt = 0 + let completion = 0 + for (const v of maxByRequest.values()) { + prompt += v.prompt + completion += v.completion + } + if (prompt + completion === 0) return undefined + return { promptTokens: prompt, completionTokens: completion, totalTokens: prompt + completion } + }) + + const isActive = () => iters().some((it) => !it.completedAt) + const hasError = () => iters().some((it) => it.finishReason === 'error') + const allCompleted = () => iters().every((it) => !!it.completedAt) && !hasError() + + const groupAccentClass = () => { + if (isActive()) return s().cardActive + if (hasError()) return s().cardError + if (allCompleted()) return s().cardCompleted + return '' + } + + const userContent = () => { + const msg = userMsg() + if (!msg) return '(no user message)' + const text = msg.content || '' + return text.length > 120 ? text.slice(0, 120) + '...' : text + } + + // Config from the first iteration of this group + const firstIter = createMemo(() => iters()[0]) + + const configSubtitle = () => { + const first = firstIter() + if (!first) return null + const parts: Array = [] + if (first.model) parts.push(first.model) + if (first.provider) parts.push(first.provider) + return parts.length > 0 ? parts.join(' \u00B7 ') : null + } + + const toolNames = () => firstIter()?.toolNames || [] + const systemPrompts = () => firstIter()?.systemPrompts || [] + const modelOptions = () => firstIter()?.modelOptions + const hasModelOptions = () => { + const opts = modelOptions() + return opts && Object.keys(opts).length > 0 + } + + const middlewareTransformCount = createMemo(() => { + let count = 0 + for (const it of iters()) { + for (const ev of it.middlewareEvents) { + if (ev.hasTransform) count++ + } + } + return count + }) + + const hasExpandableConfig = () => + toolNames().length > 0 || systemPrompts().length > 0 || hasModelOptions() + + return ( +
+ {/* User message header */} +
setIsOpen(!isOpen())}> +
U
+
+ {userContent()} + {/* Config subtitle — always visible under user message */} +
+ + {configSubtitle()} + + 0}> + + {toolNames().length} tool{toolNames().length === 1 ? '' : 's'} + + + 0}> + + {systemPrompts().length} system prompt{systemPrompts().length === 1 ? '' : 's'} + + + + options + + 0}> + + {middlewareTransformCount()} middleware transform{middlewareTransformCount() === 1 ? '' : 's'} + + + + { e.stopPropagation(); setConfigExpanded(!configExpanded()) }} + > + {configExpanded() ? 'hide config' : 'show config'} + + +
+
+
+ 0}> + + 🔄 {iters().length} + + + 0}> + + 🔧 {totalToolCalls()} + + + + + ⏱️ {formatDuration(totalDuration())} + + + + + 🎯 {totalUsage()!.totalTokens.toLocaleString()} + + + + ⟳ streaming + +
+ + {'\u25B6'} + +
+ + {/* Expandable config details — sits between header and iterations */} +
+
+
+ 0}> +
+ Tools +
+ + {(name) => ( + + {name} + + {toolInvocationCounts().get(name) || 0} + + + )} + +
+
+
+ 0}> +
+ + System Prompts ({systemPrompts().length}) + + + {(prompt, i) => ( + + )} + +
+
+ +
+ Model Options +
+ } + defaultExpansionDepth={2} + copyable + /> +
+
+
+
+
+
+ + {/* Iterations list — full width */} +
+
+
+ + {(iteration, index) => ( + 0 ? iters()[index() - 1] : undefined + } + messages={props.allMessages} + index={index()} + isLast={index() === iters().length - 1} + /> + )} + +
+
+
+
+ ) +} diff --git a/packages/typescript/ai-devtools/src/components/conversation/MiddlewareEventsSection.tsx b/packages/typescript/ai-devtools/src/components/conversation/MiddlewareEventsSection.tsx new file mode 100644 index 000000000..ebaf61ca9 --- /dev/null +++ b/packages/typescript/ai-devtools/src/components/conversation/MiddlewareEventsSection.tsx @@ -0,0 +1,59 @@ +import { For, Show } from 'solid-js' +import { useStyles } from '../../styles/use-styles' +import type { MiddlewareEvent } from '../../store/ai-store' +import type { Component } from 'solid-js' + +interface MiddlewareEventsSectionProps { + events: Array +} + +export const MiddlewareEventsSection: Component< + MiddlewareEventsSectionProps +> = (props) => { + const styles = useStyles() + const s = () => styles().iterationTimeline + + const getSuffix = (event: MiddlewareEvent): string | null => { + if (event.wasDropped) return 'DROP' + if (event.hookName === 'onChunk' && event.hasTransform) return 'TRANSFORM' + if (event.hookName === 'onConfig' && event.hasTransform) return 'TRANSFORM' + if ( + event.hookName === 'onBeforeToolCall' && + event.hasTransform + ) + return 'DECISION' + return null + } + + return ( + 0}> +
+ + {(event) => { + const isTransform = () => event.hasTransform + const suffix = () => getSuffix(event) + const duration = () => + event.duration !== undefined ? `${event.duration}ms` : null + + return ( + + {event.middlewareName} + · + {event.hookName} + + · + {duration()} + + + {suffix()} + + + ) + }} + +
+
+ ) +} diff --git a/packages/typescript/ai-devtools/src/components/conversation/index.ts b/packages/typescript/ai-devtools/src/components/conversation/index.ts index 2a0c91770..2ee2079c3 100644 --- a/packages/typescript/ai-devtools/src/components/conversation/index.ts +++ b/packages/typescript/ai-devtools/src/components/conversation/index.ts @@ -6,3 +6,7 @@ export { MessagesTab } from './MessagesTab' export { ChunksTab } from './ChunksTab' export { SummariesTab } from './SummariesTab' export { ActivityEventsTab } from './ActivityEventsTab' + +export { IterationTimeline } from './IterationTimeline' +export { IterationCard } from './IterationCard' +export { MiddlewareEventsSection } from './MiddlewareEventsSection' diff --git a/packages/typescript/ai-devtools/src/store/ai-context.tsx b/packages/typescript/ai-devtools/src/store/ai-context.tsx index 4b852da15..1846a2d28 100644 --- a/packages/typescript/ai-devtools/src/store/ai-context.tsx +++ b/packages/typescript/ai-devtools/src/store/ai-context.tsx @@ -60,6 +60,8 @@ export interface Message { thinkingContent?: string /** Source of the message: 'client' for aggregated client-side data, 'server' for individual server chunks */ source?: 'client' | 'server' + /** The requestId this message belongs to (for scoping usage calculations) */ + requestId?: string } /** @@ -100,6 +102,38 @@ export interface Chunk { isClientTool?: boolean } +export interface MiddlewareEvent { + id: string + middlewareName: string + hookName: string + timestamp: number + duration?: number + hasTransform: boolean + configChanges?: Record + originalChunkType?: string + resultCount?: number + wasDropped?: boolean +} + +export interface Iteration { + /** The requestId this iteration belongs to (unique per chat() call) */ + requestId?: string + index: number + messageId: string + startedAt: number + completedAt?: number + model?: string + provider?: string + systemPrompts?: Array + toolNames?: Array + options?: Record + modelOptions?: Record + finishReason?: string + usage?: TokenUsage + middlewareEvents: Array + messageIds: Array +} + export interface SummarizeOperation { id: string model: string @@ -129,6 +163,7 @@ export interface Conversation { startedAt: number completedAt?: number usage?: TokenUsage + iterations: Array iterationCount?: number toolNames?: Array options?: Record @@ -178,6 +213,8 @@ export const AIProvider: ParentComponent = (props) => { const streamToConversation = new Map() const requestToConversation = new Map() + /** Track max cumulative usage per requestId per conversation for correct totals */ + const requestUsageByConversation = new Map>() // Batching system for high-frequency chunk updates with consolidated chunk merging // Stores: conversationId -> { chunks to merge, totalNewChunks count } @@ -366,6 +403,7 @@ export const AIProvider: ParentComponent = (props) => { label, messages: [], chunks: [], + iterations: [], imageEvents: [], speechEvents: [], transcriptionEvents: [], @@ -441,6 +479,7 @@ export const AIProvider: ParentComponent = (props) => { conversationId: string, messageId: string | undefined, cumulativeUsage: TokenUsage, + requestId?: string, ): void { const conv = state.conversations[conversationId] if (!conv) return @@ -467,10 +506,12 @@ export const AIProvider: ParentComponent = (props) => { if (targetMessageIndex === -1) return - // Sum up usage from all previous assistant messages + // Sum up usage from previous assistant messages in the SAME request only. + // Cumulative usage is per-request, so mixing requests gives wrong deltas. for (let i = 0; i < targetMessageIndex; i++) { const msg = conv.messages[i] if (msg?.role === 'assistant' && msg.usage) { + if (requestId && msg.requestId !== requestId) continue previousPromptTokens += msg.usage.promptTokens previousCompletionTokens += msg.usage.completionTokens } @@ -501,12 +542,46 @@ export const AIProvider: ParentComponent = (props) => { ) } + /** + * Update conversation-level usage by tracking max cumulative per request. + * Usage events report cumulative totals per-request, so we keep the highest + * value seen for each requestId and sum across all requests for the total. + */ + function updateConversationUsage( + conversationId: string, + requestId: string | undefined, + usage: TokenUsage, + ): void { + if (!state.conversations[conversationId]) return + const key = requestId || '__default__' + let requestMap = requestUsageByConversation.get(conversationId) + if (!requestMap) { + requestMap = new Map() + requestUsageByConversation.set(conversationId, requestMap) + } + const existing = requestMap.get(key) + if (!existing || usage.totalTokens > existing.totalTokens) { + requestMap.set(key, usage) + } + // Sum across all requests + let prompt = 0 + let completion = 0 + for (const v of requestMap.values()) { + prompt += v.promptTokens + completion += v.completionTokens + } + updateConversation(conversationId, { + usage: { promptTokens: prompt, completionTokens: completion, totalTokens: prompt + completion }, + }) + } + // Public actions function clearAllConversations() { setState('conversations', {}) setState('activeConversationId', null) streamToConversation.clear() requestToConversation.clear() + requestUsageByConversation.clear() pendingConversationChunks.clear() pendingMessageChunks.clear() } @@ -671,7 +746,7 @@ export const AIProvider: ParentComponent = (props) => { cleanupFns.push( aiEventClient.on('text:message:created', (e) => { - const { clientId, streamId, messageId, role, content, timestamp } = + const { clientId, streamId, messageId, role, content, timestamp, requestId } = e.payload const conversationId = clientId || @@ -781,6 +856,7 @@ export const AIProvider: ParentComponent = (props) => { parts, toolCalls, source, + requestId, } if (existingIndex >= 0) { @@ -789,13 +865,45 @@ export const AIProvider: ParentComponent = (props) => { addMessage(conversationId, messagePayload) } + // Track messageId in the correct iteration (scoped by requestId) + if (conv.iterations.length > 0) { + let iterIndex = -1 + if (requestId) { + // Find the latest iteration for this specific request + for (let i = conv.iterations.length - 1; i >= 0; i--) { + if (conv.iterations[i]?.requestId === requestId) { + iterIndex = i + break + } + } + } else { + // Fallback: use latest iteration + iterIndex = conv.iterations.length - 1 + } + if (iterIndex >= 0) { + const iter = conv.iterations[iterIndex] + if (iter && !iter.messageIds.includes(messageId)) { + setState( + 'conversations', + conversationId, + 'iterations', + iterIndex, + 'messageIds', + produce((arr: Array) => { + arr.push(messageId) + }), + ) + } + } + } + updateConversation(conversationId, { status: 'active', hasChat: true }) }), ) cleanupFns.push( aiEventClient.on('text:message:user', (e) => { - const { clientId, streamId, messageId, content, timestamp } = e.payload + const { clientId, streamId, messageId, content, timestamp, requestId } = e.payload const conversationId = clientId || (streamId ? streamToConversation.get(streamId) : undefined) @@ -822,6 +930,7 @@ export const AIProvider: ParentComponent = (props) => { content, timestamp, source, + requestId, }) }), ) @@ -1312,17 +1421,6 @@ export const AIProvider: ParentComponent = (props) => { const conv = state.conversations[conversationId] if (conv?.type === 'client') { addChunkToMessage(conversationId, chunk) - - if (e.payload.messageId) { - const messageIndex = conv.messages.findIndex( - (msg) => msg.id === e.payload.messageId, - ) - if (messageIndex !== -1) { - updateMessage(conversationId, messageIndex, { - thinkingContent: e.payload.content, - }) - } - } } else { ensureMessageForChunk( conversationId, @@ -1331,6 +1429,18 @@ export const AIProvider: ParentComponent = (props) => { ) addChunk(conversationId, chunk) } + + // Update thinkingContent on the message for all conversation types + if (e.payload.messageId && conv) { + const messageIndex = conv.messages.findIndex( + (msg) => msg.id === e.payload.messageId, + ) + if (messageIndex !== -1) { + updateMessage(conversationId, messageIndex, { + thinkingContent: e.payload.content, + }) + } + } }), ) @@ -1350,11 +1460,12 @@ export const AIProvider: ParentComponent = (props) => { } if (e.payload.usage) { - updateConversation(conversationId, { usage: e.payload.usage }) + updateConversationUsage(conversationId, e.payload.requestId, e.payload.usage) updateMessageUsage( conversationId, e.payload.messageId, e.payload.usage, + e.payload.requestId, ) } @@ -1370,6 +1481,43 @@ export const AIProvider: ParentComponent = (props) => { addChunk(conversationId, chunk) } + // Mark the current iteration as completed when the LLM finishes generating. + // This is critical for iterations that end with tool_calls — the + // text:iteration:completed event only fires when the NEXT iteration starts, + // so without this the iteration appears stuck in "streaming" during tool execution. + if (e.payload.finishReason) { + const convForIter = state.conversations[conversationId] + if (convForIter) { + for (let i = convForIter.iterations.length - 1; i >= 0; i--) { + const iter = convForIter.iterations[i] + const msgId = e.payload.messageId + if ( + iter && + !iter.completedAt && + msgId && + (iter.messageId === msgId || + iter.messageIds?.includes(msgId)) + ) { + const iterIdx = i + setState( + 'conversations', + conversationId, + 'iterations', + iterIdx, + produce((it: Iteration) => { + it.completedAt = e.payload.timestamp + if (!it.finishReason) + it.finishReason = e.payload.finishReason || undefined + if (e.payload.usage && !it.usage) + it.usage = e.payload.usage + }), + ) + break + } + } + } + } + updateConversation(conversationId, { status: 'completed', completedAt: e.payload.timestamp, @@ -1404,6 +1552,33 @@ export const AIProvider: ParentComponent = (props) => { addChunk(conversationId, chunk) } + // Mark any active iterations as completed with error + const convForError = state.conversations[conversationId] + if (convForError) { + const errorRequestId = e.payload.requestId + const errorMsgId = e.payload.messageId + for (let i = convForError.iterations.length - 1; i >= 0; i--) { + const iter = convForError.iterations[i] + if (iter && !iter.completedAt) { + // Scope to matching requestId or messageId when available + const matchesRequest = !errorRequestId || iter.requestId === errorRequestId + const matchesMessage = !errorMsgId || iter.messageId === errorMsgId || iter.messageIds?.includes(errorMsgId) + if (matchesRequest || matchesMessage) { + setState( + 'conversations', + conversationId, + 'iterations', + i, + produce((it: Iteration) => { + it.completedAt = e.payload.timestamp + if (!it.finishReason) it.finishReason = 'error' + }), + ) + } + } + } + } + updateConversation(conversationId, { status: 'error', completedAt: e.payload.timestamp, @@ -1541,16 +1716,50 @@ export const AIProvider: ParentComponent = (props) => { const conversationId = requestToConversation.get(requestId) if (conversationId && state.conversations[conversationId]) { - const updates: Partial = { + updateConversation(conversationId, { status: 'completed', completedAt: e.payload.timestamp, - } + }) if (usage) { - updates.usage = usage + updateConversationUsage(conversationId, requestId, usage) + updateMessageUsage( + conversationId, + e.payload.messageId, + usage, + requestId, + ) } - updateConversation(conversationId, updates) - if (usage) { - updateMessageUsage(conversationId, e.payload.messageId, usage) + + // Failsafe: mark any remaining active iterations FOR THIS REQUEST as completed. + // Only scope to this requestId to avoid touching other requests' iterations. + const conv = state.conversations[conversationId] + if (conv) { + for (let i = 0; i < conv.iterations.length; i++) { + const iter = conv.iterations[i] + if ( + iter && + !iter.completedAt && + (!requestId || iter.requestId === requestId) + ) { + const iterIdx = i + setState( + 'conversations', + conversationId, + 'iterations', + iterIdx, + produce((it: Iteration) => { + it.completedAt = e.payload.timestamp + if (!it.finishReason) { + it.finishReason = + e.payload.finishReason || 'stop' + } + if (!it.usage && usage) { + it.usage = usage + } + }), + ) + } + } } } }), @@ -1562,9 +1771,250 @@ export const AIProvider: ParentComponent = (props) => { const conversationId = requestToConversation.get(requestId) if (conversationId && state.conversations[conversationId]) { - updateConversation(conversationId, { usage }) - updateMessageUsage(conversationId, messageId, usage) + updateConversationUsage(conversationId, requestId, usage) + updateMessageUsage(conversationId, messageId, usage, requestId) + } + }), + ) + + // ============= Iteration Events ============= + + cleanupFns.push( + aiEventClient.on('text:iteration:started', (e) => { + const { requestId, streamId, clientId, iteration, messageId } = + e.payload + + const conversationId = + clientId || + (streamId ? streamToConversation.get(streamId) : undefined) || + requestToConversation.get(requestId) + if (!conversationId || !state.conversations[conversationId]) return + + // Failsafe: when a new iteration starts, any previous uncompleted + // iterations for the same request must have ended (with tool_calls). + // This covers edge cases where text:chunk:done didn't match by messageId. + const convForFailsafe = state.conversations[conversationId] + if (convForFailsafe) { + for (let i = 0; i < convForFailsafe.iterations.length; i++) { + const iter = convForFailsafe.iterations[i] + if (iter && !iter.completedAt && iter.requestId === requestId) { + setState( + 'conversations', + conversationId, + 'iterations', + i, + produce((it: Iteration) => { + it.completedAt = e.payload.timestamp + if (!it.finishReason) it.finishReason = 'tool_calls' + }), + ) + } + } + } + + // Guard against duplicate iteration events (e.g. middleware registered twice) + const existingConv = state.conversations[conversationId] + if ( + existingConv && + existingConv.iterations.some( + (it) => + it.index === iteration && + it.requestId === requestId, + ) + ) { + return + } + + const newIteration: Iteration = { + requestId, + index: iteration, + messageId, + startedAt: e.payload.timestamp, + model: e.payload.model, + provider: e.payload.provider, + systemPrompts: e.payload.systemPrompts, + toolNames: e.payload.toolNames, + options: e.payload.options, + modelOptions: e.payload.modelOptions, + middlewareEvents: [], + messageIds: [messageId], + } + + setState( + 'conversations', + conversationId, + 'iterations', + produce((arr: Array) => { + arr.push(newIteration) + }), + ) + setState( + 'conversations', + conversationId, + 'iterationCount', + iteration + 1, + ) + }), + ) + + cleanupFns.push( + aiEventClient.on('text:iteration:completed', (e) => { + const { requestId, streamId, clientId, iteration } = e.payload + + const conversationId = + clientId || + (streamId ? streamToConversation.get(streamId) : undefined) || + requestToConversation.get(requestId) + if (!conversationId || !state.conversations[conversationId]) return + + const conv = state.conversations[conversationId] + // Find the iteration by BOTH requestId and index to avoid cross-request pollution. + // Without requestId scoping, request 2's iteration 0 would match request 1's iteration 0. + const iterIndex = conv.iterations.findIndex( + (it) => + it.index === iteration && + (!requestId || it.requestId === requestId), + ) + if (iterIndex === -1) return + + setState( + 'conversations', + conversationId, + 'iterations', + iterIndex, + produce((it: Iteration) => { + it.completedAt = e.payload.timestamp + it.finishReason = e.payload.finishReason + if (e.payload.usage) { + it.usage = e.payload.usage + } + }), + ) + }), + ) + + // ============= Middleware Events ============= + + /** Find the latest iteration for a given requestId, or the very latest iteration as fallback */ + function findLatestIterationIndex( + conv: Conversation, + reqId?: string, + ): number { + if (reqId) { + for (let i = conv.iterations.length - 1; i >= 0; i--) { + if (conv.iterations[i]?.requestId === reqId) return i } + } + return conv.iterations.length - 1 + } + + cleanupFns.push( + aiEventClient.on('middleware:hook:executed', (e) => { + const { requestId, streamId, clientId } = e.payload + + const conversationId = + clientId || + (streamId ? streamToConversation.get(streamId) : undefined) || + requestToConversation.get(requestId) + if (!conversationId || !state.conversations[conversationId]) return + + const conv = state.conversations[conversationId] + const iterIndex = findLatestIterationIndex(conv, requestId) + if (iterIndex < 0) return + + const mwEvent: MiddlewareEvent = { + id: `mw-${Date.now()}-${Math.random()}`, + middlewareName: e.payload.middlewareName, + hookName: e.payload.hookName, + timestamp: e.payload.timestamp, + duration: e.payload.duration, + hasTransform: e.payload.hasTransform, + } + + setState( + 'conversations', + conversationId, + 'iterations', + iterIndex, + 'middlewareEvents', + produce((arr: Array) => { + arr.push(mwEvent) + }), + ) + }), + ) + + cleanupFns.push( + aiEventClient.on('middleware:config:transformed', (e) => { + const { requestId, streamId, clientId } = e.payload + + const conversationId = + clientId || + (streamId ? streamToConversation.get(streamId) : undefined) || + requestToConversation.get(requestId) + if (!conversationId || !state.conversations[conversationId]) return + + const conv = state.conversations[conversationId] + const iterIndex = findLatestIterationIndex(conv, requestId) + if (iterIndex < 0) return + + const mwEvent: MiddlewareEvent = { + id: `mw-cfg-${Date.now()}-${Math.random()}`, + middlewareName: e.payload.middlewareName, + hookName: 'onConfig', + timestamp: e.payload.timestamp, + hasTransform: true, + configChanges: e.payload.changes, + } + + setState( + 'conversations', + conversationId, + 'iterations', + iterIndex, + 'middlewareEvents', + produce((arr: Array) => { + arr.push(mwEvent) + }), + ) + }), + ) + + cleanupFns.push( + aiEventClient.on('middleware:chunk:transformed', (e) => { + const { requestId, streamId, clientId } = e.payload + + const conversationId = + clientId || + (streamId ? streamToConversation.get(streamId) : undefined) || + requestToConversation.get(requestId) + if (!conversationId || !state.conversations[conversationId]) return + + const conv = state.conversations[conversationId] + const iterIndex = findLatestIterationIndex(conv, requestId) + if (iterIndex < 0) return + + const mwEvent: MiddlewareEvent = { + id: `mw-chunk-${Date.now()}-${Math.random()}`, + middlewareName: e.payload.middlewareName, + hookName: 'onChunk', + timestamp: e.payload.timestamp, + hasTransform: true, + originalChunkType: e.payload.originalChunkType, + resultCount: e.payload.resultCount, + wasDropped: e.payload.wasDropped, + } + + setState( + 'conversations', + conversationId, + 'iterations', + iterIndex, + 'middlewareEvents', + produce((arr: Array) => { + arr.push(mwEvent) + }), + ) }), ) diff --git a/packages/typescript/ai-devtools/src/store/ai-store.ts b/packages/typescript/ai-devtools/src/store/ai-store.ts index 600b9ffba..d8f6c3318 100644 --- a/packages/typescript/ai-devtools/src/store/ai-store.ts +++ b/packages/typescript/ai-devtools/src/store/ai-store.ts @@ -1,2 +1,9 @@ // Re-export types from ai-context for backward compatibility -export type { ToolCall, Message, Chunk, Conversation } from './ai-context' +export type { + ToolCall, + Message, + Chunk, + Conversation, + Iteration, + MiddlewareEvent, +} from './ai-context' diff --git a/packages/typescript/ai-devtools/src/styles/use-styles.ts b/packages/typescript/ai-devtools/src/styles/use-styles.ts index ba9a9cac3..6008dc8ac 100644 --- a/packages/typescript/ai-devtools/src/styles/use-styles.ts +++ b/packages/typescript/ai-devtools/src/styles/use-styles.ts @@ -1588,6 +1588,708 @@ const stylesFactory = (theme: 'light' | 'dark') => { font-weight: ${font.weight.semibold}; `, }, + + iterationTimeline: { + container: css` + position: relative; + padding: ${size[3]} ${size[3]}; + overflow-y: auto; + flex: 1; + `, + pipeline: css` + position: relative; + display: flex; + flex-direction: column; + gap: ${size[3]}; + `, + iterList: css` + display: flex; + flex-direction: column; + gap: 0; + `, + // --- User message group card --- + card: css` + position: relative; + border-radius: ${border.radius.md}; + background: ${t(colors.gray[50], colors.darkGray[700])}; + border: 1px solid ${t(colors.gray[200], colors.darkGray[500])}; + overflow: hidden; + `, + cardCompleted: css` + border-color: ${t(colors.green[200], colors.green[900] + '60')}; + `, + cardError: css` + border-color: ${t(colors.red[300], colors.red[800])}; + `, + cardActive: css` + border-color: ${t(colors.blue[300], colors.blue[700])}; + `, + cardHeader: css` + display: flex; + align-items: flex-start; + gap: ${size[2]}; + padding: ${size[3]} ${size[3]}; + cursor: pointer; + user-select: none; + + &:hover { + background: ${t(colors.gray[100], colors.darkGray[600])}; + } + `, + cardHeaderContent: css` + flex: 1; + min-width: 0; + display: flex; + flex-direction: column; + gap: ${size[1]}; + `, + cardHeaderLabel: css` + font-size: ${fontSize.sm}; + font-weight: ${font.weight.semibold}; + color: ${t(colors.gray[800], colors.gray[200])}; + line-height: 1.3; + `, + cardSubtitle: css` + display: flex; + align-items: center; + gap: ${size[1.5]}; + flex-wrap: wrap; + `, + subtitleText: css` + font-size: 10px; + font-family: ${fontFamily.mono}; + color: ${t(colors.gray[500], colors.gray[400])}; + `, + subtitleBadge: css` + font-size: 9px; + padding: 0 ${size[1]}; + border-radius: ${border.radius.xs}; + background: ${t(colors.gray[100], colors.darkGray[500])}; + color: ${t(colors.gray[500], colors.gray[400])}; + `, + subtitleBadgeWarn: css` + font-size: 9px; + padding: 0 ${size[1]}; + border-radius: ${border.radius.xs}; + background: ${t(colors.purple[50], colors.purple[900] + '30')}; + color: ${t(colors.purple[600], colors.purple[300])}; + `, + subtitleExpandToggle: css` + font-size: 9px; + color: ${t(colors.blue[500], colors.blue[400])}; + cursor: pointer; + text-decoration: underline; + text-decoration-style: dotted; + &:hover { + color: ${t(colors.blue[600], colors.blue[300])}; + } + `, + configPanelWrapper: css` + display: grid; + grid-template-rows: 0fr; + transition: grid-template-rows 0.2s ease-out; + `, + configPanelWrapperOpen: css` + grid-template-rows: 1fr; + `, + configPanel: css` + overflow: hidden; + & > div { + padding: ${size[2]} ${size[3]}; + border-top: 1px solid ${t(colors.gray[200], colors.darkGray[500])}; + background: ${t(colors.gray[50], colors.darkGray[700])}; + font-size: ${fontSize.xs}; + } + `, + configPanelSection: css` + display: flex; + flex-direction: column; + gap: ${size[1]}; + padding-bottom: ${size[2]}; + + &:last-child { + padding-bottom: 0; + } + + & + & { + border-top: 1px solid ${t(colors.gray[200], colors.darkGray[500])}; + padding-top: ${size[2]}; + } + `, + configPanelLabel: css` + font-weight: ${font.weight.semibold}; + color: ${t(colors.gray[500], colors.gray[400])}; + font-size: 10px; + text-transform: uppercase; + letter-spacing: 0.5px; + flex-shrink: 0; + `, + configToolsList: css` + display: flex; + flex-wrap: wrap; + gap: ${size[1]}; + `, + configToolChip: css` + display: inline-flex; + align-items: center; + gap: ${size[1]}; + padding: 1px ${size[1.5]}; + border-radius: ${border.radius.sm}; + font-size: 10px; + font-family: ${fontFamily.mono}; + color: ${t(colors.yellow[800], colors.yellow[300])}; + background: ${t(colors.yellow[50], colors.yellow[900] + '25')}; + border: 1px solid ${t(colors.yellow[200], colors.yellow[800] + '40')}; + `, + configToolChipCount: css` + font-size: 9px; + font-weight: ${font.weight.bold}; + padding: 0 ${size[1]}; + border-radius: ${border.radius.xs}; + background: ${t(colors.yellow[200], colors.yellow[800] + '50')}; + color: ${t(colors.yellow[700], colors.yellow[200])}; + `, + configJsonTreeContainer: css` + border-radius: ${border.radius.sm}; + overflow: hidden; + border: 1px solid ${t(colors.gray[200], colors.darkGray[500])}; + background: ${t(colors.gray[50], colors.darkGray[800])}; + padding: ${size[1.5]} ${size[2]}; + `, + systemPromptCard: css` + border-radius: ${border.radius.sm}; + overflow: hidden; + border: 1px solid ${t(colors.gray[200], colors.darkGray[500])}; + background: ${t(colors.gray[50], colors.darkGray[800])}; + `, + systemPromptHeader: css` + display: flex; + align-items: center; + gap: ${size[1.5]}; + padding: ${size[1.5]} ${size[2]}; + cursor: pointer; + user-select: none; + font-size: ${fontSize.xs}; + + &:hover { + background: ${t(colors.gray[100], colors.darkGray[700])}; + } + `, + systemPromptIndex: css` + font-weight: ${font.weight.bold}; + color: ${t(colors.gray[400], colors.gray[500])}; + font-size: 10px; + flex-shrink: 0; + `, + systemPromptPreview: css` + flex: 1; + min-width: 0; + color: ${t(colors.gray[600], colors.gray[400])}; + font-family: ${fontFamily.mono}; + font-size: ${fontSize.xs}; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + `, + systemPromptFull: css` + margin: 0; + padding: ${size[2]} ${size[3]}; + font-size: ${fontSize.xs}; + font-family: ${fontFamily.mono}; + white-space: pre-wrap; + word-break: break-word; + max-height: 300px; + overflow-y: auto; + color: ${t(colors.gray[700], colors.gray[200])}; + border-top: 1px solid ${t(colors.gray[200], colors.darkGray[600])}; + background: ${t(colors.gray[50], colors.darkGray[900])}; + line-height: 1.5; + `, + cardHeaderBadges: css` + display: flex; + align-items: center; + gap: ${size[1]}; + flex-shrink: 0; + flex-wrap: wrap; + justify-content: flex-end; + `, + chevron: css` + color: ${t(colors.gray[400], colors.gray[500])}; + font-size: 10px; + transition: transform 0.2s ease; + flex-shrink: 0; + margin-top: 3px; + `, + chevronOpen: css` + transform: rotate(90deg); + `, + badge: css` + font-size: ${fontSize.xs}; + padding: 1px ${size[2]}; + border-radius: ${border.radius.sm}; + font-family: ${fontFamily.mono}; + font-weight: ${font.weight.medium}; + white-space: nowrap; + `, + badgeDuration: css` + background: ${t(colors.gray[100], colors.darkGray[500])}; + color: ${t(colors.gray[600], colors.gray[300])}; + `, + badgeFinishReason: css` + background: ${t(colors.blue[50], colors.blue[900] + '40')}; + color: ${t(colors.blue[700], colors.blue[300])}; + `, + badgeFinishReasonStop: css` + background: ${t(colors.green[50], colors.green[900] + '40')}; + color: ${t(colors.green[700], colors.green[300])}; + `, + badgeFinishReasonToolCalls: css` + background: ${t(colors.yellow[50], colors.yellow[900] + '40')}; + color: ${t(colors.yellow[700], colors.yellow[300])}; + `, + badgeUsage: css` + background: ${t(colors.purple[50], colors.purple[900] + '40')}; + color: ${t(colors.purple[700], colors.purple[300])}; + `, + cardBody: css` + display: grid; + grid-template-rows: 0fr; + transition: grid-template-rows 0.2s ease-out; + `, + cardBodyOpen: css` + grid-template-rows: 1fr; + `, + cardBodyInner: css` + overflow: hidden; + `, + userBubble: css` + width: 22px; + height: 22px; + border-radius: 50%; + display: flex; + align-items: center; + justify-content: center; + font-size: 10px; + font-weight: ${font.weight.bold}; + color: ${colors.white}; + flex-shrink: 0; + background: ${t(colors.blue[500], colors.blue[600])}; + margin-top: 1px; + `, + + // --- Iteration card (inside user group) --- + iterCard: css` + position: relative; + background: ${t(colors.gray[50], colors.darkGray[700])}; + border-top: 1px solid ${t(colors.gray[200], colors.darkGray[500])}; + overflow: hidden; + animation: iterStaggerIn 0.3s ease-out both; + width: 100%; + + @keyframes iterStaggerIn { + from { + opacity: 0; + transform: translateY(4px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + `, + iterCardHeader: css` + display: flex; + align-items: center; + gap: ${size[2]}; + padding: ${size[2]} ${size[3]}; + cursor: pointer; + user-select: none; + + &:hover { + background: ${t(colors.gray[100], colors.darkGray[600])}; + } + `, + iterHeaderCompleted: css` + border-left: 3px solid ${t(colors.green[400], colors.green[500])}; + `, + iterHeaderError: css` + border-left: 3px solid ${t(colors.red[400], colors.red[500])}; + `, + iterHeaderActive: css` + border-left: 3px solid ${t(colors.blue[400], colors.blue[500])}; + animation: iterActivePulse 2s ease-in-out infinite; + + @keyframes iterActivePulse { + 0%, 100% { border-left-color: ${t(colors.blue[400], colors.blue[500])}; } + 50% { border-left-color: ${t(colors.blue[200], colors.blue[700])}; } + } + `, + iterCardTitle: css` + font-size: ${fontSize.xs}; + font-weight: ${font.weight.semibold}; + color: ${t(colors.gray[700], colors.gray[300])}; + `, + + // --- Config row --- + configRow: css` + display: flex; + align-items: center; + flex-wrap: wrap; + gap: ${size[2]}; + padding: ${size[1.5]} ${size[3]}; + font-size: ${fontSize.xs}; + color: ${t(colors.gray[500], colors.gray[400])}; + font-family: ${fontFamily.mono}; + border-bottom: 1px solid ${t(colors.gray[100], colors.darkGray[600])}; + `, + configRowText: css` + color: ${t(colors.gray[500], colors.gray[400])}; + `, + configRowMeta: css` + color: ${t(colors.gray[400], colors.gray[500])}; + `, + configDiffChip: css` + font-size: 9px; + padding: 0 ${size[1]}; + border-radius: ${border.radius.xs}; + background: ${t(colors.yellow[100], colors.yellow[900] + '40')}; + color: ${t(colors.yellow[700], colors.yellow[400])}; + font-weight: ${font.weight.bold}; + text-transform: uppercase; + letter-spacing: 0.05em; + `, + configDetails: css` + width: 100%; + margin-top: ${size[2]}; + display: flex; + flex-direction: column; + gap: ${size[2]}; + `, + configDiffSection: css` + display: flex; + flex-direction: column; + gap: ${size[1]}; + `, + configDiffRow: css` + display: flex; + align-items: center; + gap: ${size[1]}; + font-size: 10px; + `, + configDiffKey: css` + font-weight: ${font.weight.semibold}; + color: ${t(colors.gray[600], colors.gray[300])}; + `, + configDiffFrom: css` + color: ${t(colors.red[600], colors.red[400])}; + text-decoration: line-through; + `, + configDiffArrow: css` + color: ${t(colors.gray[400], colors.gray[500])}; + `, + configDiffTo: css` + color: ${t(colors.green[600], colors.green[400])}; + font-weight: ${font.weight.semibold}; + `, + configSystemPrompts: css` + display: flex; + flex-direction: column; + gap: ${size[1]}; + `, + systemPromptItem: css` + font-size: ${fontSize.xs}; + padding: ${size[2]}; + border-radius: ${border.radius.sm}; + background: ${t(colors.gray[50], colors.darkGray[600])}; + white-space: pre-wrap; + word-break: break-word; + line-height: 1.5; + color: ${t(colors.gray[700], colors.gray[200])}; + max-height: 120px; + overflow-y: auto; + `, + // --- Step row --- + step: css` + display: flex; + align-items: center; + gap: ${size[1.5]}; + padding: ${size[1.5]} ${size[3]}; + font-size: ${fontSize.xs}; + border-bottom: 1px solid ${t(colors.gray[100], colors.darkGray[600])}; + + &:last-child { + border-bottom: none; + } + `, + stepResponseLong: css` + flex-direction: column; + align-items: flex-start; + gap: ${size[1]}; + `, + stepPrefix: css` + flex-shrink: 0; + font-size: 10px; + font-weight: ${font.weight.bold}; + text-transform: uppercase; + letter-spacing: 0.03em; + padding: 1px ${size[1.5]}; + border-radius: ${border.radius.xs}; + `, + stepPrefixMiddleware: css` + color: ${t(colors.purple[700], colors.purple[300])}; + background: ${t(colors.purple[50], colors.purple[900] + '30')}; + `, + stepPrefixToolCall: css` + color: ${t(colors.yellow[800], colors.yellow[300])}; + background: ${t(colors.yellow[50], colors.yellow[900] + '30')}; + `, + stepPrefixToolResult: css` + color: ${t(colors.cyan[700], colors.cyan[300])}; + background: ${t(colors.cyan[900] + '15', colors.cyan[900] + '30')}; + `, + stepPrefixAssistant: css` + color: ${t(colors.blue[700], colors.blue[300])}; + background: ${t(colors.blue[50], colors.blue[900] + '30')}; + `, + stepPrefixThinking: css` + color: ${t(colors.pink[700], colors.pink[300])}; + background: ${t(colors.pink[50], colors.pink[900] + '30')}; + `, + stepContent: css` + flex: 1; + min-width: 0; + color: ${t(colors.gray[600], colors.gray[400])}; + font-family: ${fontFamily.mono}; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + `, + stepContentLong: css` + color: ${t(colors.gray[700], colors.gray[200])}; + font-size: ${fontSize.sm}; + line-height: 1.5; + white-space: pre-wrap; + word-break: break-word; + width: 100%; + `, + stepDuration: css` + flex-shrink: 0; + font-family: ${fontFamily.mono}; + font-size: 10px; + color: ${t(colors.gray[400], colors.gray[500])}; + `, + stepExpandToggle: css` + flex-shrink: 0; + cursor: pointer; + color: ${t(colors.blue[500], colors.blue[400])}; + font-size: 10px; + user-select: none; + padding: 0 ${size[1]}; + + &:hover { + text-decoration: underline; + } + `, + stepJsonPanel: css` + margin: ${size[1.5]} ${size[3]}; + padding: ${size[1.5]} ${size[2]} ${size[1.5]} ${size[5]}; + border-radius: ${border.radius.sm}; + border: 1px solid ${t(colors.gray[200], colors.darkGray[500])}; + background: ${t(colors.gray[50], colors.darkGray[800])}; + `, + stepDetail: css` + padding: ${size[2]}; + margin: 0 ${size[3]} ${size[1.5]}; + border-radius: ${border.radius.sm}; + background: ${t(colors.gray[50], colors.darkGray[600])}; + font-size: ${fontSize.xs}; + font-family: ${fontFamily.mono}; + white-space: pre-wrap; + word-break: break-word; + max-height: 200px; + overflow-y: auto; + color: ${t(colors.gray[700], colors.gray[200])}; + `, + responseDetail: css` + padding: ${size[3]}; + margin: 0 ${size[3]} ${size[1.5]}; + border-radius: ${border.radius.sm}; + background: ${t(colors.gray[50], colors.darkGray[600])}; + font-size: ${fontSize.sm}; + line-height: 1.6; + white-space: pre-wrap; + word-break: break-word; + max-height: 400px; + overflow-y: auto; + color: ${t(colors.gray[800], colors.gray[100])}; + `, + thinkingDetail: css` + padding: ${size[3]}; + margin: 0 ${size[3]} ${size[1.5]}; + border-radius: ${border.radius.sm}; + background: ${t(colors.pink[50], colors.pink[900] + '15')}; + border-left: 3px solid ${t(colors.pink[300], colors.pink[600])}; + font-size: ${fontSize.sm}; + font-style: italic; + line-height: 1.6; + white-space: pre-wrap; + word-break: break-word; + max-height: 400px; + overflow-y: auto; + color: ${t(colors.gray[700], colors.gray[200])}; + `, + jsonTreeContainer: css` + padding: ${size[1]} ${size[3]} ${size[2]}; + `, + + // --- Middleware badge styles --- + mwBadge: css` + display: inline-flex; + align-items: center; + padding: 1px ${size[1.5]}; + border-radius: ${border.radius.sm}; + font-size: 10px; + font-family: ${fontFamily.mono}; + font-weight: ${font.weight.semibold}; + white-space: nowrap; + flex-shrink: 0; + `, + mwBadgeDefault: css` + background: ${t(colors.gray[100], colors.darkGray[500])}; + color: ${t(colors.gray[600], colors.gray[300])}; + `, + mwBadgeTransform: css` + background: ${t(colors.purple[50], colors.purple[900] + '30')}; + color: ${t(colors.purple[700], colors.purple[300])}; + `, + mwBadgeError: css` + background: ${t(colors.red[50], colors.red[900] + '30')}; + color: ${t(colors.red[700], colors.red[300])}; + `, + mwBadgeToolCall: css` + background: ${t(colors.yellow[50], colors.yellow[900] + '30')}; + color: ${t(colors.yellow[800], colors.yellow[300])}; + `, + mwBadgeToolResult: css` + background: ${t(colors.cyan[900] + '15', colors.cyan[900] + '30')}; + color: ${t(colors.cyan[700], colors.cyan[300])}; + `, + mwHook: css` + font-size: 10px; + font-family: ${fontFamily.mono}; + color: ${t(colors.gray[500], colors.gray[400])}; + flex-shrink: 0; + `, + mwSuffix: css` + font-size: 9px; + font-weight: ${font.weight.bold}; + text-transform: uppercase; + letter-spacing: 0.03em; + padding: 1px ${size[1]}; + border-radius: ${border.radius.xs}; + background: ${t(colors.yellow[100], colors.yellow[900] + '40')}; + color: ${t(colors.yellow[800], colors.yellow[300])}; + flex-shrink: 0; + `, + mwChangesContainer: css` + padding: ${size[1]} ${size[3]} ${size[2]}; + `, + + // --- JSON Viewer --- + jsonViewer: css` + border: 1px solid ${t(colors.gray[200], colors.darkGray[500])}; + border-radius: ${border.radius.sm}; + overflow: hidden; + background: ${t(colors.gray[50], colors.darkGray[800])}; + `, + jsonViewerHeader: css` + display: flex; + align-items: center; + gap: ${size[1.5]}; + padding: ${size[1.5]} ${size[2]}; + cursor: pointer; + user-select: none; + font-size: 10px; + + &:hover { + background: ${t(colors.gray[100], colors.darkGray[700])}; + } + `, + jsonViewerChevron: css` + color: ${t(colors.gray[400], colors.gray[500])}; + font-size: 8px; + transition: transform 0.15s ease; + flex-shrink: 0; + `, + jsonViewerLabel: css` + font-weight: ${font.weight.semibold}; + color: ${t(colors.gray[600], colors.gray[300])}; + font-size: 10px; + flex-shrink: 0; + `, + jsonViewerPreview: css` + flex: 1; + min-width: 0; + color: ${t(colors.gray[500], colors.gray[400])}; + font-family: ${fontFamily.mono}; + font-size: 10px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + `, + jsonViewerContent: css` + margin: 0; + padding: ${size[2]}; + font-size: ${fontSize.xs}; + font-family: ${fontFamily.mono}; + white-space: pre-wrap; + word-break: break-word; + max-height: 300px; + overflow-y: auto; + color: ${t(colors.gray[700], colors.gray[200])}; + border-top: 1px solid ${t(colors.gray[200], colors.darkGray[600])}; + background: ${t(colors.gray[50], colors.darkGray[900])}; + line-height: 1.5; + `, + jsonViewerContainer: css` + padding: ${size[1]} ${size[3]} ${size[2]}; + `, + + // --- Standalone middleware display (MiddlewareEventsSection) --- + middlewareContainer: css` + display: flex; + flex-wrap: wrap; + gap: ${size[1]}; + `, + middlewarePill: css` + display: inline-flex; + align-items: center; + gap: ${size[1]}; + padding: 1px ${size[1.5]}; + border-radius: ${border.radius.sm}; + font-size: 10px; + font-family: ${fontFamily.mono}; + background: ${t(colors.gray[100], colors.darkGray[500])}; + color: ${t(colors.gray[600], colors.gray[300])}; + white-space: nowrap; + `, + middlewarePillTransform: css` + background: ${t(colors.purple[50], colors.purple[900] + '30')}; + color: ${t(colors.purple[700], colors.purple[300])}; + `, + middlewarePillSuffix: css` + font-weight: ${font.weight.bold}; + font-size: 9px; + text-transform: uppercase; + `, + + noIterations: css` + text-align: center; + padding: ${size[6]}; + color: ${t(colors.gray[400], colors.gray[500])}; + font-size: ${fontSize.sm}; + `, + }, } } diff --git a/packages/typescript/ai/src/activities/chat/middleware/compose.ts b/packages/typescript/ai/src/activities/chat/middleware/compose.ts index d8d2df4fc..460da29b5 100644 --- a/packages/typescript/ai/src/activities/chat/middleware/compose.ts +++ b/packages/typescript/ai/src/activities/chat/middleware/compose.ts @@ -1,3 +1,4 @@ +import { aiEventClient } from '../../../event-client.js' import type { StreamChunk } from '../../../types' import type { AbortInfo, @@ -14,6 +15,21 @@ import type { UsageInfo, } from './types' +/** Check if a middleware should be skipped for instrumentation events. */ +function shouldSkipInstrumentation(mw: ChatMiddleware): boolean { + return mw.name === 'devtools' +} + +/** Build the base context for middleware instrumentation events. */ +function instrumentCtx(ctx: ChatMiddlewareContext) { + return { + requestId: ctx.requestId, + streamId: ctx.streamId, + clientId: ctx.conversationId, + timestamp: Date.now(), + } +} + /** * Internal middleware runner that manages composed execution of middleware hooks. * Created once per chat() invocation. @@ -41,10 +57,32 @@ export class MiddlewareRunner { let current = config for (const mw of this.middlewares) { if (mw.onConfig) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() const result = await mw.onConfig(ctx, current) - if (result !== undefined && result !== null) { + const hasTransform = result !== undefined && result !== null + if (hasTransform) { current = { ...current, ...result } } + if (!skip) { + const base = instrumentCtx(ctx) + aiEventClient.emit('middleware:hook:executed', { + ...base, + middlewareName: mw.name || 'unnamed', + hookName: 'onConfig', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform, + }) + if (hasTransform) { + aiEventClient.emit('middleware:config:transformed', { + ...base, + middlewareName: mw.name || 'unnamed', + iteration: ctx.iteration, + changes: result as Record, + }) + } + } } } return current @@ -56,7 +94,19 @@ export class MiddlewareRunner { async runOnStart(ctx: ChatMiddlewareContext): Promise { for (const mw of this.middlewares) { if (mw.onStart) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() await mw.onStart(ctx) + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onStart', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform: false, + }) + } } } } @@ -78,22 +128,50 @@ export class MiddlewareRunner { for (const mw of this.middlewares) { if (!mw.onChunk) continue + const skip = shouldSkipInstrumentation(mw) const nextChunks: Array = [] for (const c of chunks) { const result = await mw.onChunk(ctx, c) if (result === null) { // Drop this chunk + if (!skip) { + aiEventClient.emit('middleware:chunk:transformed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + originalChunkType: c.type, + resultCount: 0, + wasDropped: true, + }) + } continue } else if (result === undefined) { - // Pass through + // Pass through — no instrumentation for pass-throughs nextChunks.push(c) } else if (Array.isArray(result)) { // Expand nextChunks.push(...result) + if (!skip) { + aiEventClient.emit('middleware:chunk:transformed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + originalChunkType: c.type, + resultCount: result.length, + wasDropped: false, + }) + } } else { // Replace nextChunks.push(result) + if (!skip) { + aiEventClient.emit('middleware:chunk:transformed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + originalChunkType: c.type, + resultCount: 1, + wasDropped: false, + }) + } } } chunks = nextChunks @@ -112,8 +190,21 @@ export class MiddlewareRunner { ): Promise { for (const mw of this.middlewares) { if (mw.onBeforeToolCall) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() const decision = await mw.onBeforeToolCall(ctx, hookCtx) - if (decision !== undefined && decision !== null) { + const hasTransform = decision !== undefined && decision !== null + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onBeforeToolCall', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform, + }) + } + if (hasTransform) { return decision } } @@ -130,7 +221,19 @@ export class MiddlewareRunner { ): Promise { for (const mw of this.middlewares) { if (mw.onAfterToolCall) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() await mw.onAfterToolCall(ctx, info) + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onAfterToolCall', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform: false, + }) + } } } } @@ -144,7 +247,19 @@ export class MiddlewareRunner { ): Promise { for (const mw of this.middlewares) { if (mw.onUsage) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() await mw.onUsage(ctx, usage) + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onUsage', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform: false, + }) + } } } } @@ -158,7 +273,19 @@ export class MiddlewareRunner { ): Promise { for (const mw of this.middlewares) { if (mw.onFinish) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() await mw.onFinish(ctx, info) + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onFinish', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform: false, + }) + } } } } @@ -169,7 +296,19 @@ export class MiddlewareRunner { async runOnAbort(ctx: ChatMiddlewareContext, info: AbortInfo): Promise { for (const mw of this.middlewares) { if (mw.onAbort) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() await mw.onAbort(ctx, info) + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onAbort', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform: false, + }) + } } } } @@ -180,7 +319,19 @@ export class MiddlewareRunner { async runOnError(ctx: ChatMiddlewareContext, info: ErrorInfo): Promise { for (const mw of this.middlewares) { if (mw.onError) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() await mw.onError(ctx, info) + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onError', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform: false, + }) + } } } } @@ -195,7 +346,19 @@ export class MiddlewareRunner { ): Promise { for (const mw of this.middlewares) { if (mw.onIteration) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() await mw.onIteration(ctx, info) + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onIteration', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform: false, + }) + } } } } @@ -210,7 +373,19 @@ export class MiddlewareRunner { ): Promise { for (const mw of this.middlewares) { if (mw.onToolPhaseComplete) { + const skip = shouldSkipInstrumentation(mw) + const start = Date.now() await mw.onToolPhaseComplete(ctx, info) + if (!skip) { + aiEventClient.emit('middleware:hook:executed', { + ...instrumentCtx(ctx), + middlewareName: mw.name || 'unnamed', + hookName: 'onToolPhaseComplete', + iteration: ctx.iteration, + duration: Date.now() - start, + hasTransform: false, + }) + } } } } diff --git a/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts b/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts index f6e1e2423..9dec08c89 100644 --- a/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts +++ b/packages/typescript/ai/src/activities/chat/middleware/devtools-middleware.ts @@ -52,6 +52,8 @@ export function devtoolsMiddleware(): ChatMiddleware { // runs first, before the engine updates ctx.currentMessageId / ctx.accumulatedContent let localMessageId: string | null = null let localAccumulatedContent = '' + let currentIteration = -1 + let iterationStartTime = 0 return { name: 'devtools', @@ -99,15 +101,41 @@ export function devtoolsMiddleware(): ChatMiddleware { }, onIteration(ctx: ChatMiddlewareContext, info: IterationInfo) { + const now = Date.now() + + // Emit completed for previous iteration (it ended with tool_calls if we got here) + if (currentIteration >= 0) { + aiEventClient.emit('text:iteration:completed', { + ...buildEventContext(ctx), + iteration: currentIteration, + messageId: localMessageId || undefined, + duration: now - iterationStartTime, + finishReason: 'tool_calls', + timestamp: now, + }) + } + + // Track new iteration + currentIteration = info.iteration + iterationStartTime = now localMessageId = info.messageId localAccumulatedContent = '' + // Emit iteration:started with config snapshot + aiEventClient.emit('text:iteration:started', { + ...buildEventContext(ctx), + iteration: info.iteration, + messageId: info.messageId, + timestamp: now, + }) + + // Emit assistant message placeholder aiEventClient.emit('text:message:created', { ...buildEventContext(ctx), messageId: info.messageId, role: 'assistant' as const, content: '', - timestamp: Date.now(), + timestamp: now, }) }, @@ -272,6 +300,21 @@ export function devtoolsMiddleware(): ChatMiddleware { }, onFinish(ctx, info) { + const now = Date.now() + + // Emit completed for the final iteration + if (currentIteration >= 0) { + aiEventClient.emit('text:iteration:completed', { + ...buildEventContext(ctx), + iteration: currentIteration, + messageId: localMessageId || undefined, + duration: now - iterationStartTime, + finishReason: info.finishReason || undefined, + usage: info.usage, + timestamp: now, + }) + } + aiEventClient.emit('text:request:completed', { ...buildEventContext(ctx), content: info.content, @@ -279,7 +322,7 @@ export function devtoolsMiddleware(): ChatMiddleware { finishReason: info.finishReason || undefined, usage: info.usage, duration: info.duration, - timestamp: Date.now(), + timestamp: now, }) }, } diff --git a/packages/typescript/ai/src/event-client.ts b/packages/typescript/ai/src/event-client.ts index b76d67055..f25f69a5c 100644 --- a/packages/typescript/ai/src/event-client.ts +++ b/packages/typescript/ai/src/event-client.ts @@ -162,6 +162,65 @@ export interface TextUsageEvent extends BaseEventContext { usage: TokenUsage } +// =========================== +// Iteration Events +// =========================== + +/** Emitted when a new agent loop iteration begins, with a config snapshot. */ +export interface TextIterationStartedEvent extends BaseEventContext { + requestId: string + streamId: string + iteration: number + messageId: string + provider: string + model: string +} + +/** Emitted when an agent loop iteration completes. */ +export interface TextIterationCompletedEvent extends BaseEventContext { + requestId: string + streamId: string + iteration: number + messageId?: string + duration: number + finishReason?: string + usage?: TokenUsage +} + +// =========================== +// Middleware Events +// =========================== + +/** Emitted when a middleware hook completes execution. */ +export interface MiddlewareHookExecutedEvent extends BaseEventContext { + requestId: string + streamId: string + middlewareName: string + hookName: string + iteration: number + duration: number + hasTransform: boolean +} + +/** Emitted when onConfig returns a non-void transform. */ +export interface MiddlewareConfigTransformedEvent extends BaseEventContext { + requestId: string + streamId: string + middlewareName: string + iteration: number + changes: Record +} + +/** Emitted when onChunk transforms, drops, or expands a chunk. */ +export interface MiddlewareChunkTransformedEvent extends BaseEventContext { + requestId: string + streamId: string + middlewareName: string + originalChunkType: string + resultCount: number + wasDropped: boolean +} + // =========================== // Tool Events // =========================== @@ -442,6 +501,15 @@ export interface AIDevtoolsEventMap { 'tanstack-ai-devtools:text:chunk:error': TextChunkErrorEvent 'tanstack-ai-devtools:text:usage': TextUsageEvent + // Iteration events + 'tanstack-ai-devtools:text:iteration:started': TextIterationStartedEvent + 'tanstack-ai-devtools:text:iteration:completed': TextIterationCompletedEvent + + // Middleware events + 'tanstack-ai-devtools:middleware:hook:executed': MiddlewareHookExecutedEvent + 'tanstack-ai-devtools:middleware:config:transformed': MiddlewareConfigTransformedEvent + 'tanstack-ai-devtools:middleware:chunk:transformed': MiddlewareChunkTransformedEvent + // Tool events 'tanstack-ai-devtools:tools:approval:requested': ToolsApprovalRequestedEvent 'tanstack-ai-devtools:tools:approval:responded': ToolsApprovalRespondedEvent From 2c5968b87e64db48da2aba7fea0348167c9ddc74 Mon Sep 17 00:00:00 2001 From: Alem Tuzlak Date: Mon, 2 Mar 2026 14:37:29 +0100 Subject: [PATCH 07/16] feat: add image generation feature and update routing --- .../ts-react-chat/src/components/Header.tsx | 15 +- examples/ts-react-chat/src/routeTree.gen.ts | 48 ++- .../ts-react-chat/src/routes/api.image-gen.ts | 48 +++ .../ts-react-chat/src/routes/image-gen.tsx | 302 ++++++++++++++++++ examples/ts-react-chat/src/routes/index.tsx | 11 +- 5 files changed, 420 insertions(+), 4 deletions(-) create mode 100644 examples/ts-react-chat/src/routes/api.image-gen.ts create mode 100644 examples/ts-react-chat/src/routes/image-gen.tsx diff --git a/examples/ts-react-chat/src/components/Header.tsx b/examples/ts-react-chat/src/components/Header.tsx index 57745b7b0..7e5037a65 100644 --- a/examples/ts-react-chat/src/components/Header.tsx +++ b/examples/ts-react-chat/src/components/Header.tsx @@ -1,7 +1,7 @@ import { Link } from '@tanstack/react-router' import { useState } from 'react' -import { Guitar, Home, Menu, X } from 'lucide-react' +import { Guitar, Home, Image, Menu, X } from 'lucide-react' export default function Header() { const [isOpen, setIsOpen] = useState(false) @@ -57,6 +57,19 @@ export default function Header() { Home + setIsOpen(false)} + className="flex items-center gap-3 p-3 rounded-lg hover:bg-gray-800 transition-colors mb-2" + activeProps={{ + className: + 'flex items-center gap-3 p-3 rounded-lg bg-cyan-600 hover:bg-cyan-700 transition-colors mb-2', + }} + > + + Image Gen + +
rootRouteImport, +} as any) const IndexRoute = IndexRouteImport.update({ id: '/', path: '/', @@ -24,6 +31,11 @@ const ApiTanchatRoute = ApiTanchatRouteImport.update({ path: '/api/tanchat', getParentRoute: () => rootRouteImport, } as any) +const ApiImageGenRoute = ApiImageGenRouteImport.update({ + id: '/api/image-gen', + path: '/api/image-gen', + getParentRoute: () => rootRouteImport, +} as any) const ExampleGuitarsIndexRoute = ExampleGuitarsIndexRouteImport.update({ id: '/example/guitars/', path: '/example/guitars/', @@ -37,12 +49,16 @@ const ExampleGuitarsGuitarIdRoute = ExampleGuitarsGuitarIdRouteImport.update({ export interface FileRoutesByFullPath { '/': typeof IndexRoute + '/image-gen': typeof ImageGenRoute + '/api/image-gen': typeof ApiImageGenRoute '/api/tanchat': typeof ApiTanchatRoute '/example/guitars/$guitarId': typeof ExampleGuitarsGuitarIdRoute '/example/guitars/': typeof ExampleGuitarsIndexRoute } export interface FileRoutesByTo { '/': typeof IndexRoute + '/image-gen': typeof ImageGenRoute + '/api/image-gen': typeof ApiImageGenRoute '/api/tanchat': typeof ApiTanchatRoute '/example/guitars/$guitarId': typeof ExampleGuitarsGuitarIdRoute '/example/guitars': typeof ExampleGuitarsIndexRoute @@ -50,6 +66,8 @@ export interface FileRoutesByTo { export interface FileRoutesById { __root__: typeof rootRouteImport '/': typeof IndexRoute + '/image-gen': typeof ImageGenRoute + '/api/image-gen': typeof ApiImageGenRoute '/api/tanchat': typeof ApiTanchatRoute '/example/guitars/$guitarId': typeof ExampleGuitarsGuitarIdRoute '/example/guitars/': typeof ExampleGuitarsIndexRoute @@ -58,14 +76,24 @@ export interface FileRouteTypes { fileRoutesByFullPath: FileRoutesByFullPath fullPaths: | '/' + | '/image-gen' + | '/api/image-gen' | '/api/tanchat' | '/example/guitars/$guitarId' | '/example/guitars/' fileRoutesByTo: FileRoutesByTo - to: '/' | '/api/tanchat' | '/example/guitars/$guitarId' | '/example/guitars' + to: + | '/' + | '/image-gen' + | '/api/image-gen' + | '/api/tanchat' + | '/example/guitars/$guitarId' + | '/example/guitars' id: | '__root__' | '/' + | '/image-gen' + | '/api/image-gen' | '/api/tanchat' | '/example/guitars/$guitarId' | '/example/guitars/' @@ -73,6 +101,8 @@ export interface FileRouteTypes { } export interface RootRouteChildren { IndexRoute: typeof IndexRoute + ImageGenRoute: typeof ImageGenRoute + ApiImageGenRoute: typeof ApiImageGenRoute ApiTanchatRoute: typeof ApiTanchatRoute ExampleGuitarsGuitarIdRoute: typeof ExampleGuitarsGuitarIdRoute ExampleGuitarsIndexRoute: typeof ExampleGuitarsIndexRoute @@ -80,6 +110,13 @@ export interface RootRouteChildren { declare module '@tanstack/react-router' { interface FileRoutesByPath { + '/image-gen': { + id: '/image-gen' + path: '/image-gen' + fullPath: '/image-gen' + preLoaderRoute: typeof ImageGenRouteImport + parentRoute: typeof rootRouteImport + } '/': { id: '/' path: '/' @@ -94,6 +131,13 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof ApiTanchatRouteImport parentRoute: typeof rootRouteImport } + '/api/image-gen': { + id: '/api/image-gen' + path: '/api/image-gen' + fullPath: '/api/image-gen' + preLoaderRoute: typeof ApiImageGenRouteImport + parentRoute: typeof rootRouteImport + } '/example/guitars/': { id: '/example/guitars/' path: '/example/guitars' @@ -113,6 +157,8 @@ declare module '@tanstack/react-router' { const rootRouteChildren: RootRouteChildren = { IndexRoute: IndexRoute, + ImageGenRoute: ImageGenRoute, + ApiImageGenRoute: ApiImageGenRoute, ApiTanchatRoute: ApiTanchatRoute, ExampleGuitarsGuitarIdRoute: ExampleGuitarsGuitarIdRoute, ExampleGuitarsIndexRoute: ExampleGuitarsIndexRoute, diff --git a/examples/ts-react-chat/src/routes/api.image-gen.ts b/examples/ts-react-chat/src/routes/api.image-gen.ts new file mode 100644 index 000000000..678b3f8c5 --- /dev/null +++ b/examples/ts-react-chat/src/routes/api.image-gen.ts @@ -0,0 +1,48 @@ +import { createFileRoute } from '@tanstack/react-router' +import { generateImage } from '@tanstack/ai' +import { openRouterImage } from '@tanstack/ai-openrouter' + +export const Route = createFileRoute('/api/image-gen')({ + server: { + handlers: { + POST: async ({ request }) => { + const body = await request.json() + const { prompt, model, size } = body as { + prompt: string + model: string + size?: string + } + + if (!prompt) { + return new Response( + JSON.stringify({ error: 'Prompt is required' }), + { status: 400, headers: { 'Content-Type': 'application/json' } }, + ) + } + + try { + const result = await generateImage({ + adapter: openRouterImage((model || 'openai/gpt-5-image-mini') as 'openai/gpt-5-image-mini'), + prompt, + ...(size ? { size: size as any } : {}), + }) + + return new Response(JSON.stringify(result), { + headers: { 'Content-Type': 'application/json' }, + }) + } catch (error: any) { + console.error('[Image Gen API] Error:', { + message: error?.message, + name: error?.name, + status: error?.status, + stack: error?.stack, + }) + return new Response( + JSON.stringify({ error: error.message || 'Image generation failed' }), + { status: 500, headers: { 'Content-Type': 'application/json' } }, + ) + } + }, + }, + }, +}) diff --git a/examples/ts-react-chat/src/routes/image-gen.tsx b/examples/ts-react-chat/src/routes/image-gen.tsx new file mode 100644 index 000000000..8018df78e --- /dev/null +++ b/examples/ts-react-chat/src/routes/image-gen.tsx @@ -0,0 +1,302 @@ +import { useState } from 'react' +import { createFileRoute } from '@tanstack/react-router' +import { Download, Loader2, Send, X } from 'lucide-react' + +interface GeneratedImage { + url?: string + b64Json?: string + revisedPrompt?: string +} + +interface ImageGenResult { + id: string + model: string + images: Array + usage?: { + inputTokens?: number + outputTokens?: number + totalTokens?: number + } +} + +const IMAGE_MODELS = [ + { value: 'openai/gpt-5-image-mini', label: 'OpenAI GPT-5 Image Mini' }, + { value: 'openai/gpt-5-image', label: 'OpenAI GPT-5 Image' }, + { + value: 'google/gemini-2.5-flash-image', + label: 'Gemini 2.5 Flash Image', + }, + { + value: 'google/gemini-2.5-flash-image-preview', + label: 'Gemini 2.5 Flash Image Preview', + }, + { + value: 'google/gemini-3-pro-image-preview', + label: 'Gemini 3 Pro Image Preview', + }, +] as const + +const IMAGE_SIZES = [ + { value: '1024x1024', label: '1024x1024 (1:1)' }, + { value: '1248x832', label: '1248x832 (3:2)' }, + { value: '832x1248', label: '832x1248 (2:3)' }, + { value: '1184x864', label: '1184x864 (4:3)' }, + { value: '864x1184', label: '864x1184 (3:4)' }, + { value: '1344x768', label: '1344x768 (16:9)' }, + { value: '768x1344', label: '768x1344 (9:16)' }, +] as const + +function ImageGenPage() { + const [prompt, setPrompt] = useState('') + const [model, setModel] = useState(IMAGE_MODELS[0].value) + const [size, setSize] = useState(IMAGE_SIZES[0].value) + const [isLoading, setIsLoading] = useState(false) + const [result, setResult] = useState(null) + const [error, setError] = useState(null) + const [history, setHistory] = useState< + Array<{ prompt: string; result: ImageGenResult }> + >([]) + + async function handleGenerate() { + if (!prompt.trim() || isLoading) return + + setIsLoading(true) + setError(null) + setResult(null) + + try { + const res = await fetch('/api/image-gen', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ prompt: prompt.trim(), model, size }), + }) + + if (!res.ok) { + const errBody = await res.json().catch(() => null) + throw new Error(errBody?.error || `Request failed (${res.status})`) + } + + const data: ImageGenResult = await res.json() + setResult(data) + setHistory((prev) => [{ prompt: prompt.trim(), result: data }, ...prev]) + } catch (err: any) { + setError(err.message || 'Something went wrong') + } finally { + setIsLoading(false) + } + } + + function getImageSrc(image: GeneratedImage): string | null { + if (image.url) return image.url + if (image.b64Json) return `data:image/png;base64,${image.b64Json}` + return null + } + + return ( +
+ {/* Controls bar */} +
+
+
+ + +
+
+ + +
+
+
+ + {/* Main content area */} +
+ {/* Current result */} + {result && result.images.length > 0 && ( +
+
+ Prompt:{' '} + {history[0]?.prompt} + {result.usage?.totalTokens != null && ( + + ({result.usage.totalTokens} tokens) + + )} +
+
+ {result.images.map((image, i) => { + const src = getImageSrc(image) + if (!src) return null + return ( +
+ {image.revisedPrompt + {image.revisedPrompt && ( +
+ Revised:{' '} + {image.revisedPrompt} +
+ )} + + + +
+ ) + })} +
+
+ )} + + {/* Loading */} + {isLoading && ( +
+ +

Generating image...

+
+ )} + + {/* Error */} + {error && ( +
+
+ +
+

Generation failed

+

{error}

+
+
+
+ )} + + {/* Empty state */} + {!result && !isLoading && !error && ( +
+
+ + + +
+

+ Enter a prompt below to generate an image +

+
+ )} + + {/* History */} + {history.length > 1 && ( +
+

+ Previous Generations +

+
+ {history.slice(1).map((entry, i) => + entry.result.images.map((image, j) => { + const src = getImageSrc(image) + if (!src) return null + return ( +
+ {entry.prompt} +
+ {entry.prompt} +
+
+ ) + }), + )} +
+
+ )} +
+ + {/* Input area */} +
+
+
+
+