diff --git a/.changeset/config.json b/.changeset/config.json index af66336b2..29b38eb85 100644 --- a/.changeset/config.json +++ b/.changeset/config.json @@ -8,13 +8,7 @@ ], "commit": false, "ignore": ["livekit-agents-examples"], - "fixed": [ - [ - "@livekit/agents", - "@livekit/agents-plugin-*", - "@livekit/agents-plugins-test" - ] - ], + "fixed": [["@livekit/agents", "@livekit/agents-plugin-*", "@livekit/agents-plugins-test"]], "access": "public", "baseBranch": "main", "updateInternalDependencies": "patch", diff --git a/.changeset/flat-pets-walk.md b/.changeset/flat-pets-walk.md new file mode 100644 index 000000000..e8c919fee --- /dev/null +++ b/.changeset/flat-pets-walk.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +Add remote session event handler diff --git a/.changeset/green-tips-worry.md b/.changeset/green-tips-worry.md new file mode 100644 index 000000000..808c8b1dd --- /dev/null +++ b/.changeset/green-tips-worry.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Add Bargein Model Metrics Usages diff --git a/.changeset/lucky-grapes-care.md b/.changeset/lucky-grapes-care.md new file mode 100644 index 000000000..d14ff4227 --- /dev/null +++ b/.changeset/lucky-grapes-care.md @@ -0,0 +1,10 @@ +--- +"@livekit/agents": patch +"@livekit/agents-plugin-cartesia": patch +"@livekit/agents-plugin-deepgram": patch +"@livekit/agents-plugin-google": patch +"@livekit/agents-plugin-openai": patch +"livekit-agents-examples": patch +--- + +Add granular session models usage stats diff --git a/.changeset/silly-donkeys-shop.md b/.changeset/silly-donkeys-shop.md new file mode 100644 index 000000000..db7b099d9 --- /dev/null +++ b/.changeset/silly-donkeys-shop.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": minor +--- + +Refactor turn handling options and add barge-in model support diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f5a577688..b4472c81b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -46,11 +46,11 @@ jobs: - name: Test agents if: steps.filter.outputs.agents-or-tests == 'true' || github.event_name == 'push' run: pnpm test agents - - name: Test examples - if: (steps.filter.outputs.examples == 'true' || github.event_name == 'push') && secrets.OPENAI_API_KEY != '' - env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run: pnpm test:examples + # - name: Test examples + # if: (steps.filter.outputs.examples == 'true' || github.event_name == 'push') + # env: + # OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + # run: pnpm test:examples # TODO (AJS-83) Re-enable once plugins are refactored with abort controllers # - name: Test all plugins # if: steps.filter.outputs.agents-or-tests == 'true' || github.event_name != 'pull_request' diff --git a/agents/package.json b/agents/package.json index 828f38d3e..74b492986 100644 --- a/agents/package.json +++ b/agents/package.json @@ -69,6 +69,7 @@ "heap-js": "^2.6.0", "json-schema": "^0.4.0", "livekit-server-sdk": "^2.14.1", + "ofetch": "^1.5.1", "openai": "^6.8.1", "pidusage": "^4.0.1", "pino": "^8.19.0", diff --git a/agents/src/constants.ts b/agents/src/constants.ts index 86ead5b4c..ba9c37dee 100644 --- a/agents/src/constants.ts +++ b/agents/src/constants.ts @@ -7,3 +7,16 @@ export const TOPIC_TRANSCRIPTION = 'lk.transcription'; export const ATTRIBUTE_TRANSCRIPTION_SEGMENT_ID = 'lk.segment_id'; export const ATTRIBUTE_PUBLISH_ON_BEHALF = 'lk.publish_on_behalf'; export const TOPIC_CHAT = 'lk.chat'; + +export const ATTRIBUTE_AGENT_STATE = 'lk.agent.state'; +export const ATTRIBUTE_AGENT_NAME = 'lk.agent.name'; + +// TODO(eval): export const ATTRIBUTE_SIMULATOR = 'lk.simulator'; + +export const TOPIC_CLIENT_EVENTS = 'lk.agent.events'; +export const RPC_GET_SESSION_STATE = 'lk.agent.get_session_state'; +export const RPC_GET_CHAT_HISTORY = 'lk.agent.get_chat_history'; +export const RPC_GET_AGENT_INFO = 'lk.agent.get_agent_info'; +export const RPC_SEND_MESSAGE = 'lk.agent.send_message'; +export const TOPIC_AGENT_REQUEST = 'lk.agent.request'; +export const TOPIC_AGENT_RESPONSE = 'lk.agent.response'; diff --git a/agents/src/inference/interruption/defaults.ts b/agents/src/inference/interruption/defaults.ts new file mode 100644 index 000000000..b58dfd40f --- /dev/null +++ b/agents/src/inference/interruption/defaults.ts @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { ApiConnectOptions } from './interruption_stream.js'; +import type { InterruptionOptions } from './types.js'; + +export const MIN_INTERRUPTION_DURATION_IN_S = 0.025 * 2; // 25ms per frame, 2 consecutive frames +export const THRESHOLD = 0.5; +export const MAX_AUDIO_DURATION_IN_S = 3.0; +export const AUDIO_PREFIX_DURATION_IN_S = 0.5; +export const DETECTION_INTERVAL_IN_S = 0.1; +export const REMOTE_INFERENCE_TIMEOUT_IN_S = 1.0; +export const SAMPLE_RATE = 16000; +export const FRAMES_PER_SECOND = 40; +export const FRAME_DURATION_IN_S = 0.025; // 25ms per frame + +export const apiConnectDefaults: ApiConnectOptions = { + maxRetries: 3, + retryInterval: 2_000, + timeout: 10_000, +} as const; + +/** + * Calculate the retry interval using exponential backoff with jitter. + * Matches the Python implementation's _interval_for_retry behavior. + */ +export function intervalForRetry( + attempt: number, + baseInterval: number = apiConnectDefaults.retryInterval, +): number { + // Exponential backoff: baseInterval * 2^attempt with some jitter + const exponentialDelay = baseInterval * Math.pow(2, attempt); + // Add jitter (0-25% of the delay) + const jitter = exponentialDelay * Math.random() * 0.25; + return exponentialDelay + jitter; +} + +// baseUrl and useProxy are resolved dynamically in the constructor +// to respect LIVEKIT_REMOTE_EOT_URL environment variable +export const interruptionOptionDefaults: Omit = { + sampleRate: SAMPLE_RATE, + threshold: THRESHOLD, + minFrames: Math.ceil(MIN_INTERRUPTION_DURATION_IN_S * FRAMES_PER_SECOND), + maxAudioDurationInS: MAX_AUDIO_DURATION_IN_S, + audioPrefixDurationInS: AUDIO_PREFIX_DURATION_IN_S, + detectionIntervalInS: DETECTION_INTERVAL_IN_S, + inferenceTimeout: REMOTE_INFERENCE_TIMEOUT_IN_S * 1_000, + apiKey: process.env.LIVEKIT_API_KEY || '', + apiSecret: process.env.LIVEKIT_API_SECRET || '', + minInterruptionDurationInS: MIN_INTERRUPTION_DURATION_IN_S, +} as const; diff --git a/agents/src/inference/interruption/errors.ts b/agents/src/inference/interruption/errors.ts new file mode 100644 index 000000000..5b5f6d370 --- /dev/null +++ b/agents/src/inference/interruption/errors.ts @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +/** + * Error thrown during interruption detection. + */ +export class InterruptionDetectionError extends Error { + readonly type = 'interruption_detection_error' as const; + + readonly timestamp: number; + readonly label: string; + readonly recoverable: boolean; + + constructor(message: string, timestamp: number, label: string, recoverable: boolean) { + super(message); + this.name = 'InterruptionDetectionError'; + this.timestamp = timestamp; + this.label = label; + this.recoverable = recoverable; + } + + toString(): string { + return `${this.name}: ${this.message} (label=${this.label}, timestamp=${this.timestamp}, recoverable=${this.recoverable})`; + } +} diff --git a/agents/src/inference/interruption/http_transport.ts b/agents/src/inference/interruption/http_transport.ts new file mode 100644 index 000000000..b698ebc50 --- /dev/null +++ b/agents/src/inference/interruption/http_transport.ts @@ -0,0 +1,187 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { ofetch } from 'ofetch'; +import { TransformStream } from 'stream/web'; +import { z } from 'zod'; +import { log } from '../../log.js'; +import { createAccessToken } from '../utils.js'; +import { intervalForRetry } from './defaults.js'; +import { InterruptionCacheEntry } from './interruption_cache_entry.js'; +import type { OverlappingSpeechEvent } from './types.js'; +import type { BoundedCache } from './utils.js'; + +export interface PostOptions { + baseUrl: string; + token: string; + signal?: AbortSignal; + timeout?: number; + maxRetries?: number; +} + +export interface PredictOptions { + threshold: number; + minFrames: number; +} + +export const predictEndpointResponseSchema = z.object({ + created_at: z.number(), + is_bargein: z.boolean(), + probabilities: z.array(z.number()), +}); + +export type PredictEndpointResponse = z.infer; + +export interface PredictResponse { + createdAt: number; + isBargein: boolean; + probabilities: number[]; + predictionDurationInS: number; +} + +export async function predictHTTP( + data: Int16Array, + predictOptions: PredictOptions, + options: PostOptions, +): Promise { + const createdAt = performance.now(); + const url = new URL(`/bargein`, options.baseUrl); + url.searchParams.append('threshold', predictOptions.threshold.toString()); + url.searchParams.append('min_frames', predictOptions.minFrames.toFixed()); + url.searchParams.append('created_at', createdAt.toFixed()); + + let retryCount = 0; + const response = await ofetch(url.toString(), { + retry: options.maxRetries ?? 3, + retryDelay: () => { + const delay = intervalForRetry(retryCount); + retryCount++; + return delay; + }, + headers: { + 'Content-Type': 'application/octet-stream', + Authorization: `Bearer ${options.token}`, + }, + signal: options.signal, + timeout: options.timeout, + method: 'POST', + body: data, + }); + const { created_at, is_bargein, probabilities } = predictEndpointResponseSchema.parse(response); + + return { + createdAt: created_at, + isBargein: is_bargein, + probabilities, + predictionDurationInS: (performance.now() - createdAt) / 1000, + }; +} + +export interface HttpTransportOptions { + baseUrl: string; + apiKey: string; + apiSecret: string; + threshold: number; + minFrames: number; + timeout: number; + maxRetries?: number; +} + +export interface HttpTransportState { + overlapSpeechStarted: boolean; + overlapSpeechStartedAt: number | undefined; + cache: BoundedCache; +} + +/** + * Creates an HTTP transport TransformStream for interruption detection. + * + * This transport receives Int16Array audio slices and outputs InterruptionEvents. + * Each audio slice triggers an HTTP POST request. + * + * @param options - Transport options object. This is read on each request, so mutations + * to threshold/minFrames will be picked up dynamically. + */ +export function createHttpTransport( + options: HttpTransportOptions, + getState: () => HttpTransportState, + setState: (partial: Partial) => void, + updateUserSpeakingSpan?: (entry: InterruptionCacheEntry) => void, + getAndResetNumRequests?: () => number, +): TransformStream { + const logger = log(); + + return new TransformStream( + { + async transform(chunk, controller) { + if (!(chunk instanceof Int16Array)) { + controller.enqueue(chunk); + return; + } + + const state = getState(); + const overlapSpeechStartedAt = state.overlapSpeechStartedAt; + if (overlapSpeechStartedAt === undefined || !state.overlapSpeechStarted) return; + + try { + const resp = await predictHTTP( + chunk, + { threshold: options.threshold, minFrames: options.minFrames }, + { + baseUrl: options.baseUrl, + timeout: options.timeout, + maxRetries: options.maxRetries, + token: await createAccessToken(options.apiKey, options.apiSecret), + }, + ); + + const { createdAt, isBargein, probabilities, predictionDurationInS } = resp; + const entry = state.cache.setOrUpdate( + createdAt, + () => new InterruptionCacheEntry({ createdAt }), + { + probabilities, + isInterruption: isBargein, + speechInput: chunk, + totalDurationInS: (performance.now() - createdAt) / 1000, + detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000, + predictionDurationInS, + }, + ); + + if (state.overlapSpeechStarted && entry.isInterruption) { + if (updateUserSpeakingSpan) { + updateUserSpeakingSpan(entry); + } + const event: OverlappingSpeechEvent = { + type: 'user_overlapping_speech', + timestamp: Date.now(), + overlapStartedAt: overlapSpeechStartedAt, + isInterruption: entry.isInterruption, + speechInput: entry.speechInput, + probabilities: entry.probabilities, + totalDurationInS: entry.totalDurationInS, + predictionDurationInS: entry.predictionDurationInS, + detectionDelayInS: entry.detectionDelayInS, + probability: entry.probability, + numRequests: getAndResetNumRequests?.() ?? 0, + }; + logger.debug( + { + detectionDelayInS: entry.detectionDelayInS, + totalDurationInS: entry.totalDurationInS, + }, + 'interruption detected', + ); + setState({ overlapSpeechStarted: false }); + controller.enqueue(event); + } + } catch (err) { + logger.error({ err }, 'Failed to send audio data over HTTP'); + } + }, + }, + { highWaterMark: 2 }, + { highWaterMark: 2 }, + ); +} diff --git a/agents/src/inference/interruption/interruption_cache_entry.ts b/agents/src/inference/interruption/interruption_cache_entry.ts new file mode 100644 index 000000000..f318a04f5 --- /dev/null +++ b/agents/src/inference/interruption/interruption_cache_entry.ts @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { estimateProbability } from './utils.js'; + +/** + * Typed cache entry for interruption inference results. + * Mutable to support setOrUpdate pattern from Python's _BoundedCache. + */ +export class InterruptionCacheEntry { + createdAt: number; + requestStartedAt?: number; + totalDurationInS: number; + predictionDurationInS: number; + detectionDelayInS: number; + speechInput?: Int16Array; + probabilities?: number[]; + isInterruption?: boolean; + + constructor(params: { + createdAt: number; + requestStartedAt?: number; + speechInput?: Int16Array; + totalDurationInS?: number; + predictionDurationInS?: number; + detectionDelayInS?: number; + probabilities?: number[]; + isInterruption?: boolean; + }) { + this.createdAt = params.createdAt; + this.requestStartedAt = params.requestStartedAt; + this.totalDurationInS = params.totalDurationInS ?? 0; + this.predictionDurationInS = params.predictionDurationInS ?? 0; + this.detectionDelayInS = params.detectionDelayInS ?? 0; + this.speechInput = params.speechInput; + this.probabilities = params.probabilities; + this.isInterruption = params.isInterruption; + } + + /** + * The conservative estimated probability of the interruption event. + */ + get probability(): number { + return this.probabilities ? estimateProbability(this.probabilities) : 0; + } + + static default(): InterruptionCacheEntry { + return new InterruptionCacheEntry({ createdAt: 0 }); + } +} diff --git a/agents/src/inference/interruption/interruption_detector.ts b/agents/src/inference/interruption/interruption_detector.ts new file mode 100644 index 000000000..26793c4f0 --- /dev/null +++ b/agents/src/inference/interruption/interruption_detector.ts @@ -0,0 +1,188 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { TypedEventEmitter } from '@livekit/typed-emitter'; +import EventEmitter from 'events'; +import { log } from '../../log.js'; +import type { InterruptionMetrics } from '../../metrics/base.js'; +import { DEFAULT_INFERENCE_URL, STAGING_INFERENCE_URL, getDefaultInferenceUrl } from '../utils.js'; +import { FRAMES_PER_SECOND, SAMPLE_RATE, interruptionOptionDefaults } from './defaults.js'; +import type { InterruptionDetectionError } from './errors.js'; +import { InterruptionStreamBase } from './interruption_stream.js'; +import type { InterruptionOptions, OverlappingSpeechEvent } from './types.js'; + +type InterruptionCallbacks = { + user_overlapping_speech: (event: OverlappingSpeechEvent) => void; + metrics_collected: (metrics: InterruptionMetrics) => void; + error: (error: InterruptionDetectionError) => void; +}; + +export type AdaptiveInterruptionDetectorOptions = Omit, 'useProxy'>; + +export class AdaptiveInterruptionDetector extends (EventEmitter as new () => TypedEventEmitter) { + options: InterruptionOptions; + private readonly _label: string; + private logger = log(); + // Use Set instead of WeakSet to allow iteration for propagating option updates + private streams: Set = new Set(); + + constructor(options: AdaptiveInterruptionDetectorOptions = {}) { + super(); + + const { + maxAudioDurationInS, + baseUrl, + apiKey, + apiSecret, + audioPrefixDurationInS, + threshold, + detectionIntervalInS, + inferenceTimeout, + minInterruptionDurationInS, + } = { ...interruptionOptionDefaults, ...options }; + + if (maxAudioDurationInS > 3.0) { + throw new RangeError('maxAudioDurationInS must be less than or equal to 3.0 seconds'); + } + + const lkBaseUrl = baseUrl ?? process.env.LIVEKIT_REMOTE_EOT_URL ?? getDefaultInferenceUrl(); + let lkApiKey = apiKey ?? ''; + let lkApiSecret = apiSecret ?? ''; + let useProxy: boolean; + + // Use LiveKit credentials if using the inference service (production or staging) + const isInferenceUrl = + lkBaseUrl === DEFAULT_INFERENCE_URL || lkBaseUrl === STAGING_INFERENCE_URL; + if (isInferenceUrl) { + lkApiKey = + apiKey ?? process.env.LIVEKIT_INFERENCE_API_KEY ?? process.env.LIVEKIT_API_KEY ?? ''; + if (!lkApiKey) { + throw new TypeError( + 'apiKey is required, either as argument or set LIVEKIT_API_KEY environmental variable', + ); + } + + lkApiSecret = + apiSecret ?? + process.env.LIVEKIT_INFERENCE_API_SECRET ?? + process.env.LIVEKIT_API_SECRET ?? + ''; + if (!lkApiSecret) { + throw new TypeError( + 'apiSecret is required, either as argument or set LIVEKIT_API_SECRET environmental variable', + ); + } + useProxy = true; + } else { + useProxy = false; + } + + this.options = { + sampleRate: SAMPLE_RATE, + threshold, + minFrames: Math.ceil(minInterruptionDurationInS * FRAMES_PER_SECOND), + maxAudioDurationInS, + audioPrefixDurationInS, + detectionIntervalInS, + inferenceTimeout, + baseUrl: lkBaseUrl, + apiKey: lkApiKey, + apiSecret: lkApiSecret, + useProxy, + minInterruptionDurationInS, + }; + + this._label = `${this.constructor.name}`; + + this.logger.debug( + { + baseUrl: this.options.baseUrl, + detectionIntervalInS: this.options.detectionIntervalInS, + audioPrefixDurationInS: this.options.audioPrefixDurationInS, + maxAudioDurationInS: this.options.maxAudioDurationInS, + minFrames: this.options.minFrames, + threshold: this.options.threshold, + inferenceTimeout: this.options.inferenceTimeout, + useProxy: this.options.useProxy, + }, + 'adaptive interruption detector initialized', + ); + } + + /** + * The model identifier for this detector. + */ + get model(): string { + return 'adaptive interruption'; + } + + /** + * The provider identifier for this detector. + */ + get provider(): string { + return 'livekit'; + } + + /** + * The label for this detector instance. + */ + get label(): string { + return this._label; + } + + /** + * The sample rate used for audio processing. + */ + get sampleRate(): number { + return this.options.sampleRate; + } + + /** + * Emit an error event from the detector. + */ + emitError(error: InterruptionDetectionError): void { + this.emit('error', error); + } + + /** + * Creates a new InterruptionStreamBase for internal use. + * The stream can receive audio frames and sentinels via pushFrame(). + * Use this when you need direct access to the stream for pushing frames. + */ + createStream(): InterruptionStreamBase { + const streamBase = new InterruptionStreamBase(this, {}); + this.streams.add(streamBase); + return streamBase; + } + + /** + * Remove a stream from tracking (called when stream is closed). + */ + removeStream(stream: InterruptionStreamBase): void { + this.streams.delete(stream); + } + + /** + * Update options for the detector and propagate to all active streams. + * For WebSocket streams, this triggers a reconnection with new settings. + */ + async updateOptions(options: { + threshold?: number; + minInterruptionDurationInS?: number; + }): Promise { + if (options.threshold !== undefined) { + this.options.threshold = options.threshold; + } + if (options.minInterruptionDurationInS !== undefined) { + this.options.minInterruptionDurationInS = options.minInterruptionDurationInS; + this.options.minFrames = Math.ceil(options.minInterruptionDurationInS * FRAMES_PER_SECOND); + } + + // Propagate option updates to all active streams (matching Python behavior) + const updatePromises: Promise[] = []; + for (const stream of this.streams) { + updatePromises.push(stream.updateOptions(options)); + } + await Promise.all(updatePromises); + } +} diff --git a/agents/src/inference/interruption/interruption_stream.ts b/agents/src/inference/interruption/interruption_stream.ts new file mode 100644 index 000000000..ce45ae804 --- /dev/null +++ b/agents/src/inference/interruption/interruption_stream.ts @@ -0,0 +1,467 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { AudioFrame, AudioResampler } from '@livekit/rtc-node'; +import type { Span } from '@opentelemetry/api'; +import { type ReadableStream, TransformStream } from 'stream/web'; +import { log } from '../../log.js'; +import type { InterruptionMetrics } from '../../metrics/base.js'; +import { type StreamChannel, createStreamChannel } from '../../stream/stream_channel.js'; +import { traceTypes } from '../../telemetry/index.js'; +import { FRAMES_PER_SECOND, apiConnectDefaults } from './defaults.js'; +import type { InterruptionDetectionError } from './errors.js'; +import { createHttpTransport } from './http_transport.js'; +import { InterruptionCacheEntry } from './interruption_cache_entry.js'; +import type { AdaptiveInterruptionDetector } from './interruption_detector.js'; +import { + type AgentSpeechEnded, + type AgentSpeechStarted, + type ApiConnectOptions, + type Flush, + type InterruptionOptions, + type InterruptionSentinel, + type OverlapSpeechEnded, + type OverlapSpeechStarted, + type OverlappingSpeechEvent, +} from './types.js'; +import { BoundedCache } from './utils.js'; +import { createWsTransport } from './ws_transport.js'; + +// Re-export sentinel types for backwards compatibility +export type { + AgentSpeechEnded, + AgentSpeechStarted, + ApiConnectOptions, + Flush, + InterruptionSentinel, + OverlapSpeechEnded, + OverlapSpeechStarted, +}; + +export class InterruptionStreamSentinel { + static agentSpeechStarted(): AgentSpeechStarted { + return { type: 'agent-speech-started' }; + } + + static agentSpeechEnded(): AgentSpeechEnded { + return { type: 'agent-speech-ended' }; + } + + static overlapSpeechStarted( + speechDuration: number, + startedAt: number, + userSpeakingSpan?: Span, + ): OverlapSpeechStarted { + return { type: 'overlap-speech-started', speechDuration, startedAt, userSpeakingSpan }; + } + + static overlapSpeechEnded(endedAt: number): OverlapSpeechEnded { + return { type: 'overlap-speech-ended', endedAt }; + } + + static flush(): Flush { + return { type: 'flush' }; + } +} + +function updateUserSpeakingSpan(span: Span, entry: InterruptionCacheEntry) { + span.setAttribute( + traceTypes.ATTR_IS_INTERRUPTION, + (entry.isInterruption ?? false).toString().toLowerCase(), + ); + span.setAttribute(traceTypes.ATTR_INTERRUPTION_PROBABILITY, entry.probability); + span.setAttribute(traceTypes.ATTR_INTERRUPTION_TOTAL_DURATION, entry.totalDurationInS); + span.setAttribute(traceTypes.ATTR_INTERRUPTION_PREDICTION_DURATION, entry.predictionDurationInS); + span.setAttribute(traceTypes.ATTR_INTERRUPTION_DETECTION_DELAY, entry.detectionDelayInS); +} + +export class InterruptionStreamBase { + private inputStream: StreamChannel; + + private eventStream: ReadableStream; + + private resampler?: AudioResampler; + + private numRequests = 0; + + private userSpeakingSpan: Span | undefined; + + private overlapSpeechStartedAt: number | undefined; + + private options: InterruptionOptions; + + private apiOptions: ApiConnectOptions; + + private model: AdaptiveInterruptionDetector; + + private logger = log(); + + // Store reconnect function for WebSocket transport + private wsReconnect?: () => Promise; + + // Mutable transport options that can be updated via updateOptions() + private transportOptions: { + baseUrl: string; + apiKey: string; + apiSecret: string; + sampleRate: number; + threshold: number; + minFrames: number; + timeout: number; + maxRetries: number; + }; + + constructor(model: AdaptiveInterruptionDetector, apiOptions: Partial) { + this.inputStream = createStreamChannel< + InterruptionSentinel | AudioFrame, + InterruptionDetectionError + >(); + + this.model = model; + this.options = { ...model.options }; + this.apiOptions = { ...apiConnectDefaults, ...apiOptions }; + + // Initialize mutable transport options + this.transportOptions = { + baseUrl: this.options.baseUrl, + apiKey: this.options.apiKey, + apiSecret: this.options.apiSecret, + sampleRate: this.options.sampleRate, + threshold: this.options.threshold, + minFrames: this.options.minFrames, + timeout: this.options.inferenceTimeout, + maxRetries: this.apiOptions.maxRetries, + }; + + this.eventStream = this.setupTransform(); + } + + /** + * Update stream options. For WebSocket transport, this triggers a reconnection. + */ + async updateOptions(options: { + threshold?: number; + minInterruptionDurationInS?: number; + }): Promise { + if (options.threshold !== undefined) { + this.options.threshold = options.threshold; + this.transportOptions.threshold = options.threshold; + } + if (options.minInterruptionDurationInS !== undefined) { + this.options.minInterruptionDurationInS = options.minInterruptionDurationInS; + this.options.minFrames = Math.ceil(options.minInterruptionDurationInS * FRAMES_PER_SECOND); + this.transportOptions.minFrames = this.options.minFrames; + } + // Trigger WebSocket reconnection if using proxy (WebSocket transport) + if (this.options.useProxy && this.wsReconnect) { + await this.wsReconnect(); + } + } + + private setupTransform(): ReadableStream { + let agentSpeechStarted = false; + let startIdx = 0; + let accumulatedSamples = 0; + let overlapSpeechStarted = false; + let overlapCount = 0; + const cache = new BoundedCache(10); + const inferenceS16Data = new Int16Array( + Math.ceil(this.options.maxAudioDurationInS * this.options.sampleRate), + ).fill(0); + + // State accessors for transport + const getState = () => ({ + overlapSpeechStarted, + overlapSpeechStartedAt: this.overlapSpeechStartedAt, + cache, + overlapCount, + }); + const setState = (partial: { overlapSpeechStarted?: boolean }) => { + if (partial.overlapSpeechStarted !== undefined) { + overlapSpeechStarted = partial.overlapSpeechStarted; + } + }; + const handleSpanUpdate = (entry: InterruptionCacheEntry) => { + if (this.userSpeakingSpan) { + updateUserSpeakingSpan(this.userSpeakingSpan, entry); + this.userSpeakingSpan = undefined; + } + }; + + const onRequestSent = () => { + this.numRequests++; + }; + + const getAndResetNumRequests = (): number => { + const n = this.numRequests; + this.numRequests = 0; + return n; + }; + + // First transform: process input frames/sentinels and output audio slices or events + const audioTransformer = new TransformStream< + InterruptionSentinel | AudioFrame, + Int16Array | OverlappingSpeechEvent + >( + { + transform: (chunk, controller) => { + if (chunk instanceof AudioFrame) { + if (!agentSpeechStarted) { + return; + } + if (this.options.sampleRate !== chunk.sampleRate) { + controller.error('the sample rate of the input frames must be consistent'); + this.logger.error('the sample rate of the input frames must be consistent'); + return; + } + const result = writeToInferenceS16Data( + chunk, + startIdx, + inferenceS16Data, + this.options.maxAudioDurationInS, + ); + startIdx = result.startIdx; + accumulatedSamples += result.samplesWritten; + + if ( + accumulatedSamples >= + Math.floor(this.options.detectionIntervalInS * this.options.sampleRate) && + overlapSpeechStarted + ) { + const audioSlice = inferenceS16Data.slice(0, startIdx); + accumulatedSamples = 0; + controller.enqueue(audioSlice); + } + } else if (chunk.type === 'agent-speech-started') { + this.logger.debug('agent speech started'); + agentSpeechStarted = true; + overlapSpeechStarted = false; + this.overlapSpeechStartedAt = undefined; + accumulatedSamples = 0; + overlapCount = 0; + startIdx = 0; + this.numRequests = 0; + cache.clear(); + } else if (chunk.type === 'agent-speech-ended') { + this.logger.debug('agent speech ended'); + agentSpeechStarted = false; + overlapSpeechStarted = false; + this.overlapSpeechStartedAt = undefined; + accumulatedSamples = 0; + overlapCount = 0; + startIdx = 0; + this.numRequests = 0; + cache.clear(); + } else if (chunk.type === 'overlap-speech-started' && agentSpeechStarted) { + this.overlapSpeechStartedAt = chunk.startedAt; + this.userSpeakingSpan = chunk.userSpeakingSpan; + this.logger.debug('overlap speech started, starting interruption inference'); + overlapSpeechStarted = true; + accumulatedSamples = 0; + overlapCount += 1; + if (overlapCount <= 1) { + const keepSize = + Math.round((chunk.speechDuration / 1000) * this.options.sampleRate) + + Math.round(this.options.audioPrefixDurationInS * this.options.sampleRate); + const shiftCount = Math.max(0, startIdx - keepSize); + inferenceS16Data.copyWithin(0, shiftCount, startIdx); + startIdx -= shiftCount; + } + cache.clear(); + } else if (chunk.type === 'overlap-speech-ended') { + this.logger.debug('overlap speech ended'); + if (overlapSpeechStarted) { + this.userSpeakingSpan = undefined; + let latestEntry = cache.pop( + (entry) => entry.totalDurationInS !== undefined && entry.totalDurationInS > 0, + ); + if (!latestEntry) { + this.logger.debug('no request made for overlap speech'); + latestEntry = InterruptionCacheEntry.default(); + } + const e = latestEntry ?? InterruptionCacheEntry.default(); + const event: OverlappingSpeechEvent = { + type: 'user_overlapping_speech', + timestamp: chunk.endedAt, + isInterruption: false, + overlapStartedAt: this.overlapSpeechStartedAt, + speechInput: e.speechInput, + probabilities: e.probabilities, + totalDurationInS: e.totalDurationInS, + detectionDelayInS: e.detectionDelayInS, + predictionDurationInS: e.predictionDurationInS, + probability: e.probability, + numRequests: getAndResetNumRequests(), + }; + controller.enqueue(event); + overlapSpeechStarted = false; + accumulatedSamples = 0; + } + this.overlapSpeechStartedAt = undefined; + } else if (chunk.type === 'flush') { + // no-op + } + }, + }, + { highWaterMark: 32 }, + { highWaterMark: 32 }, + ); + + // Second transform: transport layer (HTTP or WebSocket based on useProxy) + const transportOptions = this.transportOptions; + + let transport: TransformStream; + if (this.options.useProxy) { + const wsResult = createWsTransport( + transportOptions, + getState, + setState, + handleSpanUpdate, + onRequestSent, + getAndResetNumRequests, + ); + transport = wsResult.transport; + this.wsReconnect = wsResult.reconnect; + } else { + transport = createHttpTransport( + transportOptions, + getState, + setState, + handleSpanUpdate, + getAndResetNumRequests, + ); + } + + const eventEmitter = new TransformStream({ + transform: (chunk, controller) => { + this.model.emit('user_overlapping_speech', chunk); + + const metrics: InterruptionMetrics = { + type: 'interruption_metrics', + timestamp: chunk.timestamp, + totalDuration: chunk.totalDurationInS * 1000, + predictionDuration: chunk.predictionDurationInS * 1000, + detectionDelay: chunk.detectionDelayInS * 1000, + numInterruptions: chunk.isInterruption ? 1 : 0, + numBackchannels: chunk.isInterruption ? 0 : 1, + numRequests: chunk.numRequests, + metadata: { + modelProvider: this.model.provider, + modelName: this.model.model, + }, + }; + this.model.emit('metrics_collected', metrics); + + controller.enqueue(chunk); + }, + }); + + // Pipeline: input -> audioTransformer -> transport -> eventEmitter -> eventStream + return this.inputStream + .stream() + .pipeThrough(audioTransformer) + .pipeThrough(transport) + .pipeThrough(eventEmitter); + } + + private ensureInputNotEnded() { + if (this.inputStream.closed) { + throw new Error('input stream is closed'); + } + } + + private ensureStreamsNotEnded() { + this.ensureInputNotEnded(); + } + + private getResamplerFor(inputSampleRate: number): AudioResampler { + if (!this.resampler) { + this.resampler = new AudioResampler(inputSampleRate, this.options.sampleRate); + } + return this.resampler; + } + + stream(): ReadableStream { + return this.eventStream; + } + + async pushFrame(frame: InterruptionSentinel | AudioFrame): Promise { + this.ensureStreamsNotEnded(); + if (!(frame instanceof AudioFrame)) { + return this.inputStream.write(frame); + } else if (this.options.sampleRate !== frame.sampleRate) { + const resampler = this.getResamplerFor(frame.sampleRate); + if (resampler.inputRate !== frame.sampleRate) { + throw new Error('the sample rate of the input frames must be consistent'); + } + for (const resampledFrame of resampler.push(frame)) { + await this.inputStream.write(resampledFrame); + } + } else { + await this.inputStream.write(frame); + } + } + + async flush(): Promise { + this.ensureStreamsNotEnded(); + await this.inputStream.write(InterruptionStreamSentinel.flush()); + } + + async endInput(): Promise { + await this.flush(); + await this.inputStream.close(); + } + + async close(): Promise { + if (!this.inputStream.closed) await this.inputStream.close(); + this.model.removeStream(this); + } +} + +/** + * Write the audio frame to the output data array and return the new start index + * and the number of samples written. + */ +function writeToInferenceS16Data( + frame: AudioFrame, + startIdx: number, + outData: Int16Array, + maxAudioDuration: number, +): { startIdx: number; samplesWritten: number } { + const maxWindowSize = Math.floor(maxAudioDuration * frame.sampleRate); + + if (frame.samplesPerChannel > outData.length) { + throw new Error('frame samples are greater than the max window size'); + } + + // Shift the data to the left if the window would overflow + const shift = startIdx + frame.samplesPerChannel - maxWindowSize; + if (shift > 0) { + outData.copyWithin(0, shift, startIdx); + startIdx -= shift; + } + + // Get the frame data as Int16Array + const frameData = new Int16Array( + frame.data.buffer, + frame.data.byteOffset, + frame.samplesPerChannel * frame.channels, + ); + + if (frame.channels > 1) { + // Mix down multiple channels to mono by averaging + for (let i = 0; i < frame.samplesPerChannel; i++) { + let sum = 0; + for (let ch = 0; ch < frame.channels; ch++) { + sum += frameData[i * frame.channels + ch] ?? 0; + } + outData[startIdx + i] = Math.floor(sum / frame.channels); + } + } else { + // Single channel - copy directly + outData.set(frameData, startIdx); + } + + startIdx += frame.samplesPerChannel; + return { startIdx, samplesWritten: frame.samplesPerChannel }; +} diff --git a/agents/src/inference/interruption/types.ts b/agents/src/inference/interruption/types.ts new file mode 100644 index 000000000..a62596030 --- /dev/null +++ b/agents/src/inference/interruption/types.ts @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { Span } from '@opentelemetry/api'; + +export interface OverlappingSpeechEvent { + type: 'user_overlapping_speech'; + timestamp: number; + isInterruption: boolean; + totalDurationInS: number; + predictionDurationInS: number; + detectionDelayInS: number; + overlapStartedAt?: number; + speechInput?: Int16Array; + probabilities?: number[]; + probability: number; + numRequests: number; +} + +/** + * Configuration options for interruption detection. + */ +export interface InterruptionOptions { + sampleRate: number; + threshold: number; + minFrames: number; + maxAudioDurationInS: number; + audioPrefixDurationInS: number; + detectionIntervalInS: number; + inferenceTimeout: number; + minInterruptionDurationInS: number; + baseUrl: string; + apiKey: string; + apiSecret: string; + useProxy: boolean; +} + +/** + * API connection options for transport layers. + */ +export interface ApiConnectOptions { + maxRetries: number; + retryInterval: number; + timeout: number; +} + +// Sentinel types for stream control signals + +export interface AgentSpeechStarted { + type: 'agent-speech-started'; +} + +export interface AgentSpeechEnded { + type: 'agent-speech-ended'; +} + +export interface OverlapSpeechStarted { + type: 'overlap-speech-started'; + /** Duration of the speech segment in milliseconds (matches VADEvent.speechDuration units). */ + speechDuration: number; + /** Absolute timestamp (ms) when overlap speech started, computed at call-site. */ + startedAt: number; + userSpeakingSpan?: Span; +} + +export interface OverlapSpeechEnded { + type: 'overlap-speech-ended'; + /** Absolute timestamp (ms) when overlap speech ended, used as the non-interruption event timestamp. */ + endedAt: number; +} + +export interface Flush { + type: 'flush'; +} + +/** + * Union type for all stream control signals. + */ +export type InterruptionSentinel = + | AgentSpeechStarted + | AgentSpeechEnded + | OverlapSpeechStarted + | OverlapSpeechEnded + | Flush; diff --git a/agents/src/inference/interruption/utils.test.ts b/agents/src/inference/interruption/utils.test.ts new file mode 100644 index 000000000..79b585fe6 --- /dev/null +++ b/agents/src/inference/interruption/utils.test.ts @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it, vi } from 'vitest'; +import { BoundedCache } from './utils.js'; + +class Entry { + createdAt: number; + totalDurationInS: number | undefined = undefined; + predictionDurationInS: number | undefined = undefined; + note: string | undefined = undefined; + + constructor(createdAt: number, note?: string) { + this.createdAt = createdAt; + this.note = note; + } +} + +describe('BoundedCache', () => { + it('evicts oldest entry when maxLen is exceeded', () => { + const cache = new BoundedCache(2); + cache.set(1, new Entry(1)); + cache.set(2, new Entry(2)); + cache.set(3, new Entry(3)); + + expect(cache.size).toBe(2); + expect([...cache.keys()]).toEqual([2, 3]); + expect(cache.get(1)).toBeUndefined(); + expect(cache.get(2)!.createdAt).toBe(2); + expect(cache.get(3)!.createdAt).toBe(3); + }); + + it('setOrUpdate creates a value via factory when key is missing', () => { + const cache = new BoundedCache(10); + const factory = vi.fn(() => new Entry(100)); + + const value = cache.setOrUpdate(1, factory, { predictionDurationInS: 0.42 }); + + expect(factory).toHaveBeenCalledTimes(1); + expect(value.createdAt).toBe(100); + expect(value.predictionDurationInS).toBe(0.42); + expect(cache.get(1)?.predictionDurationInS).toBe(0.42); + }); + + it('setOrUpdate updates existing value and does not call factory', () => { + const cache = new BoundedCache(10); + cache.set(1, new Entry(1, 'before')); + const factory = vi.fn(() => new Entry(999)); + + const value = cache.setOrUpdate(1, factory, { note: 'after', totalDurationInS: 1.5 }); + + expect(factory).not.toHaveBeenCalled(); + expect(value.createdAt).toBe(1); + expect(value.note).toBe('after'); + expect(value.totalDurationInS).toBe(1.5); + }); + + it('updateValue returns undefined for missing key', () => { + const cache = new BoundedCache(10); + const result = cache.updateValue(404, { note: 'missing' }); + + expect(result).toBeUndefined(); + }); + + it('updateValue ignores undefined fields', () => { + const cache = new BoundedCache(10); + cache.set(1, new Entry(1, 'keep')); + + const result = cache.updateValue(1, { + note: undefined, + predictionDurationInS: 0.1, + }); + + expect(result?.createdAt).toBe(1); + expect(result?.note).toBe('keep'); + expect(result?.predictionDurationInS).toBe(0.1); + }); + + it('pop without predicate removes the oldest entry (python parity)', () => { + const cache = new BoundedCache(10); + cache.set(1, new Entry(1)); + cache.set(2, new Entry(2)); + cache.set(3, new Entry(3)); + + const popped = cache.pop(); + + expect(popped?.createdAt).toBe(1); + expect([...cache.keys()]).toEqual([2, 3]); + }); + + it('pop with predicate removes the most recent matching entry', () => { + const cache = new BoundedCache(10); + const e1 = new Entry(1); + e1.totalDurationInS = 0; + const e2 = new Entry(2); + e2.totalDurationInS = 1; + const e3 = new Entry(3); + e3.totalDurationInS = 2; + cache.set(1, e1); + cache.set(2, e2); + cache.set(3, e3); + + const popped = cache.pop((entry) => (entry.totalDurationInS ?? 0) > 0); + + expect(popped?.createdAt).toBe(3); + expect(popped?.totalDurationInS).toBe(2); + expect([...cache.keys()]).toEqual([1, 2]); + }); + + it('pop with predicate returns undefined when no match exists', () => { + const cache = new BoundedCache(10); + const e1 = new Entry(1); + e1.totalDurationInS = 0; + cache.set(1, e1); + + const popped = cache.pop((entry) => (entry.totalDurationInS ?? 0) > 10); + + expect(popped).toBeUndefined(); + expect(cache.size).toBe(1); + }); + + it('clear removes all entries', () => { + const cache = new BoundedCache(10); + cache.set(1, new Entry(1)); + cache.set(2, new Entry(2)); + + cache.clear(); + + expect(cache.size).toBe(0); + expect([...cache.keys()]).toEqual([]); + }); +}); diff --git a/agents/src/inference/interruption/utils.ts b/agents/src/inference/interruption/utils.ts new file mode 100644 index 000000000..e614f3b6d --- /dev/null +++ b/agents/src/inference/interruption/utils.ts @@ -0,0 +1,137 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { FRAME_DURATION_IN_S, MIN_INTERRUPTION_DURATION_IN_S } from './defaults.js'; + +/** + * A bounded cache that automatically evicts the oldest entries when the cache exceeds max size. + * Uses FIFO eviction strategy. + */ +export class BoundedCache { + private cache: Map = new Map(); + private readonly maxLen: number; + + constructor(maxLen: number = 10) { + this.maxLen = maxLen; + } + + set(key: K, value: V): void { + this.cache.set(key, value); + if (this.cache.size > this.maxLen) { + // Remove the oldest entry (first inserted) + const firstKey = this.cache.keys().next().value as K; + this.cache.delete(firstKey); + } + } + + /** + * Update existing value fields if present and defined. + * Mirrors python BoundedDict.update_value behavior. + */ + updateValue(key: K, fields: Partial): V | undefined { + const value = this.cache.get(key); + if (!value) return value; + + for (const [fieldName, fieldValue] of Object.entries(fields) as [keyof V, V[keyof V]][]) { + if (fieldValue === undefined) continue; + // Runtime field update parity with python's hasattr + setattr. + if (fieldName in (value as object)) { + (value as Record)[String(fieldName)] = fieldValue; + } + } + return value; + } + + /** + * Set a new value with factory when missing; otherwise update in place. + * Mirrors python BoundedDict.set_or_update behavior. + */ + setOrUpdate(key: K, factory: () => V, fields: Partial): V { + if (!this.cache.has(key)) { + this.set(key, factory()); + } + const result = this.updateValue(key, fields); + if (!result) { + throw new Error('setOrUpdate invariant failed: entry should exist after set'); + } + return result; + } + + get(key: K): V | undefined { + return this.cache.get(key); + } + + has(key: K): boolean { + return this.cache.has(key); + } + + delete(key: K): boolean { + return this.cache.delete(key); + } + + /** + * Pop an entry if it satisfies the predicate. + * - No predicate: pop oldest (FIFO) + * - With predicate: search in reverse order and pop first match + */ + pop(predicate?: (value: V) => boolean): V | undefined { + if (predicate === undefined) { + const first = this.cache.entries().next().value as [K, V] | undefined; + if (!first) return undefined; + const [key, value] = first; + this.cache.delete(key); + return value; + } + + const keys = Array.from(this.cache.keys()); + for (let i = keys.length - 1; i >= 0; i--) { + const key = keys[i]!; + const value = this.cache.get(key)!; + if (predicate(value)) { + this.cache.delete(key); + return value; + } + } + return undefined; + } + + clear(): void { + this.cache.clear(); + } + + get size(): number { + return this.cache.size; + } + + values(): IterableIterator { + return this.cache.values(); + } + + keys(): IterableIterator { + return this.cache.keys(); + } + + entries(): IterableIterator<[K, V]> { + return this.cache.entries(); + } +} + +/** + * Estimate probability by finding the n-th maximum value in the probabilities array. + * The n-th position is determined by the window size (25ms per frame). + * Returns 0 if there are insufficient probabilities. + */ +export function estimateProbability( + probabilities: number[], + windowSizeInS: number = MIN_INTERRUPTION_DURATION_IN_S, +): number { + const nTh = Math.ceil(windowSizeInS / FRAME_DURATION_IN_S); + if (probabilities.length < nTh) { + return 0; + } + + // Find the n-th maximum value by sorting in descending order + // Create a copy to avoid mutating the original array + const sorted = [...probabilities].sort((a, b) => b - a); + return sorted[nTh - 1]!; +} diff --git a/agents/src/inference/interruption/ws_transport.ts b/agents/src/inference/interruption/ws_transport.ts new file mode 100644 index 000000000..e497bb1d1 --- /dev/null +++ b/agents/src/inference/interruption/ws_transport.ts @@ -0,0 +1,402 @@ +// SPDX-FileCopyrightText: 2025 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { TransformStream } from 'stream/web'; +import WebSocket from 'ws'; +import { z } from 'zod'; +import { log } from '../../log.js'; +import { createAccessToken } from '../utils.js'; +import { intervalForRetry } from './defaults.js'; +import { InterruptionCacheEntry } from './interruption_cache_entry.js'; +import type { OverlappingSpeechEvent } from './types.js'; +import type { BoundedCache } from './utils.js'; + +// WebSocket message types +const MSG_SESSION_CREATE = 'session.create'; +const MSG_SESSION_CLOSE = 'session.close'; +const MSG_SESSION_CREATED = 'session.created'; +const MSG_SESSION_CLOSED = 'session.closed'; +const MSG_INTERRUPTION_DETECTED = 'bargein_detected'; +const MSG_INFERENCE_DONE = 'inference_done'; +const MSG_ERROR = 'error'; + +export interface WsTransportOptions { + baseUrl: string; + apiKey: string; + apiSecret: string; + sampleRate: number; + threshold: number; + minFrames: number; + timeout: number; + maxRetries?: number; +} + +export interface WsTransportState { + overlapSpeechStarted: boolean; + overlapSpeechStartedAt: number | undefined; + cache: BoundedCache; +} + +const wsMessageSchema = z.discriminatedUnion('type', [ + z.object({ + type: z.literal(MSG_SESSION_CREATED), + }), + z.object({ + type: z.literal(MSG_SESSION_CLOSED), + }), + z.object({ + type: z.literal(MSG_INTERRUPTION_DETECTED), + created_at: z.number(), + probabilities: z.array(z.number()).default([]), + prediction_duration: z.number().default(0), + }), + z.object({ + type: z.literal(MSG_INFERENCE_DONE), + created_at: z.number(), + probabilities: z.array(z.number()).default([]), + prediction_duration: z.number().default(0), + is_bargein: z.boolean().optional(), + }), + z.object({ + type: z.literal(MSG_ERROR), + message: z.string(), + code: z.number().optional(), + session_id: z.string().optional(), + }), +]); + +type WsMessage = z.infer; + +/** + * Creates a WebSocket connection and waits for it to open. + */ +async function connectWebSocket(options: WsTransportOptions): Promise { + const baseUrl = options.baseUrl.replace(/^http/, 'ws'); + const token = await createAccessToken(options.apiKey, options.apiSecret); + const url = `${baseUrl}/bargein`; + + const ws = new WebSocket(url, { + headers: { Authorization: `Bearer ${token}` }, + }); + + await new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + ws.terminate(); + reject(new Error('WebSocket connection timeout')); + }, options.timeout); + ws.once('open', () => { + clearTimeout(timeout); + resolve(); + }); + ws.once('error', (err: Error) => { + clearTimeout(timeout); + ws.terminate(); + reject(err); + }); + }); + + return ws; +} + +export interface WsTransportResult { + transport: TransformStream; + reconnect: () => Promise; +} + +/** + * Creates a WebSocket transport TransformStream for interruption detection. + * + * This transport receives Int16Array audio slices and outputs InterruptionEvents. + * It maintains a persistent WebSocket connection with automatic retry on failure. + * Returns both the transport and a reconnect function for option updates. + */ +export function createWsTransport( + options: WsTransportOptions, + getState: () => WsTransportState, + setState: (partial: Partial) => void, + updateUserSpeakingSpan?: (entry: InterruptionCacheEntry) => void, + onRequestSent?: () => void, + getAndResetNumRequests?: () => number, +): WsTransportResult { + const logger = log(); + let ws: WebSocket | null = null; + let outputController: TransformStreamDefaultController | null = null; + + function setupMessageHandler(socket: WebSocket): void { + socket.on('message', (data: WebSocket.Data) => { + try { + const message = wsMessageSchema.parse(JSON.parse(data.toString())); + handleMessage(message); + } catch { + logger.warn({ data: data.toString() }, 'Failed to parse WebSocket message'); + } + }); + + socket.on('error', (err: Error) => { + logger.error({ err }, 'WebSocket error'); + }); + + socket.on('close', (code: number, reason: Buffer) => { + logger.debug({ code, reason: reason.toString() }, 'WebSocket closed'); + }); + } + + async function ensureConnection(): Promise { + if (ws && ws.readyState === WebSocket.OPEN) return; + + const maxRetries = options.maxRetries ?? 3; + let lastError: Error | null = null; + + for (let attempt = 0; attempt <= maxRetries; attempt++) { + try { + ws = await connectWebSocket(options); + setupMessageHandler(ws); + + // Send session.create message + const sessionCreateMsg = JSON.stringify({ + type: MSG_SESSION_CREATE, + settings: { + sample_rate: options.sampleRate, + num_channels: 1, + threshold: options.threshold, + min_frames: options.minFrames, + encoding: 's16le', + }, + }); + ws.send(sessionCreateMsg); + return; + } catch (err) { + lastError = err instanceof Error ? err : new Error(String(err)); + if (attempt < maxRetries) { + const delay = intervalForRetry(attempt); + logger.debug( + { attempt, delay, err: lastError.message }, + 'WebSocket connection failed, retrying', + ); + await new Promise((resolve) => setTimeout(resolve, delay)); + } + } + } + + throw lastError ?? new Error('Failed to connect to WebSocket after retries'); + } + + function handleMessage(message: WsMessage): void { + const state = getState(); + + switch (message.type) { + case MSG_SESSION_CREATED: + logger.debug('WebSocket session created'); + break; + + case MSG_INTERRUPTION_DETECTED: { + const createdAt = message.created_at; + const overlapSpeechStartedAt = state.overlapSpeechStartedAt; + if (state.overlapSpeechStarted && overlapSpeechStartedAt !== undefined) { + const existing = state.cache.get(createdAt); + + const totalDurationInS = + existing?.requestStartedAt !== undefined + ? (performance.now() - existing.requestStartedAt) / 1000 + : (performance.now() - createdAt) / 1000; + + const entry = state.cache.setOrUpdate( + createdAt, + () => new InterruptionCacheEntry({ createdAt }), + { + speechInput: existing?.speechInput, + requestStartedAt: existing?.requestStartedAt, + totalDurationInS, + probabilities: message.probabilities, + isInterruption: true, + predictionDurationInS: message.prediction_duration, + detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000, + }, + ); + + if (updateUserSpeakingSpan) { + updateUserSpeakingSpan(entry); + } + + logger.debug( + { + totalDuration: entry.totalDurationInS, + predictionDuration: entry.predictionDurationInS, + detectionDelay: entry.detectionDelayInS, + probability: entry.probability, + }, + 'interruption detected', + ); + + const event: OverlappingSpeechEvent = { + type: 'user_overlapping_speech', + timestamp: Date.now(), + isInterruption: true, + totalDurationInS: entry.totalDurationInS, + predictionDurationInS: entry.predictionDurationInS, + overlapStartedAt: overlapSpeechStartedAt, + speechInput: entry.speechInput, + probabilities: entry.probabilities, + detectionDelayInS: entry.detectionDelayInS, + probability: entry.probability, + numRequests: getAndResetNumRequests?.() ?? 0, + }; + + outputController?.enqueue(event); + setState({ overlapSpeechStarted: false }); + } + break; + } + + case MSG_INFERENCE_DONE: { + const createdAt = message.created_at; + const overlapSpeechStartedAt = state.overlapSpeechStartedAt; + if (state.overlapSpeechStarted && overlapSpeechStartedAt !== undefined) { + const existing = state.cache.get(createdAt); + const totalDurationInS = + existing?.requestStartedAt !== undefined + ? (performance.now() - existing.requestStartedAt) / 1000 + : (performance.now() - createdAt) / 1000; + const entry = state.cache.setOrUpdate( + createdAt, + () => new InterruptionCacheEntry({ createdAt }), + { + speechInput: existing?.speechInput, + requestStartedAt: existing?.requestStartedAt, + totalDurationInS, + predictionDurationInS: message.prediction_duration, + probabilities: message.probabilities, + isInterruption: message.is_bargein ?? false, + detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000, + }, + ); + + logger.debug( + { + totalDurationInS: entry.totalDurationInS, + predictionDurationInS: entry.predictionDurationInS, + }, + 'interruption inference done', + ); + } + break; + } + + case MSG_SESSION_CLOSED: + logger.debug('WebSocket session closed'); + break; + + case MSG_ERROR: + outputController?.error( + new Error( + `LiveKit Adaptive Interruption error${ + message.code !== undefined ? ` (${message.code})` : '' + }: ${message.message}`, + ), + ); + break; + } + } + + function sendAudioData(audioSlice: Int16Array): void { + if (!ws || ws.readyState !== WebSocket.OPEN) { + throw new Error('WebSocket not connected'); + } + + const state = getState(); + // Use truncated timestamp consistently for both cache key and header + // This ensures the server's response created_at matches our cache key + const createdAt = Math.floor(performance.now()); + + // Store the audio data in cache with truncated timestamp + state.cache.set( + createdAt, + new InterruptionCacheEntry({ + createdAt, + requestStartedAt: performance.now(), + speechInput: audioSlice, + }), + ); + + // Create header: 8-byte little-endian uint64 timestamp (milliseconds as integer) + const header = new ArrayBuffer(8); + const view = new DataView(header); + view.setUint32(0, createdAt >>> 0, true); + view.setUint32(4, Math.floor(createdAt / 0x100000000) >>> 0, true); + + // Combine header and audio data + const audioBytes = new Uint8Array( + audioSlice.buffer, + audioSlice.byteOffset, + audioSlice.byteLength, + ); + const combined = new Uint8Array(8 + audioBytes.length); + combined.set(new Uint8Array(header), 0); + combined.set(audioBytes, 8); + + try { + ws.send(combined); + onRequestSent?.(); + } catch (e: unknown) { + logger.error(e, `failed to send audio via websocket`); + } + } + + function close(): void { + if (ws?.readyState === WebSocket.OPEN) { + const closeMsg = JSON.stringify({ type: MSG_SESSION_CLOSE }); + try { + ws.send(closeMsg); + } catch (e: unknown) { + logger.error(e, 'failed to send close message'); + } + } + ws?.close(1000); // signal normal websocket closure + ws = null; + } + + /** + * Reconnect the WebSocket with updated options. + * This is called when options are updated via updateOptions(). + */ + async function reconnect(): Promise { + close(); + } + + const transport = new TransformStream< + Int16Array | OverlappingSpeechEvent, + OverlappingSpeechEvent + >( + { + async start(controller) { + outputController = controller; + await ensureConnection(); + }, + + transform(chunk, controller) { + if (!(chunk instanceof Int16Array)) { + controller.enqueue(chunk); + return; + } + + // Only forwards buffered audio while overlap speech is actively on. + const state = getState(); + if (!state.overlapSpeechStartedAt || !state.overlapSpeechStarted) return; + + try { + sendAudioData(chunk); + } catch (err) { + logger.error({ err }, 'Failed to send audio data over WebSocket'); + } + }, + + flush() { + close(); + }, + }, + { highWaterMark: 2 }, + { highWaterMark: 2 }, + ); + + return { transport, reconnect }; +} diff --git a/agents/src/inference/llm.ts b/agents/src/inference/llm.ts index c612b1654..d33d67d27 100644 --- a/agents/src/inference/llm.ts +++ b/agents/src/inference/llm.ts @@ -2,19 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 import OpenAI from 'openai'; -import { - APIConnectionError, - APIStatusError, - APITimeoutError, - DEFAULT_API_CONNECT_OPTIONS, - type Expand, - toError, -} from '../index.js'; +import { APIConnectionError, APIStatusError, APITimeoutError } from '../_exceptions.js'; import * as llm from '../llm/index.js'; +import { DEFAULT_API_CONNECT_OPTIONS } from '../types.js'; import type { APIConnectOptions } from '../types.js'; -import { type AnyString, createAccessToken } from './utils.js'; - -const DEFAULT_BASE_URL = 'https://agent-gateway.livekit.cloud/v1'; +import { type Expand, toError } from '../utils.js'; +import { type AnyString, createAccessToken, getDefaultInferenceUrl } from './utils.js'; export type OpenAIModels = | 'openai/gpt-5.2' @@ -127,7 +120,7 @@ export class LLM extends llm.LLM { strictToolSchema = false, } = opts; - const lkBaseURL = baseURL || process.env.LIVEKIT_INFERENCE_URL || DEFAULT_BASE_URL; + const lkBaseURL = baseURL || getDefaultInferenceUrl(); const lkApiKey = apiKey || process.env.LIVEKIT_INFERENCE_API_KEY || process.env.LIVEKIT_API_KEY; if (!lkApiKey) { throw new Error('apiKey is required: pass apiKey or set LIVEKIT_API_KEY'); @@ -163,6 +156,10 @@ export class LLM extends llm.LLM { return this.opts.model; } + get provider(): string { + return 'livekit'; + } + static fromModelString(modelString: string): LLM { return new LLM({ model: modelString }); } diff --git a/agents/src/inference/stt.ts b/agents/src/inference/stt.ts index eb5479db6..a31aabdc4 100644 --- a/agents/src/inference/stt.ts +++ b/agents/src/inference/stt.ts @@ -22,7 +22,7 @@ import { type SttTranscriptEvent, sttServerEventSchema, } from './api_protos.js'; -import { type AnyString, connectWs, createAccessToken } from './utils.js'; +import { type AnyString, connectWs, createAccessToken, getDefaultInferenceUrl } from './utils.js'; export type DeepgramModels = | 'deepgram/flux-general' @@ -151,7 +151,6 @@ export type STTEncoding = 'pcm_s16le'; const DEFAULT_ENCODING: STTEncoding = 'pcm_s16le'; const DEFAULT_SAMPLE_RATE = 16000; -const DEFAULT_BASE_URL = 'wss://agent-gateway.livekit.cloud/v1'; const DEFAULT_CANCEL_TIMEOUT = 5000; export interface InferenceSTTOptions { @@ -203,7 +202,7 @@ export class STT extends BaseSTT { connOptions, } = opts || {}; - const lkBaseURL = baseURL || process.env.LIVEKIT_INFERENCE_URL || DEFAULT_BASE_URL; + const lkBaseURL = baseURL || getDefaultInferenceUrl(); const lkApiKey = apiKey || process.env.LIVEKIT_INFERENCE_API_KEY || process.env.LIVEKIT_API_KEY; if (!lkApiKey) { throw new Error('apiKey is required: pass apiKey or set LIVEKIT_API_KEY'); @@ -253,6 +252,14 @@ export class STT extends BaseSTT { return 'inference.STT'; } + get model(): string { + return this.opts.model ?? 'auto'; + } + + get provider(): string { + return 'livekit'; + } + static fromModelString(modelString: string): STT { const [model, language] = parseSTTModelString(modelString); return new STT({ model, language }); diff --git a/agents/src/inference/tts.ts b/agents/src/inference/tts.ts index 83f62db17..07e444587 100644 --- a/agents/src/inference/tts.ts +++ b/agents/src/inference/tts.ts @@ -19,7 +19,7 @@ import { ttsClientEventSchema, ttsServerEventSchema, } from './api_protos.js'; -import { type AnyString, connectWs, createAccessToken } from './utils.js'; +import { type AnyString, connectWs, createAccessToken, getDefaultInferenceUrl } from './utils.js'; export type CartesiaModels = | 'cartesia/sonic-3' @@ -136,7 +136,6 @@ type TTSEncoding = 'pcm_s16le'; const DEFAULT_ENCODING: TTSEncoding = 'pcm_s16le'; const DEFAULT_SAMPLE_RATE = 16000; -const DEFAULT_BASE_URL = 'https://agent-gateway.livekit.cloud/v1'; const NUM_CHANNELS = 1; const DEFAULT_LANGUAGE = 'en'; @@ -193,7 +192,7 @@ export class TTS extends BaseTTS { connOptions, } = opts || {}; - const lkBaseURL = baseURL || process.env.LIVEKIT_INFERENCE_URL || DEFAULT_BASE_URL; + const lkBaseURL = baseURL || getDefaultInferenceUrl(); const lkApiKey = apiKey || process.env.LIVEKIT_INFERENCE_API_KEY || process.env.LIVEKIT_API_KEY; if (!lkApiKey) { throw new Error('apiKey is required: pass apiKey or set LIVEKIT_API_KEY'); @@ -254,6 +253,14 @@ export class TTS extends BaseTTS { return 'inference.TTS'; } + get model(): string { + return this.opts.model ?? 'unknown'; + } + + get provider(): string { + return 'livekit'; + } + static fromModelString(modelString: string): TTS { const [model, voice] = parseTTSModelString(modelString); return new TTS({ model, voice: voice || undefined }); diff --git a/agents/src/inference/utils.ts b/agents/src/inference/utils.ts index b3b772ef6..a80017c0b 100644 --- a/agents/src/inference/utils.ts +++ b/agents/src/inference/utils.ts @@ -3,10 +3,38 @@ // SPDX-License-Identifier: Apache-2.0 import { AccessToken } from 'livekit-server-sdk'; import { WebSocket } from 'ws'; -import { APIConnectionError, APIStatusError } from '../index.js'; +import { APIConnectionError, APIStatusError } from '../_exceptions.js'; export type AnyString = string & NonNullable; +/** Default production inference URL */ +export const DEFAULT_INFERENCE_URL = 'https://agent-gateway.livekit.cloud/v1'; + +/** Staging inference URL */ +export const STAGING_INFERENCE_URL = 'https://agent-gateway.staging.livekit.cloud/v1'; + +/** + * Get the default inference URL based on the environment. + * + * Priority: + * 1. LIVEKIT_INFERENCE_URL if set + * 2. If LIVEKIT_URL contains '.staging.livekit.cloud', use staging gateway + * 3. Otherwise, use production gateway + */ +export function getDefaultInferenceUrl(): string { + const inferenceUrl = process.env.LIVEKIT_INFERENCE_URL; + if (inferenceUrl) { + return inferenceUrl; + } + + const livekitUrl = process.env.LIVEKIT_URL || ''; + if (livekitUrl.includes('.staging.livekit.cloud')) { + return STAGING_INFERENCE_URL; + } + + return DEFAULT_INFERENCE_URL; +} + export async function createAccessToken( apiKey: string, apiSecret: string, diff --git a/agents/src/llm/chat_context.ts b/agents/src/llm/chat_context.ts index 5ac15c0c7..114b4ac75 100644 --- a/agents/src/llm/chat_context.ts +++ b/agents/src/llm/chat_context.ts @@ -81,6 +81,17 @@ export function createAudioContent(params: { }; } +export interface MetricsReport { + startedSpeakingAt?: number; + stoppedSpeakingAt?: number; + transcriptionDelay?: number; + endOfTurnDelay?: number; + onUserTurnCompletedDelay?: number; + llmNodeTtft?: number; + ttsNodeTtfb?: number; + e2eLatency?: number; +} + export class ChatMessage { readonly id: string; @@ -92,18 +103,24 @@ export class ChatMessage { interrupted: boolean; + transcriptConfidence?: number; + + extra: Record; + + metrics: MetricsReport; + hash?: Uint8Array; createdAt: number; - extra: Record; - constructor(params: { role: ChatRole; content: ChatContent[] | string; id?: string; interrupted?: boolean; createdAt?: number; + transcriptConfidence?: number; + metrics?: MetricsReport; extra?: Record; }) { const { @@ -112,6 +129,8 @@ export class ChatMessage { id = shortuuid('item_'), interrupted = false, createdAt = Date.now(), + transcriptConfidence, + metrics = {}, extra = {}, } = params; this.id = id; @@ -119,6 +138,8 @@ export class ChatMessage { this.content = Array.isArray(content) ? content : [content]; this.interrupted = interrupted; this.createdAt = createdAt; + this.transcriptConfidence = transcriptConfidence; + this.metrics = metrics; this.extra = extra; } @@ -128,6 +149,8 @@ export class ChatMessage { id?: string; interrupted?: boolean; createdAt?: number; + transcriptConfidence?: number; + metrics?: MetricsReport; extra?: Record; }) { return new ChatMessage(params); @@ -179,6 +202,16 @@ export class ChatMessage { result.createdAt = this.createdAt; } + if (this.transcriptConfidence !== undefined) { + result.transcriptConfidence = this.transcriptConfidence; + } + if (Object.keys(this.metrics).length > 0) { + result.metrics = { ...this.metrics }; + } + if (Object.keys(this.extra).length > 0) { + result.extra = this.extra as JSONValue; + } + return result; } } @@ -439,6 +472,8 @@ export class ChatContext { id?: string; interrupted?: boolean; createdAt?: number; + transcriptConfidence?: number; + metrics?: MetricsReport; extra?: Record; }): ChatMessage { const msg = new ChatMessage(params); @@ -623,6 +658,9 @@ export class ChatContext { id: item.id, interrupted: item.interrupted, createdAt: item.createdAt, + transcriptConfidence: item.transcriptConfidence, + metrics: item.metrics, + extra: item.extra, }); // Filter content based on options diff --git a/agents/src/llm/index.ts b/agents/src/llm/index.ts index e40e2051c..1f0c10a90 100644 --- a/agents/src/llm/index.ts +++ b/agents/src/llm/index.ts @@ -30,6 +30,7 @@ export { type ChatItem, type ChatRole, type ImageContent, + type MetricsReport, } from './chat_context.js'; export type { ProviderFormat } from './provider_format/index.js'; diff --git a/agents/src/llm/llm.ts b/agents/src/llm/llm.ts index 624bea490..a71bd4714 100644 --- a/agents/src/llm/llm.ts +++ b/agents/src/llm/llm.ts @@ -65,6 +65,18 @@ export abstract class LLM extends (EventEmitter as new () => TypedEmitter { } return (usage?.completionTokens || 0) / (durationMs / 1000); })(), + metadata: { + modelProvider: this.#llm.provider, + modelName: this.#llm.model, + }, }; if (this.#llmRequestSpan) { diff --git a/agents/src/llm/realtime.ts b/agents/src/llm/realtime.ts index 5c132afd0..864e25d2d 100644 --- a/agents/src/llm/realtime.ts +++ b/agents/src/llm/realtime.ts @@ -73,6 +73,10 @@ export abstract class RealtimeModel { /** The model name/identifier used by this realtime model */ abstract get model(): string; + get provider(): string { + return 'unknown'; + } + abstract session(): RealtimeSession; abstract close(): Promise; diff --git a/agents/src/metrics/base.ts b/agents/src/metrics/base.ts index 7f6d6a0cc..1c9c317c1 100644 --- a/agents/src/metrics/base.ts +++ b/agents/src/metrics/base.ts @@ -2,13 +2,21 @@ // // SPDX-License-Identifier: Apache-2.0 +export type MetricsMetadata = { + /** The provider name (e.g., 'openai', 'anthropic'). */ + modelProvider?: string; + /** The model name (e.g., 'gpt-4o', 'claude-3-5-sonnet'). */ + modelName?: string; +}; + export type AgentMetrics = | STTMetrics | LLMMetrics | TTSMetrics | VADMetrics | EOUMetrics - | RealtimeModelMetrics; + | RealtimeModelMetrics + | InterruptionMetrics; export type LLMMetrics = { type: 'llm_metrics'; @@ -26,6 +34,8 @@ export type LLMMetrics = { totalTokens: number; tokensPerSecond: number; speechId?: string; + /** Metadata for model provider and name tracking. */ + metadata?: MetricsMetadata; }; export type STTMetrics = { @@ -41,10 +51,16 @@ export type STTMetrics = { * The duration of the pushed audio in milliseconds. */ audioDurationMs: number; + /** Input audio tokens (for token-based billing). */ + inputTokens?: number; + /** Output text tokens (for token-based billing). */ + outputTokens?: number; /** * Whether the STT is streaming (e.g using websocket). */ streamed: boolean; + /** Metadata for model provider and name tracking. */ + metadata?: MetricsMetadata; }; export type TTSMetrics = { @@ -59,10 +75,17 @@ export type TTSMetrics = { /** Generated audio duration in milliseconds. */ audioDurationMs: number; cancelled: boolean; + /** Number of characters synthesized (for character-based billing). */ charactersCount: number; + /** Input text tokens (for token-based billing, e.g., OpenAI TTS). */ + inputTokens?: number; + /** Output audio tokens (for token-based billing, e.g., OpenAI TTS). */ + outputTokens?: number; streamed: boolean; segmentId?: string; speechId?: string; + /** Metadata for model provider and name tracking. */ + metadata?: MetricsMetadata; }; export type VADMetrics = { @@ -133,6 +156,10 @@ export type RealtimeModelMetrics = { * The duration of the response from created to done in milliseconds. */ durationMs: number; + /** + * The duration of the session connection in milliseconds (for session-based billing like xAI). + */ + sessionDurationMs?: number; /** * Time to first audio token in milliseconds. -1 if no audio token was sent. */ @@ -165,4 +192,24 @@ export type RealtimeModelMetrics = { * Details about the output tokens used in the Response. */ outputTokenDetails: RealtimeModelMetricsOutputTokenDetails; + /** Metadata for model provider and name tracking. */ + metadata?: MetricsMetadata; +}; + +export type InterruptionMetrics = { + type: 'interruption_metrics'; + timestamp: number; + /** Latest RTT time taken to perform inference, in milliseconds. */ + totalDuration: number; + /** Latest time taken by the model side, in milliseconds. */ + predictionDuration: number; + /** Latest total time from onset of speech to final prediction, in milliseconds. */ + detectionDelay: number; + /** Number of interruptions detected (incremental). */ + numInterruptions: number; + /** Number of backchannels detected (incremental). */ + numBackchannels: number; + /** Number of requests sent to the model (incremental). */ + numRequests: number; + metadata?: MetricsMetadata; }; diff --git a/agents/src/metrics/index.ts b/agents/src/metrics/index.ts index f400a9638..f3cce796c 100644 --- a/agents/src/metrics/index.ts +++ b/agents/src/metrics/index.ts @@ -5,11 +5,22 @@ export type { AgentMetrics, EOUMetrics, + InterruptionMetrics, LLMMetrics, + MetricsMetadata, RealtimeModelMetrics, STTMetrics, TTSMetrics, VADMetrics, } from './base.js'; +export { + filterZeroValues, + ModelUsageCollector, + type InterruptionModelUsage, + type LLMModelUsage, + type ModelUsage, + type STTModelUsage, + type TTSModelUsage, +} from './model_usage.js'; export { UsageCollector, type UsageSummary } from './usage_collector.js'; export { logMetrics } from './utils.js'; diff --git a/agents/src/metrics/model_usage.test.ts b/agents/src/metrics/model_usage.test.ts new file mode 100644 index 000000000..d2f983beb --- /dev/null +++ b/agents/src/metrics/model_usage.test.ts @@ -0,0 +1,545 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { beforeEach, describe, expect, it } from 'vitest'; +import type { LLMMetrics, RealtimeModelMetrics, STTMetrics, TTSMetrics } from './base.js'; +import { + type LLMModelUsage, + ModelUsageCollector, + type STTModelUsage, + type TTSModelUsage, + filterZeroValues, +} from './model_usage.js'; + +describe('model_usage', () => { + describe('filterZeroValues', () => { + it('should filter out zero values from LLMModelUsage', () => { + const usage: LLMModelUsage = { + type: 'llm_usage', + provider: 'openai', + model: 'gpt-4o', + inputTokens: 100, + inputCachedTokens: 0, + inputAudioTokens: 0, + inputCachedAudioTokens: 0, + inputTextTokens: 0, + inputCachedTextTokens: 0, + inputImageTokens: 0, + inputCachedImageTokens: 0, + outputTokens: 50, + outputAudioTokens: 0, + outputTextTokens: 0, + sessionDurationMs: 0, + }; + + const filtered = filterZeroValues(usage); + + expect(filtered.type).toBe('llm_usage'); + expect(filtered.provider).toBe('openai'); + expect(filtered.model).toBe('gpt-4o'); + expect(filtered.inputTokens).toBe(100); + expect(filtered.outputTokens).toBe(50); + // Zero values should be filtered out + expect(filtered.inputCachedTokens).toBeUndefined(); + expect(filtered.inputAudioTokens).toBeUndefined(); + expect(filtered.sessionDurationMs).toBeUndefined(); + }); + + it('should filter out zero values from TTSModelUsage', () => { + const usage: TTSModelUsage = { + type: 'tts_usage', + provider: 'elevenlabs', + model: 'eleven_turbo_v2', + inputTokens: 0, + outputTokens: 0, + charactersCount: 500, + audioDurationMs: 3000, + }; + + const filtered = filterZeroValues(usage); + + expect(filtered.type).toBe('tts_usage'); + expect(filtered.provider).toBe('elevenlabs'); + expect(filtered.charactersCount).toBe(500); + expect(filtered.audioDurationMs).toBe(3000); + expect(filtered.inputTokens).toBeUndefined(); + expect(filtered.outputTokens).toBeUndefined(); + }); + + it('should keep all values when none are zero', () => { + const usage: STTModelUsage = { + type: 'stt_usage', + provider: 'deepgram', + model: 'nova-2', + inputTokens: 10, + outputTokens: 20, + audioDurationMs: 5000, + }; + + const filtered = filterZeroValues(usage); + + expect(Object.keys(filtered)).toHaveLength(6); + expect(filtered).toEqual(usage); + }); + }); + + describe('ModelUsageCollector', () => { + let collector: ModelUsageCollector; + + beforeEach(() => { + collector = new ModelUsageCollector(); + }); + + describe('collect LLM metrics', () => { + it('should aggregate LLM metrics by provider and model', () => { + const metrics1: LLMMetrics = { + type: 'llm_metrics', + label: 'test', + requestId: 'req1', + timestamp: Date.now(), + durationMs: 100, + ttftMs: 50, + cancelled: false, + completionTokens: 100, + promptTokens: 200, + promptCachedTokens: 50, + totalTokens: 300, + tokensPerSecond: 10, + metadata: { + modelProvider: 'openai', + modelName: 'gpt-4o', + }, + }; + + const metrics2: LLMMetrics = { + type: 'llm_metrics', + label: 'test', + requestId: 'req2', + timestamp: Date.now(), + durationMs: 150, + ttftMs: 60, + cancelled: false, + completionTokens: 150, + promptTokens: 300, + promptCachedTokens: 75, + totalTokens: 450, + tokensPerSecond: 12, + metadata: { + modelProvider: 'openai', + modelName: 'gpt-4o', + }, + }; + + collector.collect(metrics1); + collector.collect(metrics2); + + const usage = collector.flatten(); + expect(usage).toHaveLength(1); + + const llmUsage = usage[0] as LLMModelUsage; + expect(llmUsage.type).toBe('llm_usage'); + expect(llmUsage.provider).toBe('openai'); + expect(llmUsage.model).toBe('gpt-4o'); + expect(llmUsage.inputTokens).toBe(500); // 200 + 300 + expect(llmUsage.inputCachedTokens).toBe(125); // 50 + 75 + expect(llmUsage.outputTokens).toBe(250); // 100 + 150 + }); + + it('should separate metrics by different providers', () => { + const openaiMetrics: LLMMetrics = { + type: 'llm_metrics', + label: 'test', + requestId: 'req1', + timestamp: Date.now(), + durationMs: 100, + ttftMs: 50, + cancelled: false, + completionTokens: 100, + promptTokens: 200, + promptCachedTokens: 0, + totalTokens: 300, + tokensPerSecond: 10, + metadata: { + modelProvider: 'openai', + modelName: 'gpt-4o', + }, + }; + + const anthropicMetrics: LLMMetrics = { + type: 'llm_metrics', + label: 'test', + requestId: 'req2', + timestamp: Date.now(), + durationMs: 120, + ttftMs: 55, + cancelled: false, + completionTokens: 80, + promptTokens: 150, + promptCachedTokens: 0, + totalTokens: 230, + tokensPerSecond: 8, + metadata: { + modelProvider: 'anthropic', + modelName: 'claude-3-5-sonnet', + }, + }; + + collector.collect(openaiMetrics); + collector.collect(anthropicMetrics); + + const usage = collector.flatten(); + expect(usage).toHaveLength(2); + + const openaiUsage = usage.find( + (u) => u.type === 'llm_usage' && u.provider === 'openai', + ) as LLMModelUsage; + const anthropicUsage = usage.find( + (u) => u.type === 'llm_usage' && u.provider === 'anthropic', + ) as LLMModelUsage; + + expect(openaiUsage.inputTokens).toBe(200); + expect(openaiUsage.outputTokens).toBe(100); + expect(anthropicUsage.inputTokens).toBe(150); + expect(anthropicUsage.outputTokens).toBe(80); + }); + }); + + describe('collect TTS metrics', () => { + it('should aggregate TTS metrics by provider and model', () => { + const metrics1: TTSMetrics = { + type: 'tts_metrics', + label: 'test', + requestId: 'req1', + timestamp: Date.now(), + ttfbMs: 100, + durationMs: 500, + audioDurationMs: 3000, + cancelled: false, + charactersCount: 100, + inputTokens: 10, + outputTokens: 20, + streamed: true, + metadata: { + modelProvider: 'elevenlabs', + modelName: 'eleven_turbo_v2', + }, + }; + + const metrics2: TTSMetrics = { + type: 'tts_metrics', + label: 'test', + requestId: 'req2', + timestamp: Date.now(), + ttfbMs: 120, + durationMs: 600, + audioDurationMs: 4000, + cancelled: false, + charactersCount: 200, + inputTokens: 15, + outputTokens: 25, + streamed: true, + metadata: { + modelProvider: 'elevenlabs', + modelName: 'eleven_turbo_v2', + }, + }; + + collector.collect(metrics1); + collector.collect(metrics2); + + const usage = collector.flatten(); + expect(usage).toHaveLength(1); + + const ttsUsage = usage[0] as TTSModelUsage; + expect(ttsUsage.type).toBe('tts_usage'); + expect(ttsUsage.provider).toBe('elevenlabs'); + expect(ttsUsage.model).toBe('eleven_turbo_v2'); + expect(ttsUsage.charactersCount).toBe(300); // 100 + 200 + expect(ttsUsage.audioDurationMs).toBe(7000); // 3000 + 4000 + expect(ttsUsage.inputTokens).toBe(25); // 10 + 15 + expect(ttsUsage.outputTokens).toBe(45); // 20 + 25 + }); + }); + + describe('collect STT metrics', () => { + it('should aggregate STT metrics by provider and model', () => { + const metrics1: STTMetrics = { + type: 'stt_metrics', + label: 'test', + requestId: 'req1', + timestamp: Date.now(), + durationMs: 0, + audioDurationMs: 5000, + inputTokens: 50, + outputTokens: 100, + streamed: true, + metadata: { + modelProvider: 'deepgram', + modelName: 'nova-2', + }, + }; + + const metrics2: STTMetrics = { + type: 'stt_metrics', + label: 'test', + requestId: 'req2', + timestamp: Date.now(), + durationMs: 0, + audioDurationMs: 3000, + inputTokens: 30, + outputTokens: 60, + streamed: true, + metadata: { + modelProvider: 'deepgram', + modelName: 'nova-2', + }, + }; + + collector.collect(metrics1); + collector.collect(metrics2); + + const usage = collector.flatten(); + expect(usage).toHaveLength(1); + + const sttUsage = usage[0] as STTModelUsage; + expect(sttUsage.type).toBe('stt_usage'); + expect(sttUsage.provider).toBe('deepgram'); + expect(sttUsage.model).toBe('nova-2'); + expect(sttUsage.audioDurationMs).toBe(8000); // 5000 + 3000 + expect(sttUsage.inputTokens).toBe(80); // 50 + 30 + expect(sttUsage.outputTokens).toBe(160); // 100 + 60 + }); + }); + + describe('collect realtime model metrics', () => { + it('should aggregate realtime model metrics with detailed token breakdown', () => { + const metrics: RealtimeModelMetrics = { + type: 'realtime_model_metrics', + label: 'test', + requestId: 'req1', + timestamp: Date.now(), + durationMs: 1000, + ttftMs: 100, + cancelled: false, + inputTokens: 500, + outputTokens: 300, + totalTokens: 800, + tokensPerSecond: 10, + sessionDurationMs: 5000, + inputTokenDetails: { + audioTokens: 200, + textTokens: 250, + imageTokens: 50, + cachedTokens: 100, + cachedTokensDetails: { + audioTokens: 30, + textTokens: 50, + imageTokens: 20, + }, + }, + outputTokenDetails: { + textTokens: 200, + audioTokens: 100, + imageTokens: 0, + }, + metadata: { + modelProvider: 'openai', + modelName: 'gpt-4o-realtime', + }, + }; + + collector.collect(metrics); + + const usage = collector.flatten(); + expect(usage).toHaveLength(1); + + const llmUsage = usage[0] as LLMModelUsage; + expect(llmUsage.type).toBe('llm_usage'); + expect(llmUsage.provider).toBe('openai'); + expect(llmUsage.model).toBe('gpt-4o-realtime'); + expect(llmUsage.inputTokens).toBe(500); + expect(llmUsage.inputCachedTokens).toBe(100); + expect(llmUsage.inputAudioTokens).toBe(200); + expect(llmUsage.inputCachedAudioTokens).toBe(30); + expect(llmUsage.inputTextTokens).toBe(250); + expect(llmUsage.inputCachedTextTokens).toBe(50); + expect(llmUsage.inputImageTokens).toBe(50); + expect(llmUsage.inputCachedImageTokens).toBe(20); + expect(llmUsage.outputTokens).toBe(300); + expect(llmUsage.outputTextTokens).toBe(200); + expect(llmUsage.outputAudioTokens).toBe(100); + expect(llmUsage.sessionDurationMs).toBe(5000); + }); + }); + + describe('mixed metrics collection', () => { + it('should collect and separate LLM, TTS, and STT metrics', () => { + const llmMetrics: LLMMetrics = { + type: 'llm_metrics', + label: 'test', + requestId: 'req1', + timestamp: Date.now(), + durationMs: 100, + ttftMs: 50, + cancelled: false, + completionTokens: 100, + promptTokens: 200, + promptCachedTokens: 0, + totalTokens: 300, + tokensPerSecond: 10, + metadata: { + modelProvider: 'openai', + modelName: 'gpt-4o', + }, + }; + + const ttsMetrics: TTSMetrics = { + type: 'tts_metrics', + label: 'test', + requestId: 'req2', + timestamp: Date.now(), + ttfbMs: 100, + durationMs: 500, + audioDurationMs: 3000, + cancelled: false, + charactersCount: 100, + streamed: true, + metadata: { + modelProvider: 'elevenlabs', + modelName: 'eleven_turbo_v2', + }, + }; + + const sttMetrics: STTMetrics = { + type: 'stt_metrics', + label: 'test', + requestId: 'req3', + timestamp: Date.now(), + durationMs: 0, + audioDurationMs: 5000, + streamed: true, + metadata: { + modelProvider: 'deepgram', + modelName: 'nova-2', + }, + }; + + collector.collect(llmMetrics); + collector.collect(ttsMetrics); + collector.collect(sttMetrics); + + const usage = collector.flatten(); + expect(usage).toHaveLength(3); + + const llmUsage = usage.find((u) => u.type === 'llm_usage'); + const ttsUsage = usage.find((u) => u.type === 'tts_usage'); + const sttUsage = usage.find((u) => u.type === 'stt_usage'); + + expect(llmUsage).toBeDefined(); + expect(ttsUsage).toBeDefined(); + expect(sttUsage).toBeDefined(); + }); + }); + + describe('flatten returns copies', () => { + it('should return deep copies of usage objects', () => { + const metrics: LLMMetrics = { + type: 'llm_metrics', + label: 'test', + requestId: 'req1', + timestamp: Date.now(), + durationMs: 100, + ttftMs: 50, + cancelled: false, + completionTokens: 100, + promptTokens: 200, + promptCachedTokens: 0, + totalTokens: 300, + tokensPerSecond: 10, + metadata: { + modelProvider: 'openai', + modelName: 'gpt-4o', + }, + }; + + collector.collect(metrics); + + const usage1 = collector.flatten(); + const usage2 = collector.flatten(); + + // Should be equal values + expect(usage1[0]).toEqual(usage2[0]); + + // But not the same object reference + expect(usage1[0]).not.toBe(usage2[0]); + + // Modifying one shouldn't affect the other + (usage1[0] as LLMModelUsage).inputTokens = 9999; + expect((usage2[0] as LLMModelUsage).inputTokens).toBe(200); + }); + }); + + describe('handles missing metadata', () => { + it('should use empty strings when metadata is missing', () => { + const metrics: LLMMetrics = { + type: 'llm_metrics', + label: 'test', + requestId: 'req1', + timestamp: Date.now(), + durationMs: 100, + ttftMs: 50, + cancelled: false, + completionTokens: 100, + promptTokens: 200, + promptCachedTokens: 0, + totalTokens: 300, + tokensPerSecond: 10, + // No metadata + }; + + collector.collect(metrics); + + const usage = collector.flatten(); + expect(usage).toHaveLength(1); + + const llmUsage = usage[0] as LLMModelUsage; + expect(llmUsage.provider).toBe(''); + expect(llmUsage.model).toBe(''); + }); + }); + + describe('ignores VAD and EOU metrics', () => { + it('should not collect VAD metrics', () => { + const vadMetrics = { + type: 'vad_metrics' as const, + label: 'test', + timestamp: Date.now(), + idleTimeMs: 100, + inferenceDurationTotalMs: 50, + inferenceCount: 10, + }; + + collector.collect(vadMetrics); + + const usage = collector.flatten(); + expect(usage).toHaveLength(0); + }); + + it('should not collect EOU metrics', () => { + const eouMetrics = { + type: 'eou_metrics' as const, + timestamp: Date.now(), + endOfUtteranceDelayMs: 100, + transcriptionDelayMs: 50, + onUserTurnCompletedDelayMs: 30, + lastSpeakingTimeMs: Date.now(), + }; + + collector.collect(eouMetrics); + + const usage = collector.flatten(); + expect(usage).toHaveLength(0); + }); + }); + }); +}); diff --git a/agents/src/metrics/model_usage.ts b/agents/src/metrics/model_usage.ts new file mode 100644 index 000000000..5e723fb51 --- /dev/null +++ b/agents/src/metrics/model_usage.ts @@ -0,0 +1,262 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { + AgentMetrics, + InterruptionMetrics, + LLMMetrics, + RealtimeModelMetrics, + STTMetrics, + TTSMetrics, +} from './base.js'; + +export type LLMModelUsage = { + type: 'llm_usage'; + /** The provider name (e.g., 'openai', 'anthropic'). */ + provider: string; + /** The model name (e.g., 'gpt-4o', 'claude-3-5-sonnet'). */ + model: string; + /** Total input tokens. */ + inputTokens: number; + /** Input tokens served from cache. */ + inputCachedTokens: number; + /** Input audio tokens (for multimodal models). */ + inputAudioTokens: number; + /** Cached input audio tokens. */ + inputCachedAudioTokens: number; + /** Input text tokens. */ + inputTextTokens: number; + /** Cached input text tokens. */ + inputCachedTextTokens: number; + /** Input image tokens (for multimodal models). */ + inputImageTokens: number; + /** Cached input image tokens. */ + inputCachedImageTokens: number; + /** Total output tokens. */ + outputTokens: number; + /** Output audio tokens (for multimodal models). */ + outputAudioTokens: number; + /** Output text tokens. */ + outputTextTokens: number; + /** Total session connection duration in milliseconds (for session-based billing like xAI). */ + sessionDurationMs: number; +}; + +export type TTSModelUsage = { + type: 'tts_usage'; + /** The provider name (e.g., 'elevenlabs', 'cartesia'). */ + provider: string; + /** The model name (e.g., 'eleven_turbo_v2', 'sonic'). */ + model: string; + /** Input text tokens (for token-based TTS billing, e.g., OpenAI TTS). */ + inputTokens: number; + /** Output audio tokens (for token-based TTS billing, e.g., OpenAI TTS). */ + outputTokens: number; + /** Number of characters synthesized (for character-based TTS billing). */ + charactersCount: number; + /** + * Duration of generated audio in milliseconds. + */ + audioDurationMs: number; +}; + +export type STTModelUsage = { + type: 'stt_usage'; + /** The provider name (e.g., 'deepgram', 'assemblyai'). */ + provider: string; + /** The model name (e.g., 'nova-2', 'best'). */ + model: string; + /** Input audio tokens (for token-based STT billing). */ + inputTokens: number; + /** Output text tokens (for token-based STT billing). */ + outputTokens: number; + /** Duration of processed audio in milliseconds. */ + audioDurationMs: number; +}; + +export type InterruptionModelUsage = { + type: 'interruption_usage'; + /** The provider name (e.g., 'livekit'). */ + provider: string; + /** The model name (e.g., 'adaptive interruption'). */ + model: string; + /** Total number of requests sent. */ + totalRequests: number; +}; + +export type ModelUsage = LLMModelUsage | TTSModelUsage | STTModelUsage | InterruptionModelUsage; + +export function filterZeroValues(usage: T): Partial { + const result: Partial = {} as Partial; + for (const [key, value] of Object.entries(usage)) { + if (value !== 0 && value !== 0.0) { + (result as Record)[key] = value; + } + } + return result; +} + +export class ModelUsageCollector { + private llmUsage: Map = new Map(); + private ttsUsage: Map = new Map(); + private sttUsage: Map = new Map(); + + private interruptionUsage: Map = new Map(); + + /** Extract provider and model from metrics metadata. */ + private extractProviderModel( + metrics: LLMMetrics | STTMetrics | TTSMetrics | RealtimeModelMetrics | InterruptionMetrics, + ): [string, string] { + let provider = ''; + let model = ''; + if (metrics.metadata) { + provider = metrics.metadata.modelProvider || ''; + model = metrics.metadata.modelName || ''; + } + return [provider, model]; + } + + /** Get or create an LLMModelUsage for the given provider/model combination. */ + private getLLMUsage(provider: string, model: string): LLMModelUsage { + const key = `${provider}:${model}`; + let usage = this.llmUsage.get(key); + if (!usage) { + usage = { + type: 'llm_usage', + provider, + model, + inputTokens: 0, + inputCachedTokens: 0, + inputAudioTokens: 0, + inputCachedAudioTokens: 0, + inputTextTokens: 0, + inputCachedTextTokens: 0, + inputImageTokens: 0, + inputCachedImageTokens: 0, + outputTokens: 0, + outputAudioTokens: 0, + outputTextTokens: 0, + sessionDurationMs: 0, + }; + this.llmUsage.set(key, usage); + } + return usage; + } + + /** Get or create a TTSModelUsage for the given provider/model combination. */ + private getTTSUsage(provider: string, model: string): TTSModelUsage { + const key = `${provider}:${model}`; + let usage = this.ttsUsage.get(key); + if (!usage) { + usage = { + type: 'tts_usage', + provider, + model, + inputTokens: 0, + outputTokens: 0, + charactersCount: 0, + audioDurationMs: 0, + }; + this.ttsUsage.set(key, usage); + } + return usage; + } + + /** Get or create an STTModelUsage for the given provider/model combination. */ + private getSTTUsage(provider: string, model: string): STTModelUsage { + const key = `${provider}:${model}`; + let usage = this.sttUsage.get(key); + if (!usage) { + usage = { + type: 'stt_usage', + provider, + model, + inputTokens: 0, + outputTokens: 0, + audioDurationMs: 0, + }; + this.sttUsage.set(key, usage); + } + return usage; + } + + private getInterruptionUsage(provider: string, model: string): InterruptionModelUsage { + const key = `${provider}:${model}`; + let usage = this.interruptionUsage.get(key); + if (!usage) { + usage = { + type: 'interruption_usage', + provider, + model, + totalRequests: 0, + }; + this.interruptionUsage.set(key, usage); + } + return usage; + } + + /** Collect metrics and aggregate usage by model/provider. */ + collect(metrics: AgentMetrics): void { + if (metrics.type === 'llm_metrics') { + const [provider, model] = this.extractProviderModel(metrics); + const usage = this.getLLMUsage(provider, model); + usage.inputTokens += metrics.promptTokens; + usage.inputCachedTokens += metrics.promptCachedTokens; + usage.outputTokens += metrics.completionTokens; + } else if (metrics.type === 'realtime_model_metrics') { + const [provider, model] = this.extractProviderModel(metrics); + const usage = this.getLLMUsage(provider, model); + usage.inputTokens += metrics.inputTokens; + usage.inputCachedTokens += metrics.inputTokenDetails.cachedTokens; + + usage.inputTextTokens += metrics.inputTokenDetails.textTokens; + usage.inputCachedTextTokens += metrics.inputTokenDetails.cachedTokensDetails?.textTokens ?? 0; + usage.inputImageTokens += metrics.inputTokenDetails.imageTokens; + usage.inputCachedImageTokens += + metrics.inputTokenDetails.cachedTokensDetails?.imageTokens ?? 0; + usage.inputAudioTokens += metrics.inputTokenDetails.audioTokens; + usage.inputCachedAudioTokens += + metrics.inputTokenDetails.cachedTokensDetails?.audioTokens ?? 0; + + usage.outputTextTokens += metrics.outputTokenDetails.textTokens; + usage.outputAudioTokens += metrics.outputTokenDetails.audioTokens; + usage.outputTokens += metrics.outputTokens; + usage.sessionDurationMs += metrics.sessionDurationMs ?? 0; + } else if (metrics.type === 'tts_metrics') { + const [provider, model] = this.extractProviderModel(metrics); + const ttsUsage = this.getTTSUsage(provider, model); + ttsUsage.inputTokens += metrics.inputTokens ?? 0; + ttsUsage.outputTokens += metrics.outputTokens ?? 0; + ttsUsage.charactersCount += metrics.charactersCount; + ttsUsage.audioDurationMs += metrics.audioDurationMs; + } else if (metrics.type === 'stt_metrics') { + const [provider, model] = this.extractProviderModel(metrics); + const sttUsage = this.getSTTUsage(provider, model); + sttUsage.inputTokens += metrics.inputTokens ?? 0; + sttUsage.outputTokens += metrics.outputTokens ?? 0; + sttUsage.audioDurationMs += metrics.audioDurationMs; + } else if (metrics.type === 'interruption_metrics') { + const [provider, model] = this.extractProviderModel(metrics); + const usage = this.getInterruptionUsage(provider, model); + usage.totalRequests += metrics.numRequests; + } + // VAD and EOU metrics are not aggregated for usage tracking. + } + + flatten(): ModelUsage[] { + const result: ModelUsage[] = []; + for (const u of this.llmUsage.values()) { + result.push({ ...u }); + } + for (const u of this.ttsUsage.values()) { + result.push({ ...u }); + } + for (const u of this.sttUsage.values()) { + result.push({ ...u }); + } + for (const u of this.interruptionUsage.values()) { + result.push({ ...u }); + } + return result; + } +} diff --git a/agents/src/metrics/usage_collector.ts b/agents/src/metrics/usage_collector.ts index c7f0e6c3d..c815c8394 100644 --- a/agents/src/metrics/usage_collector.ts +++ b/agents/src/metrics/usage_collector.ts @@ -1,8 +1,13 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +import { log } from '../log.js'; import type { AgentMetrics } from './base.js'; +/** + * @deprecated Use LLMModelUsage, TTSModelUsage, or STTModelUsage instead. + * These new types provide per-model/provider usage aggregation for more detailed tracking. + */ export interface UsageSummary { llmPromptTokens: number; llmPromptCachedTokens: number; @@ -11,10 +16,16 @@ export interface UsageSummary { sttAudioDurationMs: number; } +/** + * @deprecated Use ModelUsageCollector instead. + * ModelUsageCollector provides per-model/provider usage aggregation for more detailed tracking. + */ export class UsageCollector { private summary: UsageSummary; + private logger = log(); constructor() { + this.logger.warn('UsageCollector is deprecated. Use ModelUsageCollector instead.'); this.summary = { llmPromptTokens: 0, llmPromptCachedTokens: 0, diff --git a/agents/src/metrics/utils.ts b/agents/src/metrics/utils.ts index cf98f8d1d..ced021e63 100644 --- a/agents/src/metrics/utils.ts +++ b/agents/src/metrics/utils.ts @@ -60,5 +60,16 @@ export const logMetrics = (metrics: AgentMetrics) => { audioDurationMs: Math.round(metrics.audioDurationMs), }) .info('STT metrics'); + } else if (metrics.type === 'interruption_metrics') { + logger + .child({ + totalDurationMs: roundTwoDecimals(metrics.totalDuration), + predictionDurationMs: roundTwoDecimals(metrics.predictionDuration), + detectionDelayMs: roundTwoDecimals(metrics.detectionDelay), + numInterruptions: metrics.numInterruptions, + numBackchannels: metrics.numBackchannels, + numRequests: metrics.numRequests, + }) + .info('Interruption metrics'); } }; diff --git a/agents/src/stream/multi_input_stream.test.ts b/agents/src/stream/multi_input_stream.test.ts index cda78b62b..fb648ff32 100644 --- a/agents/src/stream/multi_input_stream.test.ts +++ b/agents/src/stream/multi_input_stream.test.ts @@ -2,7 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 import { ReadableStream } from 'node:stream/web'; -import { describe, expect, it } from 'vitest'; +import { beforeAll, describe, expect, it } from 'vitest'; +import { initializeLogger } from '../log.js'; import { delay } from '../utils.js'; import { MultiInputStream } from './multi_input_stream.js'; @@ -16,6 +17,10 @@ function streamFrom(values: T[]): ReadableStream { } describe('MultiInputStream', () => { + beforeAll(() => { + initializeLogger({ pretty: false }); + }); + // --------------------------------------------------------------------------- // Basic functionality // --------------------------------------------------------------------------- diff --git a/agents/src/stream/stream_channel.ts b/agents/src/stream/stream_channel.ts index 1fb68bab2..edaeaa856 100644 --- a/agents/src/stream/stream_channel.ts +++ b/agents/src/stream/stream_channel.ts @@ -4,14 +4,16 @@ import type { ReadableStream } from 'node:stream/web'; import { IdentityTransform } from './identity_transform.js'; -export interface StreamChannel { +export interface StreamChannel { write(chunk: T): Promise; close(): Promise; stream(): ReadableStream; + abort(error: E): Promise; readonly closed: boolean; + addStreamInput(stream: ReadableStream): void; } -export function createStreamChannel(): StreamChannel { +export function createStreamChannel(): StreamChannel { const transform = new IdentityTransform(); const writer = transform.writable.getWriter(); let isClosed = false; @@ -19,6 +21,36 @@ export function createStreamChannel(): StreamChannel { return { write: (chunk: T) => writer.write(chunk), stream: () => transform.readable, + abort: async (error: E) => { + if (isClosed) return; + isClosed = true; + try { + await writer.abort(error); + } catch (e) { + if (e instanceof Error && e.name === 'TypeError') return; + throw e; + } + }, + addStreamInput: (newInputStream) => { + if (isClosed) return; + const reader = newInputStream.getReader(); + (async () => { + try { + while (!isClosed) { + const { done, value } = await reader.read(); + if (done) break; + await writer.write(value); + } + } catch (err) { + if (!isClosed) { + isClosed = true; + await writer.abort(err as E); + } + } finally { + reader.releaseLock(); + } + })().catch(() => {}); + }, close: async () => { try { const result = await writer.close(); diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index d7ee9bf19..6c2da2b8c 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -66,6 +66,10 @@ export interface SpeechData { export interface RecognitionUsage { /** Duration of the audio that was recognized in seconds. */ audioDuration: number; + /** Input audio tokens (for token-based STT billing). */ + inputTokens?: number; + /** Output text tokens (for token-based STT billing). */ + outputTokens?: number; } /** SpeechEvent is a packet of speech-to-text data. */ @@ -128,6 +132,30 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter { const startTime = process.hrtime.bigint(); @@ -141,6 +169,10 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter durationMs: 0, label: this.#stt.label, audioDurationMs: Math.round(event.recognitionUsage!.audioDuration * 1000), + inputTokens: event.recognitionUsage!.inputTokens ?? 0, + outputTokens: event.recognitionUsage!.outputTokens ?? 0, streamed: true, + metadata: { + modelProvider: this.#stt.provider, + modelName: this.#stt.model, + }, }; this.#stt.emit('metrics_collected', metrics); } diff --git a/agents/src/telemetry/otel_http_exporter.ts b/agents/src/telemetry/otel_http_exporter.ts index ae6b6590e..43f01faea 100644 --- a/agents/src/telemetry/otel_http_exporter.ts +++ b/agents/src/telemetry/otel_http_exporter.ts @@ -58,6 +58,16 @@ export class SimpleOTLPHttpLogExporter { private readonly config: SimpleOTLPHttpLogExporterConfig; private jwt: string | null = null; + private static readonly FORCE_DOUBLE_KEYS = new Set([ + 'transcriptConfidence', + 'transcriptionDelay', + 'endOfTurnDelay', + 'onUserTurnCompletedDelay', + 'llmNodeTtft', + 'ttsNodeTtfb', + 'e2eLatency', + ]); + constructor(config: SimpleOTLPHttpLogExporterConfig) { this.config = config; } @@ -72,6 +82,7 @@ export class SimpleOTLPHttpLogExporter { const endpoint = `https://${this.config.cloudHostname}/observability/logs/otlp/v0`; const payload = this.buildPayload(records); + const payloadJson = JSON.stringify(payload); const response = await fetch(endpoint, { method: 'POST', @@ -79,7 +90,7 @@ export class SimpleOTLPHttpLogExporter { Authorization: `Bearer ${this.jwt}`, 'Content-Type': 'application/json', }, - body: JSON.stringify(payload), + body: payloadJson, }); if (!response.ok) { @@ -160,11 +171,11 @@ export class SimpleOTLPHttpLogExporter { ): Array<{ key: string; value: unknown }> { return Object.entries(attrs).map(([key, value]) => ({ key, - value: this.convertValue(value), + value: this.convertValue(value, key), })); } - private convertValue(value: unknown): unknown { + private convertValue(value: unknown, path: string = ''): unknown { if (value === null || value === undefined) { return { stringValue: '' }; } @@ -172,20 +183,32 @@ export class SimpleOTLPHttpLogExporter { return { stringValue: value }; } if (typeof value === 'number') { + const leafKey = + path + .split('.') + .pop() + ?.replace(/\[\d+\]$/, '') ?? path; + if (SimpleOTLPHttpLogExporter.FORCE_DOUBLE_KEYS.has(leafKey)) { + return { doubleValue: value }; + } return Number.isInteger(value) ? { intValue: String(value) } : { doubleValue: value }; } if (typeof value === 'boolean') { return { boolValue: value }; } if (Array.isArray(value)) { - return { arrayValue: { values: value.map((v) => this.convertValue(v)) } }; + return { + arrayValue: { + values: value.map((v, i) => this.convertValue(v, `${path}[${i}]`)), + }, + }; } if (typeof value === 'object') { return { kvlistValue: { values: Object.entries(value as Record).map(([k, v]) => ({ key: k, - value: this.convertValue(v), + value: this.convertValue(v, path ? `${path}.${k}` : k), })), }, }; diff --git a/agents/src/telemetry/trace_types.ts b/agents/src/telemetry/trace_types.ts index 1663bd75d..cc4d89443 100644 --- a/agents/src/telemetry/trace_types.ts +++ b/agents/src/telemetry/trace_types.ts @@ -33,6 +33,7 @@ export const ATTR_PROVIDER_TOOLS = 'lk.provider_tools'; export const ATTR_TOOL_SETS = 'lk.tool_sets'; export const ATTR_RESPONSE_TEXT = 'lk.response.text'; export const ATTR_RESPONSE_FUNCTION_CALLS = 'lk.response.function_calls'; +/** Time to first token in seconds. */ export const ATTR_RESPONSE_TTFT = 'lk.response.ttft'; // function tool @@ -46,6 +47,7 @@ export const ATTR_FUNCTION_TOOL_OUTPUT = 'lk.function_tool.output'; export const ATTR_TTS_INPUT_TEXT = 'lk.input_text'; export const ATTR_TTS_STREAMING = 'lk.tts.streaming'; export const ATTR_TTS_LABEL = 'lk.tts.label'; +/** Time to first byte in seconds. */ export const ATTR_RESPONSE_TTFB = 'lk.response.ttfb'; // eou detection @@ -58,18 +60,26 @@ export const ATTR_TRANSCRIPT_CONFIDENCE = 'lk.transcript_confidence'; export const ATTR_TRANSCRIPTION_DELAY = 'lk.transcription_delay'; export const ATTR_END_OF_TURN_DELAY = 'lk.end_of_turn_delay'; +// Adaptive Interruption attributes +export const ATTR_IS_INTERRUPTION = 'lk.is_interruption'; +export const ATTR_INTERRUPTION_PROBABILITY = 'lk.interruption.probability'; +export const ATTR_INTERRUPTION_TOTAL_DURATION = 'lk.interruption.total_duration'; +export const ATTR_INTERRUPTION_PREDICTION_DURATION = 'lk.interruption.prediction_duration'; +export const ATTR_INTERRUPTION_DETECTION_DELAY = 'lk.interruption.detection_delay'; + // metrics export const ATTR_LLM_METRICS = 'lk.llm_metrics'; export const ATTR_TTS_METRICS = 'lk.tts_metrics'; export const ATTR_REALTIME_MODEL_METRICS = 'lk.realtime_model_metrics'; -// latency span attributes +/** End-to-end latency in seconds. */ export const ATTR_E2E_LATENCY = 'lk.e2e_latency'; // OpenTelemetry GenAI attributes // OpenTelemetry specification: https://opentelemetry.io/docs/specs/semconv/registry/attributes/gen-ai/ export const ATTR_GEN_AI_OPERATION_NAME = 'gen_ai.operation.name'; export const ATTR_GEN_AI_REQUEST_MODEL = 'gen_ai.request.model'; +/** The provider name (e.g., 'openai', 'anthropic'). */ export const ATTR_GEN_AI_PROVIDER_NAME = 'gen_ai.provider.name'; export const ATTR_GEN_AI_USAGE_INPUT_TOKENS = 'gen_ai.usage.input_tokens'; export const ATTR_GEN_AI_USAGE_OUTPUT_TOKENS = 'gen_ai.usage.output_tokens'; @@ -97,10 +107,3 @@ export const ATTR_EXCEPTION_MESSAGE = 'exception.message'; // Platform-specific attributes export const ATTR_LANGFUSE_COMPLETION_START_TIME = 'langfuse.observation.completion_start_time'; - -// Adaptive Interruption attributes -export const ATTR_IS_INTERRUPTION = 'lk.is_interruption'; -export const ATTR_INTERRUPTION_PROBABILITY = 'lk.interruption.probability'; -export const ATTR_INTERRUPTION_TOTAL_DURATION = 'lk.interruption.total_duration'; -export const ATTR_INTERRUPTION_PREDICTION_DURATION = 'lk.interruption.prediction_duration'; -export const ATTR_INTERRUPTION_DETECTION_DELAY = 'lk.interruption.detection_delay'; diff --git a/agents/src/telemetry/traces.ts b/agents/src/telemetry/traces.ts index 28ef4c746..8ee52e586 100644 --- a/agents/src/telemetry/traces.ts +++ b/agents/src/telemetry/traces.ts @@ -22,8 +22,9 @@ import { ATTR_SERVICE_NAME } from '@opentelemetry/semantic-conventions'; import FormData from 'form-data'; import { AccessToken } from 'livekit-server-sdk'; import fs from 'node:fs/promises'; -import type { ChatContent, ChatItem } from '../llm/index.js'; +import type { ChatContent, ChatItem, ChatRole } from '../llm/index.js'; import { enableOtelLogging } from '../log.js'; +import { filterZeroValues } from '../metrics/model_usage.js'; import type { SessionReport } from '../voice/report.js'; import { type SimpleLogRecord, SimpleOTLPHttpLogExporter } from './otel_http_exporter.js'; import { flushPinoLogs, initPinoCloudExporter } from './pino_otel_transport.js'; @@ -285,24 +286,80 @@ export async function flushOtelLogs(): Promise { await flushPinoLogs(); } +/** Proto-compatible role enum values. */ +type ProtoRole = 'DEVELOPER' | 'SYSTEM' | 'USER' | 'ASSISTANT'; + +const ROLE_MAP: Record = { + developer: 'DEVELOPER', + system: 'SYSTEM', + user: 'USER', + assistant: 'ASSISTANT', +}; + +interface ProtoMetricsReport { + startedSpeakingAt?: string; + stoppedSpeakingAt?: string; + transcriptionDelay?: number; + endOfTurnDelay?: number; + onUserTurnCompletedDelay?: number; + llmNodeTtft?: number; + ttsNodeTtfb?: number; + e2eLatency?: number; +} + +interface ProtoMessage { + id: string; + role: ProtoRole; + content: { text: ChatContent }[]; + createdAt: string; + interrupted?: boolean; + extra?: Record; + transcriptConfidence?: number; + metrics?: ProtoMetricsReport; +} + +interface ProtoFunctionCall { + id: string; + callId: string; + arguments: string | Record; + name: string; + createdAt: string; +} + +interface ProtoFunctionCallOutput { + id: string; + name: string; + callId: string; + output: string; + isError: boolean; + createdAt: string; +} + +interface ProtoAgentHandoff { + id: string; + newAgentId: string; + createdAt: string; + oldAgentId?: string; +} + +interface ProtoChatItem { + message?: ProtoMessage; + functionCall?: ProtoFunctionCall; + functionCallOutput?: ProtoFunctionCallOutput; + agentHandoff?: ProtoAgentHandoff; +} + /** * Convert ChatItem to proto-compatible dictionary format. * TODO: Use actual agent_session proto types once @livekit/protocol v1.43.1+ is published */ -function chatItemToProto(item: ChatItem): Record { - const itemDict: Record = {}; +function chatItemToProto(item: ChatItem): ProtoChatItem { + const itemDict: ProtoChatItem = {}; if (item.type === 'message') { - const roleMap: Record = { - developer: 'DEVELOPER', - system: 'SYSTEM', - user: 'USER', - assistant: 'ASSISTANT', - }; - - const msg: Record = { + const msg: ProtoMessage = { id: item.id, - role: roleMap[item.role] || item.role.toUpperCase(), + role: ROLE_MAP[item.role] ?? (item.role.toUpperCase() as ProtoRole), content: item.content.map((c: ChatContent) => ({ text: c })), createdAt: toRFC3339(item.createdAt), }; @@ -311,44 +368,43 @@ function chatItemToProto(item: ChatItem): Record { msg.interrupted = item.interrupted; } - // TODO(brian): Add extra and transcriptConfidence to ChatMessage - // if (item.extra && Object.keys(item.extra).length > 0) { - // msg.extra = item.extra; - // } - - // if (item.transcriptConfidence !== undefined && item.transcriptConfidence !== null) { - // msg.transcriptConfidence = item.transcriptConfidence; - // } - - // TODO(brian): Add metrics to ChatMessage - // const metrics = item.metrics || {}; - // if (Object.keys(metrics).length > 0) { - // msg.metrics = {}; - // if (metrics.started_speaking_at) { - // msg.metrics.startedSpeakingAt = toRFC3339(metrics.started_speaking_at); - // } - // if (metrics.stopped_speaking_at) { - // msg.metrics.stoppedSpeakingAt = toRFC3339(metrics.stopped_speaking_at); - // } - // if (metrics.transcription_delay !== undefined) { - // msg.metrics.transcriptionDelay = metrics.transcription_delay; - // } - // if (metrics.end_of_turn_delay !== undefined) { - // msg.metrics.endOfTurnDelay = metrics.end_of_turn_delay; - // } - // if (metrics.on_user_turn_completed_delay !== undefined) { - // msg.metrics.onUserTurnCompletedDelay = metrics.on_user_turn_completed_delay; - // } - // if (metrics.llm_node_ttft !== undefined) { - // msg.metrics.llmNodeTtft = metrics.llm_node_ttft; - // } - // if (metrics.tts_node_ttfb !== undefined) { - // msg.metrics.ttsNodeTtfb = metrics.tts_node_ttfb; - // } - // if (metrics.e2e_latency !== undefined) { - // msg.metrics.e2eLatency = metrics.e2e_latency; - // } - // } + if (item.extra && Object.keys(item.extra).length > 0) { + msg.extra = item.extra; + } + + if (item.transcriptConfidence !== undefined) { + msg.transcriptConfidence = item.transcriptConfidence; + } + + const metrics = item.metrics; + if (metrics && Object.keys(metrics).length > 0) { + const protoMetrics: ProtoMetricsReport = {}; + if (metrics.startedSpeakingAt !== undefined) { + protoMetrics.startedSpeakingAt = toRFC3339(metrics.startedSpeakingAt * 1000); + } + if (metrics.stoppedSpeakingAt !== undefined) { + protoMetrics.stoppedSpeakingAt = toRFC3339(metrics.stoppedSpeakingAt * 1000); + } + if (metrics.transcriptionDelay !== undefined) { + protoMetrics.transcriptionDelay = metrics.transcriptionDelay; + } + if (metrics.endOfTurnDelay !== undefined) { + protoMetrics.endOfTurnDelay = metrics.endOfTurnDelay; + } + if (metrics.onUserTurnCompletedDelay !== undefined) { + protoMetrics.onUserTurnCompletedDelay = metrics.onUserTurnCompletedDelay; + } + if (metrics.llmNodeTtft !== undefined) { + protoMetrics.llmNodeTtft = metrics.llmNodeTtft; + } + if (metrics.ttsNodeTtfb !== undefined) { + protoMetrics.ttsNodeTtfb = metrics.ttsNodeTtfb; + } + if (metrics.e2eLatency !== undefined) { + protoMetrics.e2eLatency = metrics.e2eLatency; + } + msg.metrics = protoMetrics; + } itemDict.message = msg; } else if (item.type === 'function_call') { @@ -369,7 +425,7 @@ function chatItemToProto(item: ChatItem): Record { createdAt: toRFC3339(item.createdAt), }; } else if (item.type === 'agent_handoff') { - const handoff: Record = { + const handoff: ProtoAgentHandoff = { id: item.id, newAgentId: item.newAgentId, createdAt: toRFC3339(item.createdAt), @@ -397,9 +453,7 @@ function chatItemToProto(item: ChatItem): Record { } /** - * Convert timestamp to RFC3339 format matching Python's _to_rfc3339. - * Note: TypeScript createdAt is in milliseconds (Date.now()), not seconds like Python. - * @internal + * Convert timestamp to RFC3339 format */ function toRFC3339(valueMs: number | Date): string { // valueMs is already in milliseconds (from Date.now()) @@ -445,6 +499,8 @@ export async function uploadSessionReport(options: { 'logger.name': 'chat_history', }; + const usage = report.modelUsage?.map(filterZeroValues) || null; + logRecords.push({ body: 'session report', timestampMs: report.startedAt || report.timestamp || 0, @@ -453,6 +509,7 @@ export async function uploadSessionReport(options: { 'session.options': report.options || {}, 'session.report_timestamp': report.timestamp, agent_name: agentName, + usage, }, }); diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index 8b8dcfda0..ab0477144 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -96,6 +96,30 @@ export abstract class TTS extends (EventEmitter as new () => TypedEmitter; #ttsRequestSpan?: Span; + #inputTokens = 0; + #outputTokens = 0; constructor(tts: TTS, connOptions: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS) { this.#tts = tts; @@ -284,6 +310,18 @@ export abstract class SynthesizeStream } } + /** + * Set token usage for token-based TTS billing (e.g., OpenAI TTS). + * Plugins should call this method to report token usage. + */ + protected setTokenUsage({ + inputTokens = 0, + outputTokens = 0, + }: { inputTokens?: number; outputTokens?: number } = {}): void { + this.#inputTokens = inputTokens; + this.#outputTokens = outputTokens; + } + protected async monitorMetrics() { const startTime = process.hrtime.bigint(); let audioDurationMs = 0; @@ -305,12 +343,22 @@ export abstract class SynthesizeStream audioDurationMs: roundedAudioDurationMs, cancelled: this.abortController.signal.aborted, label: this.#tts.label, - streamed: false, + inputTokens: this.#inputTokens, + outputTokens: this.#outputTokens, + streamed: true, + metadata: { + modelProvider: this.#tts.provider, + modelName: this.#tts.model, + }, }; if (this.#ttsRequestSpan) { this.#ttsRequestSpan.setAttribute(traceTypes.ATTR_TTS_METRICS, JSON.stringify(metrics)); } this.#tts.emit('metrics_collected', metrics); + + // Reset token usage after emitting metrics for the next segment + this.#inputTokens = 0; + this.#outputTokens = 0; } }; @@ -434,6 +482,8 @@ export abstract class ChunkedStream implements AsyncIterableIterator(); export const speechHandleStorage = new AsyncLocalStorage(); @@ -110,6 +113,7 @@ export interface AgentOptions { instructions: string; chatCtx?: ChatContext; tools?: ToolContext; + /** @deprecated use turnHandling instead */ turnDetection?: TurnDetectionMode; stt?: STT | STTModelString; vad?: VAD; @@ -117,16 +121,19 @@ export interface AgentOptions { tts?: TTS | TTSModelString; allowInterruptions?: boolean; minConsecutiveSpeechDelay?: number; + turnHandling?: TurnHandlingOptions; useTtsAlignedTranscript?: boolean; } export class Agent { private _id: string; - private turnDetection?: TurnDetectionMode; private _stt?: STT; private _vad?: VAD; private _llm?: LLM | RealtimeModel; private _tts?: TTS; + private turnHandling?: TurnHandlingOptions; + private _interruptionDetection: InterruptionOptions['mode']; + private _allowInterruptions?: boolean; private _useTtsAlignedTranscript?: boolean; /** @internal */ @@ -151,7 +158,9 @@ export class Agent { vad, llm, tts, + turnHandling, useTtsAlignedTranscript, + allowInterruptions, }: AgentOptions) { if (id) { this._id = id; @@ -176,7 +185,12 @@ export class Agent { }) : ChatContext.empty(); - this.turnDetection = turnDetection; + const migratedOptions = migrateLegacyOptions({ + turnDetection, + options: { turnHandling, allowInterruptions }, + }); + this.turnHandling = migratedOptions.options.turnHandling; + this._vad = vad; if (typeof stt === 'string') { @@ -197,6 +211,10 @@ export class Agent { this._tts = tts; } + this._interruptionDetection = this.turnHandling?.interruption.mode; + if (this.turnHandling?.interruption.mode !== undefined) { + this._allowInterruptions = !!this.turnHandling.interruption.mode; + } this._useTtsAlignedTranscript = useTtsAlignedTranscript; this._agentActivity = undefined; @@ -242,6 +260,14 @@ export class Agent { return this.getActivityOrThrow().agentSession as AgentSession; } + get interruptionDetection(): InterruptionOptions['mode'] { + return this._interruptionDetection; + } + + get allowInterruptions(): boolean | undefined { + return this._allowInterruptions; + } + async onEnter(): Promise {} async onExit(): Promise {} @@ -341,7 +367,8 @@ export class Agent { // Set startTimeOffset to provide linear timestamps across reconnections const audioInputStartedAt = - activity.agentSession._recorderIO?.recordingStartedAt ?? // Use recording start time if available + activity.inputStartedAt ?? // Use input started at proxied from AudioRecognition if available + activity.agentSession._recorderIO?.recordingStartedAt ?? // Fallback to recording start time if available activity.agentSession._startedAt ?? // Fallback to session start time Date.now(); // Fallback to current time diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index 252768f48..56027942a 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -8,7 +8,10 @@ import { ROOT_CONTEXT, context as otelContext, trace } from '@opentelemetry/api' import { Heap } from 'heap-js'; import { AsyncLocalStorage } from 'node:async_hooks'; import { ReadableStream, TransformStream } from 'node:stream/web'; -import { type ChatContext, ChatMessage } from '../llm/chat_context.js'; +import type { InterruptionDetectionError } from '../inference/interruption/errors.js'; +import { AdaptiveInterruptionDetector } from '../inference/interruption/interruption_detector.js'; +import type { OverlappingSpeechEvent } from '../inference/interruption/types.js'; +import { type ChatContext, ChatMessage, type MetricsReport } from '../llm/chat_context.js'; import { type ChatItem, type FunctionCall, @@ -30,6 +33,7 @@ import { isSameToolChoice, isSameToolContext } from '../llm/tool_context.js'; import { log } from '../log.js'; import type { EOUMetrics, + InterruptionMetrics, LLMMetrics, RealtimeModelMetrics, STTMetrics, @@ -57,7 +61,6 @@ import { type EndOfTurnInfo, type PreemptiveGenerationInfo, type RecognitionHooks, - type _TurnDetector, } from './audio_recognition.js'; import { AgentSessionEventTypes, @@ -101,6 +104,7 @@ interface PreemptiveGeneration { createdAt: number; } +// TODO add false interruption handling and barge in handling for https://github.com/livekit/agents/pull/3109/changes export class AgentActivity implements RecognitionHooks { agent: Agent; agentSession: AgentSession; @@ -111,7 +115,7 @@ export class AgentActivity implements RecognitionHooks { private audioRecognition?: AudioRecognition; private realtimeSession?: RealtimeSession; private realtimeSpans?: Map; // Maps response_id to OTEL span for metrics recording - private turnDetectionMode?: Exclude; + private turnDetectionMode?: TurnDetectionMode; private logger = log(); private _schedulingPaused = true; private _drainBlockedTasks: Task[] = []; @@ -126,6 +130,43 @@ export class AgentActivity implements RecognitionHooks { // default to null as None, which maps to the default provider tool choice value private toolChoice: ToolChoice | null = null; private _preemptiveGeneration?: PreemptiveGeneration; + private interruptionDetector?: AdaptiveInterruptionDetector; + private isInterruptionDetectionEnabled: boolean; + private isInterruptionByAudioActivityEnabled: boolean; + private isDefaultInterruptionByAudioActivityEnabled: boolean; + + private readonly onRealtimeGenerationCreated = (ev: GenerationCreatedEvent): void => + this.onGenerationCreated(ev); + + private readonly onRealtimeInputSpeechStarted = (ev: InputSpeechStartedEvent): void => + this.onInputSpeechStarted(ev); + + private readonly onRealtimeInputSpeechStopped = (ev: InputSpeechStoppedEvent): void => + this.onInputSpeechStopped(ev); + + private readonly onRealtimeInputAudioTranscriptionCompleted = ( + ev: InputTranscriptionCompleted, + ): void => this.onInputAudioTranscriptionCompleted(ev); + + private readonly onModelError = (ev: RealtimeModelError | STTError | TTSError | LLMError): void => + this.onError(ev); + + private readonly onInterruptionOverlappingSpeech = (ev: OverlappingSpeechEvent): void => { + this.agentSession.emit(AgentSessionEventTypes.UserOverlappingSpeech, ev); + }; + + private readonly onInterruptionMetricsCollected = (ev: InterruptionMetrics): void => { + this.agentSession.emit( + AgentSessionEventTypes.MetricsCollected, + createMetricsCollectedEvent({ metrics: ev }), + ); + }; + + private readonly onInterruptionError = (ev: InterruptionDetectionError): void => { + const errorEvent = createErrorEvent(ev, this.interruptionDetector); + this.agentSession.emit(AgentSessionEventTypes.Error, errorEvent); + this.agentSession._onError(ev); + }; /** @internal */ _mainTask?: Task; @@ -133,16 +174,6 @@ export class AgentActivity implements RecognitionHooks { _onExitTask?: Task; _userTurnCompletedTask?: Task; - private readonly onRealtimeGenerationCreated = (ev: GenerationCreatedEvent) => - this.onGenerationCreated(ev); - private readonly onRealtimeInputSpeechStarted = (ev: InputSpeechStartedEvent) => - this.onInputSpeechStarted(ev); - private readonly onRealtimeInputSpeechStopped = (ev: InputSpeechStoppedEvent) => - this.onInputSpeechStopped(ev); - private readonly onRealtimeInputAudioTranscriptionCompleted = (ev: InputTranscriptionCompleted) => - this.onInputAudioTranscriptionCompleted(ev); - private readonly onModelError = (ev: RealtimeModelError | STTError | TTSError | LLMError) => - this.onError(ev); constructor(agent: Agent, agentSession: AgentSession) { this.agent = agent; this.agentSession = agentSession; @@ -235,6 +266,16 @@ export class AgentActivity implements RecognitionHooks { 'for more responsive interruption handling.', ); } + + this.interruptionDetector = this.resolveInterruptionDetector(); + this.isInterruptionDetectionEnabled = !!this.interruptionDetector; + + // this allows taking over audio interruption temporarily until interruption is detected + // by default is is ture unless turnDetection is manual or realtime_llm + this.isInterruptionByAudioActivityEnabled = + this.turnDetectionMode !== 'manual' && this.turnDetectionMode !== 'realtime_llm'; + + this.isDefaultInterruptionByAudioActivityEnabled = this.isInterruptionByAudioActivityEnabled; } async start(): Promise { @@ -348,8 +389,9 @@ export class AgentActivity implements RecognitionHooks { vad: this.vad, turnDetector: typeof this.turnDetection === 'string' ? undefined : this.turnDetection, turnDetectionMode: this.turnDetectionMode, - minEndpointingDelay: this.agentSession.options.minEndpointingDelay, - maxEndpointingDelay: this.agentSession.options.maxEndpointingDelay, + interruptionDetection: this.interruptionDetector, + minEndpointingDelay: this.agentSession.options.turnHandling.endpointing.minDelay, + maxEndpointingDelay: this.agentSession.options.turnHandling.endpointing.maxDelay, rootSpanContext: this.agentSession.rootSpanContext, sttModel: this.stt?.label, sttProvider: this.getSttProvider(), @@ -423,7 +465,7 @@ export class AgentActivity implements RecognitionHooks { get allowInterruptions(): boolean { // TODO(AJS-51): Allow options to be defined in Agent class - return this.agentSession.options.allowInterruptions; + return this.agentSession.options.turnHandling.interruption?.mode !== false; } get useTtsAlignedTranscript(): boolean { @@ -440,6 +482,11 @@ export class AgentActivity implements RecognitionHooks { return this.agent.toolCtx; } + /** @internal */ + get inputStartedAt() { + return this.audioRecognition?.inputStartedAt; + } + async updateChatCtx(chatCtx: ChatContext): Promise { chatCtx = chatCtx.copy({ toolCtx: this.toolCtx }); @@ -471,7 +518,13 @@ export class AgentActivity implements RecognitionHooks { } } - updateOptions({ toolChoice }: { toolChoice?: ToolChoice | null }): void { + updateOptions({ + toolChoice, + turnDetection, + }: { + toolChoice?: ToolChoice | null; + turnDetection?: TurnDetectionMode; + }): void { if (toolChoice !== undefined) { this.toolChoice = toolChoice; } @@ -479,6 +532,22 @@ export class AgentActivity implements RecognitionHooks { if (this.realtimeSession) { this.realtimeSession.updateOptions({ toolChoice: this.toolChoice }); } + + if (turnDetection !== undefined) { + this.turnDetectionMode = turnDetection; + this.isDefaultInterruptionByAudioActivityEnabled = + this.turnDetectionMode !== 'manual' && this.turnDetectionMode !== 'realtime_llm'; + + // sync live flag immediately when not speaking so the change takes effect right away + if (this.agentSession.agentState !== 'speaking') { + this.isInterruptionByAudioActivityEnabled = + this.isDefaultInterruptionByAudioActivityEnabled; + } + } + + if (this.audioRecognition) { + this.audioRecognition.updateOptions({ turnDetection: this.turnDetectionMode }); + } } attachAudioInput(audioStream: ReadableStream): void { @@ -655,6 +724,13 @@ export class AgentActivity implements RecognitionHooks { if (!this.vad) { this.agentSession._updateUserState('speaking'); + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + this.audioRecognition.onStartOfOverlapSpeech( + 0, + Date.now(), + this.agentSession._userSpeakingSpan, + ); + } } // this.interrupt() is going to raise when allow_interruptions is False, @@ -673,6 +749,9 @@ export class AgentActivity implements RecognitionHooks { this.logger.info(ev, 'onInputSpeechStopped'); if (!this.vad) { + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + this.audioRecognition.onEndOfOverlapSpeech(Date.now(), this.agentSession._userSpeakingSpan); + } this.agentSession._updateUserState('listening'); } @@ -746,15 +825,32 @@ export class AgentActivity implements RecognitionHooks { onStartOfSpeech(ev: VADEvent): void { let speechStartTime = Date.now(); if (ev) { - speechStartTime = speechStartTime - ev.speechDuration; + // Subtract both speechDuration and inferenceDuration to correct for VAD model latency. + speechStartTime = speechStartTime - ev.speechDuration - ev.inferenceDuration; } this.agentSession._updateUserState('speaking', speechStartTime); + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + // Pass speechStartTime as the absolute startedAt timestamp. + this.audioRecognition.onStartOfOverlapSpeech( + ev.speechDuration, + speechStartTime, + this.agentSession._userSpeakingSpan, + ); + } } onEndOfSpeech(ev: VADEvent): void { let speechEndTime = Date.now(); if (ev) { - speechEndTime = speechEndTime - ev.silenceDuration; + // Subtract both silenceDuration and inferenceDuration to correct for VAD model latency. + speechEndTime = speechEndTime - ev.silenceDuration - ev.inferenceDuration; + } + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + // Pass speechEndTime as the absolute endedAt timestamp. + this.audioRecognition.onEndOfOverlapSpeech( + speechEndTime, + this.agentSession._userSpeakingSpan, + ); } this.agentSession._updateUserState('listening', speechEndTime); } @@ -765,12 +861,16 @@ export class AgentActivity implements RecognitionHooks { return; } - if (ev.speechDuration >= this.agentSession.options.minInterruptionDuration) { + if (ev.speechDuration >= this.agentSession.options.turnHandling.interruption?.minDuration) { this.interruptByAudioActivity(); } } private interruptByAudioActivity(): void { + if (!this.isInterruptionByAudioActivityEnabled) { + return; + } + if (this.agentSession._aecWarmupRemaining > 0) { // Disable interruption from audio activity while AEC warmup is active. return; @@ -785,7 +885,11 @@ export class AgentActivity implements RecognitionHooks { // - Always apply minInterruptionWords filtering when STT is available and minInterruptionWords > 0 // - Apply check to all STT results: empty string, undefined, or any length // - This ensures consistent behavior across all interruption scenarios - if (this.stt && this.agentSession.options.minInterruptionWords > 0 && this.audioRecognition) { + if ( + this.stt && + this.agentSession.options.turnHandling.interruption?.minWords > 0 && + this.audioRecognition + ) { const text = this.audioRecognition.currentTranscript; // TODO(shubhra): better word splitting for multi-language @@ -795,7 +899,7 @@ export class AgentActivity implements RecognitionHooks { // Only allow interruption if word count meets or exceeds minInterruptionWords // This applies to all cases: empty strings, partial speech, and full speech - if (wordCount < this.agentSession.options.minInterruptionWords) { + if (wordCount < this.agentSession.options.turnHandling.interruption?.minWords) { return; } } @@ -816,6 +920,14 @@ export class AgentActivity implements RecognitionHooks { } } + onInterruption(ev: OverlappingSpeechEvent) { + this.restoreInterruptionByAudioActivity(); + this.interruptByAudioActivity(); + if (this.audioRecognition) { + this.audioRecognition.onEndOfAgentSpeech(ev.overlapStartedAt || ev.timestamp); + } + } + onInterimTranscript(ev: SpeechEvent): void { if (this.llm instanceof RealtimeModel && this.llm.capabilities.userTranscription) { // skip stt transcription if userTranscription is enabled on the realtime model @@ -891,6 +1003,7 @@ export class AgentActivity implements RecognitionHooks { const userMessage = ChatMessage.create({ role: 'user', content: info.newTranscript, + transcriptConfidence: info.transcriptConfidence, }); const chatCtx = this.agent.chatCtx.copy(); const speechHandle = this.generateReply({ @@ -986,16 +1099,16 @@ export class AgentActivity implements RecognitionHooks { this._currentSpeech && this._currentSpeech.allowInterruptions && !this._currentSpeech.interrupted && - this.agentSession.options.minInterruptionWords > 0 + this.agentSession.options.turnHandling.interruption?.minWords > 0 ) { const wordCount = splitWords(info.newTranscript, true).length; - if (wordCount < this.agentSession.options.minInterruptionWords) { + if (wordCount < this.agentSession.options.turnHandling.interruption?.minWords) { // avoid interruption if the new_transcript contains fewer words than minInterruptionWords this.cancelPreemptiveGeneration(); this.logger.info( { wordCount, - minInterruptionWords: this.agentSession.options.minInterruptionWords, + minInterruptionWords: this.agentSession.options.turnHandling.interruption.minWords, }, 'skipping user input, word count below minimum interruption threshold', ); @@ -1293,6 +1406,7 @@ export class AgentActivity implements RecognitionHooks { let userMessage: ChatMessage | undefined = ChatMessage.create({ role: 'user', content: info.newTranscript, + transcriptConfidence: info.transcriptConfidence, }); // create a temporary mutable chat context to pass to onUserTurnCompleted @@ -1319,6 +1433,24 @@ export class AgentActivity implements RecognitionHooks { return; } + const userMetricsReport: MetricsReport = {}; + if (info.startedSpeakingAt !== undefined) { + userMetricsReport.startedSpeakingAt = info.startedSpeakingAt / 1000; // ms -> seconds + } + if (info.stoppedSpeakingAt !== undefined) { + userMetricsReport.stoppedSpeakingAt = info.stoppedSpeakingAt / 1000; // ms -> seconds + } + if (info.transcriptionDelay !== undefined) { + userMetricsReport.transcriptionDelay = info.transcriptionDelay / 1000; // ms -> seconds + } + if (info.endOfUtteranceDelay !== undefined) { + userMetricsReport.endOfTurnDelay = info.endOfUtteranceDelay / 1000; // ms -> seconds + } + userMetricsReport.onUserTurnCompletedDelay = callbackDuration / 1000; // ms -> seconds + if (userMessage) { + userMessage.metrics = userMetricsReport; + } + let speechHandle: SpeechHandle | undefined; if (this._preemptiveGeneration !== undefined) { const preemptive = this._preemptiveGeneration; @@ -1331,6 +1463,14 @@ export class AgentActivity implements RecognitionHooks { isSameToolChoice(preemptive.toolChoice, this.toolChoice) ) { speechHandle = preemptive.speechHandle; + // The preemptive userMessage was created without metrics. + // Copy the metrics and transcriptConfidence from the new userMessage + // to the preemptive message BEFORE scheduling (so the pipeline inserts + // the message with metrics already set). + if (preemptive.userMessage && userMessage) { + preemptive.userMessage.metrics = userMetricsReport; + preemptive.userMessage.transcriptConfidence = userMessage.transcriptConfidence; + } this.scheduleSpeech(speechHandle, SpeechHandle.SPEECH_PRIORITY_NORMAL); this.logger.debug( { @@ -1424,11 +1564,19 @@ export class AgentActivity implements RecognitionHooks { tasks.push(textForwardTask); } + let replyStartedSpeakingAt: number | undefined; + let replyTtsGenData: _TTSGenerationData | null = null; + const onFirstFrame = (startedSpeakingAt?: number) => { + replyStartedSpeakingAt = startedSpeakingAt ?? Date.now(); this.agentSession._updateAgentState('speaking', { startTime: startedSpeakingAt, otelContext: speechHandle._agentTurnContext, }); + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + this.audioRecognition.onStartOfAgentSpeech(); + this.isInterruptionByAudioActivityEnabled = false; + } }; if (!audioOutput) { @@ -1446,8 +1594,11 @@ export class AgentActivity implements RecognitionHooks { audioSource, modelSettings, replyAbortController, + this.tts?.model, + this.tts?.provider, ); tasks.push(ttsTask); + replyTtsGenData = ttsGenData; const [forwardTask, _audioOut] = performAudioForwarding( ttsGenData.audioStream, @@ -1487,10 +1638,21 @@ export class AgentActivity implements RecognitionHooks { } if (addToChatCtx) { + const replyStoppedSpeakingAt = Date.now(); + const replyAssistantMetrics: MetricsReport = {}; + if (replyTtsGenData?.ttfb !== undefined) { + replyAssistantMetrics.ttsNodeTtfb = replyTtsGenData.ttfb; + } + if (replyStartedSpeakingAt !== undefined) { + replyAssistantMetrics.startedSpeakingAt = replyStartedSpeakingAt / 1000; // ms -> seconds + replyAssistantMetrics.stoppedSpeakingAt = replyStoppedSpeakingAt / 1000; // ms -> seconds + } + const message = ChatMessage.create({ role: 'assistant', content: textOut?.text || '', interrupted: speechHandle.interrupted, + metrics: replyAssistantMetrics, }); this.agent._chatCtx.insert(message); this.agentSession._conversationItemAdded(message); @@ -1498,6 +1660,10 @@ export class AgentActivity implements RecognitionHooks { if (this.agentSession.agentState === 'speaking') { this.agentSession._updateAgentState('listening'); + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + this.audioRecognition.onEndOfAgentSpeech(Date.now()); + } + this.restoreInterruptionByAudioActivity(); } } @@ -1511,6 +1677,7 @@ export class AgentActivity implements RecognitionHooks { newMessage, toolsMessages, span, + _previousUserMetrics, }: { speechHandle: SpeechHandle; chatCtx: ChatContext; @@ -1521,6 +1688,7 @@ export class AgentActivity implements RecognitionHooks { newMessage?: ChatMessage; toolsMessages?: ChatItem[]; span: Span; + _previousUserMetrics?: MetricsReport; }): Promise => { speechHandle._agentTurnContext = otelContext.active(); @@ -1573,6 +1741,8 @@ export class AgentActivity implements RecognitionHooks { toolCtx, modelSettings, replyAbortController, + this.llm?.model, + this.llm?.provider, ); tasks.push(llmTask); @@ -1589,6 +1759,8 @@ export class AgentActivity implements RecognitionHooks { ttsTextInput, modelSettings, replyAbortController, + this.tts?.model, + this.tts?.provider, ); tasks.push(ttsTask); } else { @@ -1598,10 +1770,12 @@ export class AgentActivity implements RecognitionHooks { await speechHandle.waitIfNotInterrupted([speechHandle._waitForScheduled()]); + let userMetrics: MetricsReport | undefined = _previousUserMetrics; // Add new message to actual chat context if the speech is scheduled if (newMessage && speechHandle.scheduled) { this.agent._chatCtx.insert(newMessage); this.agentSession._conversationItemAdded(newMessage); + userMetrics = newMessage.metrics; } if (speechHandle.interrupted) { @@ -1647,11 +1821,17 @@ export class AgentActivity implements RecognitionHooks { textOut = _textOut; } + let agentStartedSpeakingAt: number | undefined; const onFirstFrame = (startedSpeakingAt?: number) => { + agentStartedSpeakingAt = startedSpeakingAt ?? Date.now(); this.agentSession._updateAgentState('speaking', { startTime: startedSpeakingAt, otelContext: speechHandle._agentTurnContext, }); + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + this.audioRecognition.onStartOfAgentSpeech(); + this.isInterruptionByAudioActivityEnabled = false; + } }; let audioOut: _AudioOut | null = null; @@ -1708,6 +1888,29 @@ export class AgentActivity implements RecognitionHooks { await speechHandle.waitIfNotInterrupted([audioOutput.waitForPlayout()]); } + const agentStoppedSpeakingAt = Date.now(); + const assistantMetrics: MetricsReport = {}; + + if (llmGenData.ttft !== undefined) { + assistantMetrics.llmNodeTtft = llmGenData.ttft; // already in seconds + } + if (ttsGenData?.ttfb !== undefined) { + assistantMetrics.ttsNodeTtfb = ttsGenData.ttfb; // already in seconds + } + if (agentStartedSpeakingAt !== undefined) { + assistantMetrics.startedSpeakingAt = agentStartedSpeakingAt / 1000; // ms -> seconds + assistantMetrics.stoppedSpeakingAt = agentStoppedSpeakingAt / 1000; // ms -> seconds + + if (userMetrics?.stoppedSpeakingAt !== undefined) { + const e2eLatency = agentStartedSpeakingAt / 1000 - userMetrics.stoppedSpeakingAt; + assistantMetrics.e2eLatency = e2eLatency; + span.setAttribute(traceTypes.ATTR_E2E_LATENCY, e2eLatency); + } + } + + span.setAttribute(traceTypes.ATTR_SPEECH_INTERRUPTED, speechHandle.interrupted); + let hasSpeechMessage = false; + // add the tools messages that triggers this reply to the chat context if (toolsMessages) { for (const msg of toolsMessages) { @@ -1762,45 +1965,54 @@ export class AgentActivity implements RecognitionHooks { } if (forwardedText) { + hasSpeechMessage = true; const message = ChatMessage.create({ role: 'assistant', content: forwardedText, id: llmGenData.id, interrupted: true, createdAt: replyStartedAt, + metrics: assistantMetrics, }); chatCtx.insert(message); this.agent._chatCtx.insert(message); speechHandle._itemAdded([message]); this.agentSession._conversationItemAdded(message); + span.setAttribute(traceTypes.ATTR_RESPONSE_TEXT, forwardedText); } if (this.agentSession.agentState === 'speaking') { this.agentSession._updateAgentState('listening'); + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + this.audioRecognition.onEndOfAgentSpeech(Date.now()); + this.restoreInterruptionByAudioActivity(); + } } this.logger.info( { speech_id: speechHandle.id, message: forwardedText }, 'playout completed with interrupt', ); - // TODO(shubhra) add chat message to speech handle speechHandle._markGenerationDone(); await executeToolsTask.cancelAndWait(AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); return; } if (textOut && textOut.text) { + hasSpeechMessage = true; const message = ChatMessage.create({ role: 'assistant', id: llmGenData.id, interrupted: false, createdAt: replyStartedAt, content: textOut.text, + metrics: assistantMetrics, }); chatCtx.insert(message); this.agent._chatCtx.insert(message); speechHandle._itemAdded([message]); this.agentSession._conversationItemAdded(message); + span.setAttribute(traceTypes.ATTR_RESPONSE_TEXT, textOut.text); this.logger.info( { speech_id: speechHandle.id, message: textOut.text }, 'playout completed without interruption', @@ -1811,6 +2023,12 @@ export class AgentActivity implements RecognitionHooks { this.agentSession._updateAgentState('thinking'); } else if (this.agentSession.agentState === 'speaking') { this.agentSession._updateAgentState('listening'); + if (this.isInterruptionDetectionEnabled && this.audioRecognition) { + { + this.audioRecognition.onEndOfAgentSpeech(Date.now()); + this.restoreInterruptionByAudioActivity(); + } + } } // mark the playout done before waiting for the tool execution @@ -1870,6 +2088,7 @@ export class AgentActivity implements RecognitionHooks { instructions, undefined, toolMessages, + hasSpeechMessage ? undefined : userMetrics, ), ownedSpeechHandle: speechHandle, name: 'AgentActivity.pipelineReply', @@ -1903,6 +2122,7 @@ export class AgentActivity implements RecognitionHooks { instructions?: string, newMessage?: ChatMessage, toolsMessages?: ChatItem[], + _previousUserMetrics?: MetricsReport, ): Promise => tracer.startActiveSpan( async (span) => @@ -1916,6 +2136,7 @@ export class AgentActivity implements RecognitionHooks { newMessage, toolsMessages, span, + _previousUserMetrics, }), { name: 'agent_turn', @@ -2066,6 +2287,8 @@ export class AgentActivity implements RecognitionHooks { ttsTextInput, modelSettings, abortController, + this.tts?.model, + this.tts?.provider, ); tasks.push(ttsTask); realtimeAudioResult = ttsGenData.audioStream; @@ -2575,6 +2798,14 @@ export class AgentActivity implements RecognitionHooks { if (this._mainTask) { await this._mainTask.cancelAndWait(); } + if (this.interruptionDetector) { + this.interruptionDetector.off( + 'user_overlapping_speech', + this.onInterruptionOverlappingSpeech, + ); + this.interruptionDetector.off('metrics_collected', this.onInterruptionMetricsCollected); + this.interruptionDetector.off('error', this.onInterruptionError); + } this.agent._agentActivity = undefined; } finally { @@ -2582,6 +2813,53 @@ export class AgentActivity implements RecognitionHooks { } } + private resolveInterruptionDetector(): AdaptiveInterruptionDetector | undefined { + const interruptionDetection = + this.agent.interruptionDetection ?? this.agentSession.interruptionDetection; + if ( + !( + this.stt && + this.stt.capabilities.alignedTranscript && + this.stt.capabilities.streaming && + this.vad && + this.turnDetection !== 'manual' && + this.turnDetection !== 'realtime_llm' && + !(this.llm instanceof RealtimeModel) + ) + ) { + if (interruptionDetection === 'adaptive') { + this.logger.warn( + "interruptionDetection is provided, but it's not compatible with the current configuration and will be disabled", + ); + return undefined; + } + } + + if ( + (interruptionDetection !== undefined && interruptionDetection === false) || + interruptionDetection === 'vad' + ) { + return undefined; + } + + try { + const detector = new AdaptiveInterruptionDetector(); + + detector.on('user_overlapping_speech', this.onInterruptionOverlappingSpeech); + detector.on('metrics_collected', this.onInterruptionMetricsCollected); + detector.on('error', this.onInterruptionError); + + return detector; + } catch (error: unknown) { + this.logger.warn({ error }, 'could not instantiate AdaptiveInterruptionDetector'); + } + return undefined; + } + + private restoreInterruptionByAudioActivity(): void { + this.isInterruptionByAudioActivityEnabled = this.isDefaultInterruptionByAudioActivityEnabled; + } + private async _closeSessionResources(): Promise { // Unregister event handlers to prevent duplicate metrics if (this.llm instanceof LLM) { diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index b1f095622..6b686499b 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -17,12 +17,15 @@ import { type STTModelString, type TTSModelString, } from '../inference/index.js'; +import type { InterruptionDetectionError } from '../inference/interruption/errors.js'; +import type { OverlappingSpeechEvent } from '../inference/interruption/types.js'; import { type JobContext, getJobContext } from '../job.js'; import type { FunctionCall, FunctionCallOutput } from '../llm/chat_context.js'; import { AgentHandoffItem, ChatContext, ChatMessage } from '../llm/chat_context.js'; import type { LLM, RealtimeModel, RealtimeModelError, ToolChoice } from '../llm/index.js'; import type { LLMError } from '../llm/llm.js'; import { log } from '../log.js'; +import { type ModelUsage, ModelUsageCollector, filterZeroValues } from '../metrics/model_usage.js'; import type { STT } from '../stt/index.js'; import type { STTError } from '../stt/stt.js'; import { traceTypes, tracer } from '../telemetry/index.js'; @@ -38,6 +41,7 @@ import type { VAD } from '../vad.js'; import type { Agent } from './agent.js'; import { AgentActivity } from './agent_activity.js'; import type { _TurnDetector } from './audio_recognition.js'; +import { ClientEventsHandler } from './client_events.js'; import { type AgentEvent, AgentSessionEventTypes, @@ -61,39 +65,90 @@ import { } from './events.js'; import { AgentInput, AgentOutput } from './io.js'; import { RecorderIO } from './recorder_io/index.js'; -import { RoomIO, type RoomInputOptions, type RoomOutputOptions } from './room_io/index.js'; +import { + DEFAULT_TEXT_INPUT_CALLBACK, + RoomIO, + type RoomInputOptions, + type RoomOutputOptions, +} from './room_io/index.js'; import type { UnknownUserData } from './run_context.js'; import type { SpeechHandle } from './speech_handle.js'; import { RunResult } from './testing/run_result.js'; +import type { InterruptionOptions } from './turn_config/interruption.js'; +import type { + InternalTurnHandlingOptions, + TurnHandlingOptions, +} from './turn_config/turn_handling.js'; +import { migrateLegacyOptions } from './turn_config/utils.js'; import { setParticipantSpanAttributes } from './utils.js'; -export interface VoiceOptions { - allowInterruptions: boolean; - discardAudioIfUninterruptible: boolean; - minInterruptionDuration: number; - minInterruptionWords: number; - minEndpointingDelay: number; - maxEndpointingDelay: number; +export interface AgentSessionUsage { + /** List of usage summaries, one per model/provider combination. */ + modelUsage: Array>; +} + +export interface SessionOptions { maxToolSteps: number; + /** + * Whether to speculatively begin LLM and TTS requests before an end-of-turn is detected. + * When `true`, the agent sends inference calls as soon as a user transcript is received rather + * than waiting for a definitive turn boundary. This can reduce response latency by overlapping + * model inference with user audio, but may incur extra compute if the user interrupts or + * revises mid-utterance. + * @defaultValue false + */ preemptiveGeneration: boolean; - userAwayTimeout?: number | null; + + /** + * If set, set the user state as "away" after this amount of time after user and agent are + * silent. Set to `null` to disable. + * @defaultValue 15.0 + */ + userAwayTimeout: number | null; + + /** + * Duration in milliseconds for AEC (Acoustic Echo Cancellation) warmup, during which + * interruptions from audio activity are suppressed. Set to `null` to disable. + * @defaultValue 3000 + */ aecWarmupDuration: number | null; + + /** + * Configuration for turn handling. + */ + turnHandling: Partial; + useTtsAlignedTranscript: boolean; + + /** @deprecated Use {@link SessionOptions.turnHandling}.interruption.mode instead. */ + allowInterruptions?: boolean; + /** @deprecated Use {@link SessionOptions.turnHandling}.interruption.discardAudioIfUninterruptible instead. */ + discardAudioIfUninterruptible?: boolean; + /** @deprecated Use {@link SessionOptions.turnHandling}.interruption.minDuration instead. */ + minInterruptionDuration?: number; + /** @deprecated Use {@link SessionOptions.turnHandling}.interruption.minWords instead. */ + minInterruptionWords?: number; + /** @deprecated Use {@link SessionOptions.turnHandling}.endpointing.minDelay instead. */ + minEndpointingDelay?: number; + /** @deprecated Use {@link SessionOptions.turnHandling}.endpointing.maxDelay instead. */ + maxEndpointingDelay?: number; +} + +export interface InternalSessionOptions extends SessionOptions { + turnHandling: InternalTurnHandlingOptions; } -const defaultVoiceOptions: VoiceOptions = { - allowInterruptions: true, - discardAudioIfUninterruptible: true, - minInterruptionDuration: 500, - minInterruptionWords: 0, - minEndpointingDelay: 500, - maxEndpointingDelay: 6000, +export const defaultSessionOptions = { maxToolSteps: 3, preemptiveGeneration: false, userAwayTimeout: 15.0, aecWarmupDuration: 3000, + turnHandling: {}, useTtsAlignedTranscript: true, -} as const; +} as const satisfies SessionOptions; + +/** @deprecated {@link VoiceOptions} has been renamed to {@link SessionOptions} */ +export type VoiceOptions = SessionOptions; export type TurnDetectionMode = 'stt' | 'vad' | 'realtime_llm' | 'manual' | _TurnDetector; @@ -107,17 +162,22 @@ export type AgentSessionCallbacks = { [AgentSessionEventTypes.SpeechCreated]: (ev: SpeechCreatedEvent) => void; [AgentSessionEventTypes.Error]: (ev: ErrorEvent) => void; [AgentSessionEventTypes.Close]: (ev: CloseEvent) => void; + [AgentSessionEventTypes.UserOverlappingSpeech]: (ev: OverlappingSpeechEvent) => void; }; export type AgentSessionOptions = { - turnDetection?: TurnDetectionMode; stt?: STT | STTModelString; vad?: VAD; llm?: LLM | RealtimeModel | LLMModels; tts?: TTS | TTSModelString; userData?: UserData; - voiceOptions?: Partial; + options?: Partial; connOptions?: SessionConnectOptions; + + /** @deprecated use {@link AgentSessionOptions.options}.turnHandling.turnDetection instead */ + turnDetection?: TurnDetectionMode; + /** @deprecated use {@link AgentSessionOptions.options} instead */ + voiceOptions?: Partial; }; type ActivityTransitionOptions = { @@ -136,22 +196,19 @@ export class AgentSession< tts?: TTS; turnDetection?: TurnDetectionMode; - readonly options: VoiceOptions; + readonly options: InternalSessionOptions; + private readonly activityLock = new Mutex(); private agent?: Agent; private activity?: AgentActivity; private nextActivity?: AgentActivity; private updateActivityTask?: Task; private started = false; - private userState: UserState = 'listening'; - private readonly activityLock = new Mutex(); - - /** @internal */ - _roomIO?: RoomIO; - private logger = log(); + private clientEventsHandler?: ClientEventsHandler; private _chatCtx: ChatContext; private _userData: UserData | undefined; + private _userState: UserState = 'listening'; private _agentState: AgentState = 'initializing'; private _input: AgentInput; @@ -168,11 +225,18 @@ export class AgentSession< // Unrecoverable error counts, reset after agent speaking private llmErrorCounts = 0; private ttsErrorCounts = 0; + private interruptionDetectionErrorCounts = 0; private sessionSpan?: Span; - private userSpeakingSpan?: Span; private agentSpeakingSpan?: Span; + private _interruptionDetection?: InterruptionOptions['mode']; + + private _usageCollector: ModelUsageCollector = new ModelUsageCollector(); + + /** @internal */ + _roomIO?: RoomIO; + /** @internal */ _aecWarmupRemaining = 0; @@ -194,20 +258,17 @@ export class AgentSession< /** @internal - Current run state for testing */ _globalRunState?: RunResult; - constructor(opts: AgentSessionOptions) { + /** @internal */ + _userSpeakingSpan?: Span; + + private logger = log(); + + constructor(options: AgentSessionOptions) { super(); - const { - vad, - stt, - llm, - tts, - turnDetection, - userData, - voiceOptions = defaultVoiceOptions, - connOptions, - } = opts; + const opts = migrateLegacyOptions(options); + const { vad, stt, llm, tts, userData, connOptions, options: sessionOptions } = opts; // Merge user-provided connOptions with defaults this._connOptions = { sttConnOptions: { ...DEFAULT_API_CONNECT_OPTIONS, ...connOptions?.sttConnOptions }, @@ -238,7 +299,8 @@ export class AgentSession< this.tts = tts; } - this.turnDetection = turnDetection; + this.turnDetection = sessionOptions?.turnHandling?.turnDetection; + this._interruptionDetection = sessionOptions?.turnHandling?.interruption?.mode; this._userData = userData; // configurable IO @@ -247,7 +309,7 @@ export class AgentSession< // This is the "global" chat context, it holds the entire conversation history this._chatCtx = ChatContext.empty(); - this.options = { ...defaultVoiceOptions, ...voiceOptions }; + this.options = opts.options; this._aecWarmupRemaining = this.options.aecWarmupDuration ?? 0; this._onUserInputTranscribed = this._onUserInputTranscribed.bind(this); @@ -260,6 +322,9 @@ export class AgentSession< ): boolean { const eventData = args[0] as AgentEvent; this._recordedEvents.push(eventData); + if (event === AgentSessionEventTypes.MetricsCollected) { + this._usageCollector.collect((eventData as MetricsCollectedEvent).metrics); + } return super.emit(event, ...args); } @@ -288,6 +353,18 @@ export class AgentSession< return this._connOptions; } + get interruptionDetection() { + return this._interruptionDetection; + } + + /** + * Returns usage summaries for this session, one per model/provider combination. + */ + get usage(): AgentSessionUsage { + // Skip zero fields for more concise usage display (matches python behavior). + return { modelUsage: this._usageCollector.flatten().map(filterZeroValues) }; + } + get useTtsAlignedTranscript(): boolean { return this.options.useTtsAlignedTranscript; } @@ -342,7 +419,15 @@ export class AgentSession< inputOptions, outputOptions, }); + this._roomIO.start(); + + this.clientEventsHandler = new ClientEventsHandler(this, this._roomIO); + if (inputOptions?.textEnabled !== false) { + this.clientEventsHandler.registerTextInput( + inputOptions?.textInputCallback ?? DEFAULT_TEXT_INPUT_CALLBACK, + ); + } } let ctx: JobContext | undefined = undefined; @@ -385,6 +470,10 @@ export class AgentSession< await Promise.allSettled(tasks); + if (this.clientEventsHandler) { + await this.clientEventsHandler.start(); + } + // Log used IO configuration this.logger.debug( `using audio io: ${this.input.audio ? '`' + this.input.audio.constructor.name + '`' : '(none)'} -> \`AgentSession\` -> ${this.output.audio ? '`' + this.output.audio.constructor.name + '`' : '(none)'}`, @@ -416,6 +505,8 @@ export class AgentSession< return; } + this._usageCollector = new ModelUsageCollector(); + let ctx: JobContext | undefined = undefined; try { ctx = getJobContext(); @@ -748,6 +839,10 @@ export class AgentSession< return this._agentState; } + get userState(): UserState { + return this._userState; + } + get currentAgent(): Agent { if (!this.agent) { throw new Error('AgentSession is not running'); @@ -786,7 +881,9 @@ export class AgentSession< } /** @internal */ - _onError(error: RealtimeModelError | STTError | TTSError | LLMError): void { + _onError( + error: RealtimeModelError | STTError | TTSError | LLMError | InterruptionDetectionError, + ): void { if (this.closingTask || error.recoverable) { return; } @@ -802,6 +899,11 @@ export class AgentSession< if (this.ttsErrorCounts <= this._connOptions.maxUnrecoverableErrors) { return; } + } else if (error.type === 'interruption_detection_error') { + this.interruptionDetectionErrorCounts += 1; + if (this.interruptionDetectionErrorCounts <= this._connOptions.maxUnrecoverableErrors) { + return; + } } this.logger.error(error, 'AgentSession is closing due to unrecoverable error'); @@ -831,9 +933,9 @@ export class AgentSession< } if (state === 'speaking') { - // Reset error counts when agent starts speaking this.llmErrorCounts = 0; this.ttsErrorCounts = 0; + this.interruptionDetectionErrorCounts = 0; if (this.agentSpeakingSpan === undefined) { this.agentSpeakingSpan = tracer.startSpan({ @@ -865,7 +967,7 @@ export class AgentSession< this._agentState = state; // Handle user away timer based on state changes - if (state === 'listening' && this.userState === 'listening') { + if (state === 'listening' && this._userState === 'listening') { this._setUserAwayTimer(); } else { this._cancelUserAwayTimer(); @@ -879,12 +981,12 @@ export class AgentSession< /** @internal */ _updateUserState(state: UserState, lastSpeakingTime?: number) { - if (this.userState === state) { + if (this._userState === state) { return; } - if (state === 'speaking' && this.userSpeakingSpan === undefined) { - this.userSpeakingSpan = tracer.startSpan({ + if (state === 'speaking' && this._userSpeakingSpan === undefined) { + this._userSpeakingSpan = tracer.startSpan({ name: 'user_speaking', context: this.rootSpanContext, startTime: lastSpeakingTime, @@ -892,15 +994,15 @@ export class AgentSession< const linked = this._roomIO?.linkedParticipant; if (linked) { - setParticipantSpanAttributes(this.userSpeakingSpan, linked); + setParticipantSpanAttributes(this._userSpeakingSpan, linked); } - } else if (this.userSpeakingSpan !== undefined) { - this.userSpeakingSpan.end(lastSpeakingTime); - this.userSpeakingSpan = undefined; + } else if (this._userSpeakingSpan !== undefined) { + this._userSpeakingSpan.end(lastSpeakingTime); + this._userSpeakingSpan = undefined; } - const oldState = this.userState; - this.userState = state; + const oldState = this._userState; + this._userState = state; // Handle user away timer based on state changes if (state === 'listening' && this._agentState === 'listening') { @@ -968,7 +1070,7 @@ export class AgentSession< } private _onUserInputTranscribed(ev: UserInputTranscribedEvent): void { - if (this.userState === 'away' && ev.isFinal) { + if (this._userState === 'away' && ev.isFinal) { this.logger.debug('User returned from away state due to speech input'); this._updateUserState('listening'); } @@ -976,7 +1078,13 @@ export class AgentSession< private async closeImpl( reason: ShutdownReason, - error: RealtimeModelError | LLMError | TTSError | STTError | null = null, + error: + | RealtimeModelError + | LLMError + | TTSError + | STTError + | InterruptionDetectionError + | null = null, drain: boolean = false, ): Promise { if (this.rootSpanContext) { @@ -990,7 +1098,13 @@ export class AgentSession< private async closeImplInner( reason: ShutdownReason, - error: RealtimeModelError | LLMError | TTSError | STTError | null = null, + error: + | RealtimeModelError + | LLMError + | TTSError + | STTError + | InterruptionDetectionError + | null = null, drain: boolean = false, ): Promise { if (!this.started) { @@ -1036,6 +1150,9 @@ export class AgentSession< this.output.audio = null; this.output.transcription = null; + await this.clientEventsHandler?.close(); + this.clientEventsHandler = undefined; + await this._roomIO?.close(); this._roomIO = undefined; @@ -1047,9 +1164,9 @@ export class AgentSession< this.sessionSpan = undefined; } - if (this.userSpeakingSpan) { - this.userSpeakingSpan.end(); - this.userSpeakingSpan = undefined; + if (this._userSpeakingSpan) { + this._userSpeakingSpan.end(); + this._userSpeakingSpan = undefined; } if (this.agentSpeakingSpan) { @@ -1061,11 +1178,12 @@ export class AgentSession< this.emit(AgentSessionEventTypes.Close, createCloseEvent(reason, error)); - this.userState = 'listening'; + this._userState = 'listening'; this._agentState = 'initializing'; this.rootSpanContext = undefined; this.llmErrorCounts = 0; this.ttsErrorCounts = 0; + this.interruptionDetectionErrorCounts = 0; this.logger.info({ reason, error }, 'AgentSession closed'); } diff --git a/agents/src/voice/audio_recognition.ts b/agents/src/voice/audio_recognition.ts index a564a842d..e5a2be078 100644 --- a/agents/src/voice/audio_recognition.ts +++ b/agents/src/voice/audio_recognition.ts @@ -12,14 +12,22 @@ import { } from '@opentelemetry/api'; import type { WritableStreamDefaultWriter } from 'node:stream/web'; import { ReadableStream } from 'node:stream/web'; +import { InterruptionDetectionError } from '../inference/interruption/errors.js'; +import type { AdaptiveInterruptionDetector } from '../inference/interruption/interruption_detector.js'; +import { InterruptionStreamSentinel } from '../inference/interruption/interruption_stream.js'; +import { + type InterruptionSentinel, + type OverlappingSpeechEvent, +} from '../inference/interruption/types.js'; import { type ChatContext } from '../llm/chat_context.js'; import { log } from '../log.js'; import { DeferredReadableStream, isStreamReaderReleaseError } from '../stream/deferred_stream.js'; import { IdentityTransform } from '../stream/identity_transform.js'; import { mergeReadableStreams } from '../stream/merge_readable_streams.js'; +import { type StreamChannel, createStreamChannel } from '../stream/stream_channel.js'; import { type SpeechEvent, SpeechEventType } from '../stt/stt.js'; import { traceTypes, tracer } from '../telemetry/index.js'; -import { Task, delay } from '../utils.js'; +import { Task, delay, waitForAbort } from '../utils.js'; import { type VAD, type VADEvent, VADEventType } from '../vad.js'; import type { TurnDetectionMode } from './agent_session.js'; import type { STTNode } from './io.js'; @@ -46,6 +54,7 @@ export interface PreemptiveGenerationInfo { } export interface RecognitionHooks { + onInterruption: (ev: OverlappingSpeechEvent) => void; onStartOfSpeech: (ev: VADEvent) => void; onVADInferenceDone: (ev: VADEvent) => void; onEndOfSpeech: (ev: VADEvent) => void; @@ -58,9 +67,13 @@ export interface RecognitionHooks { } export interface _TurnDetector { + /** The model name used by this turn detector. */ + readonly model: string; + /** The provider name for this turn detector. */ + readonly provider: string; unlikelyThreshold: (language?: string) => Promise; supportsLanguage: (language?: string) => Promise; - predictEndOfTurn(chatCtx: ChatContext): Promise; + predictEndOfTurn(chatCtx: ChatContext, timeout?: number): Promise; } export interface AudioRecognitionOptions { @@ -73,7 +86,8 @@ export interface AudioRecognitionOptions { /** Turn detector for end-of-turn prediction. */ turnDetector?: _TurnDetector; /** Turn detection mode. */ - turnDetectionMode?: Exclude; + turnDetectionMode?: TurnDetectionMode; + interruptionDetection?: AdaptiveInterruptionDetector; /** Minimum endpointing delay in milliseconds. */ minEndpointingDelay: number; /** Maximum endpointing delay in milliseconds. */ @@ -98,12 +112,13 @@ export interface ParticipantLike { kind: ParticipantKind; } +// TODO add ability to update stt/vad/interruption-detection export class AudioRecognition { private hooks: RecognitionHooks; private stt?: STTNode; private vad?: VAD; private turnDetector?: _TurnDetector; - private turnDetectionMode?: Exclude; + private turnDetectionMode?: TurnDetectionMode; private minEndpointingDelay: number; private maxEndpointingDelay: number; private lastLanguage?: string; @@ -137,6 +152,16 @@ export class AudioRecognition { private commitUserTurnTask?: Task; private vadTask?: Task; private sttTask?: Task; + private interruptionTask?: Task; + + // interruption detection + private interruptionDetection?: AdaptiveInterruptionDetector; + private _inputStartedAt?: number; + private ignoreUserTranscriptUntil?: number; + private transcriptBuffer: SpeechEvent[]; + private isInterruptionEnabled: boolean; + private isAgentSpeaking: boolean; + private interruptionStreamChannel?: StreamChannel; constructor(opts: AudioRecognitionOptions) { this.hooks = opts.recognitionHooks; @@ -153,9 +178,29 @@ export class AudioRecognition { this.getLinkedParticipant = opts.getLinkedParticipant; this.deferredInputStream = new DeferredReadableStream(); - const [vadInputStream, sttInputStream] = this.deferredInputStream.stream.tee(); - this.vadInputStream = vadInputStream; - this.sttInputStream = mergeReadableStreams(sttInputStream, this.silenceAudioTransform.readable); + this.interruptionDetection = opts.interruptionDetection; + this.transcriptBuffer = []; + this.isInterruptionEnabled = !!(opts.interruptionDetection && opts.vad); + this.isAgentSpeaking = false; + + if (opts.interruptionDetection) { + const [vadInputStream, teedInput] = this.deferredInputStream.stream.tee(); + const [inputStream, sttInputStream] = teedInput.tee(); + this.vadInputStream = vadInputStream; + this.sttInputStream = mergeReadableStreams( + sttInputStream, + this.silenceAudioTransform.readable, + ); + this.interruptionStreamChannel = createStreamChannel(); + this.interruptionStreamChannel.addStreamInput(inputStream); + } else { + const [vadInputStream, sttInputStream] = this.deferredInputStream.stream.tee(); + this.vadInputStream = vadInputStream; + this.sttInputStream = mergeReadableStreams( + sttInputStream, + this.silenceAudioTransform.readable, + ); + } this.silenceAudioWriter = this.silenceAudioTransform.writable.getWriter(); } @@ -169,6 +214,16 @@ export class AudioRecognition { return this.audioTranscript; } + /** @internal */ + get inputStartedAt() { + return this._inputStartedAt; + } + + /** @internal */ + updateOptions(options: { turnDetection: TurnDetectionMode | undefined }): void { + this.turnDetectionMode = options.turnDetection; + } + async start() { this.vadTask = Task.from(({ signal }) => this.createVadTask(this.vad, signal)); this.vadTask.result.catch((err) => { @@ -179,6 +234,211 @@ export class AudioRecognition { this.sttTask.result.catch((err) => { this.logger.error(`Error running STT task: ${err}`); }); + + this.interruptionTask = Task.from(({ signal }) => + this.createInterruptionTask(this.interruptionDetection, signal), + ); + this.interruptionTask.result.catch((err) => { + this.logger.error(`Error running interruption task: ${err}`); + }); + } + + async stop() { + await this.sttTask?.cancelAndWait(); + await this.vadTask?.cancelAndWait(); + await this.interruptionTask?.cancelAndWait(); + } + + async onStartOfAgentSpeech() { + this.isAgentSpeaking = true; + return this.trySendInterruptionSentinel(InterruptionStreamSentinel.agentSpeechStarted()); + } + + async onEndOfAgentSpeech(ignoreUserTranscriptUntil: number) { + if (!this.isInterruptionEnabled) { + this.isAgentSpeaking = false; + return; + } + + const inputOpen = await this.trySendInterruptionSentinel( + InterruptionStreamSentinel.agentSpeechEnded(), + ); + if (!inputOpen) { + this.isAgentSpeaking = false; + return; + } + + if (this.isAgentSpeaking) { + if (this.ignoreUserTranscriptUntil === undefined) { + this.onEndOfOverlapSpeech(Date.now()); + } + this.ignoreUserTranscriptUntil = this.ignoreUserTranscriptUntil + ? Math.min(ignoreUserTranscriptUntil, this.ignoreUserTranscriptUntil) + : ignoreUserTranscriptUntil; + + // flush held transcripts if possible + await this.flushHeldTranscripts(); + } + this.isAgentSpeaking = false; + } + + /** Start interruption inference when agent is speaking and overlap speech starts. */ + async onStartOfOverlapSpeech(speechDuration: number, startedAt: number, userSpeakingSpan?: Span) { + if (this.isAgentSpeaking) { + this.trySendInterruptionSentinel( + InterruptionStreamSentinel.overlapSpeechStarted( + speechDuration, + startedAt, + userSpeakingSpan, + ), + ); + } + } + + /** End interruption inference when overlap speech ends. */ + async onEndOfOverlapSpeech(endedAt: number, userSpeakingSpan?: Span) { + if (!this.isInterruptionEnabled) { + return; + } + if (userSpeakingSpan && userSpeakingSpan.isRecording()) { + userSpeakingSpan.setAttribute(traceTypes.ATTR_IS_INTERRUPTION, 'false'); + } + + return this.trySendInterruptionSentinel(InterruptionStreamSentinel.overlapSpeechEnded(endedAt)); + } + + /** + * Flush held transcripts whose *end time* is after the ignoreUserTranscriptUntil timestamp. + * If the event has no timestamps, we assume it is the same as the next valid event. + */ + private async flushHeldTranscripts() { + if ( + !this.isInterruptionEnabled || + this.ignoreUserTranscriptUntil === undefined || + this.transcriptBuffer.length === 0 + ) { + return; + } + + if (!this._inputStartedAt) { + this.transcriptBuffer = []; + this.ignoreUserTranscriptUntil = undefined; + return; + } + + let emitFromIndex: number | null = null; + let shouldFlush = false; + + for (let i = 0; i < this.transcriptBuffer.length; i++) { + const ev = this.transcriptBuffer[i]; + if (!ev || !ev.alternatives || ev.alternatives.length === 0) { + emitFromIndex = Math.min(emitFromIndex ?? i, i); + continue; + } + const firstAlternative = ev.alternatives[0]; + if ( + firstAlternative.startTime === firstAlternative.endTime && + firstAlternative.startTime === 0 + ) { + this.transcriptBuffer = []; + this.ignoreUserTranscriptUntil = undefined; + return; + } + + if (this.#alternativeEndsBeforeIgnoreWindow(firstAlternative)) { + emitFromIndex = null; + } else { + emitFromIndex = Math.min(emitFromIndex ?? i, i); + shouldFlush = true; + break; + } + } + + const eventsToEmit = + emitFromIndex !== null && shouldFlush ? this.transcriptBuffer.slice(emitFromIndex) : []; + + this.transcriptBuffer = []; + this.ignoreUserTranscriptUntil = undefined; + + for (const event of eventsToEmit) { + this.logger.trace( + { + event: event.type, + }, + 're-emitting held user transcript', + ); + this.onSTTEvent(event); + } + } + + #alternativeEndsBeforeIgnoreWindow( + alternative: NonNullable[number], + ): boolean { + if ( + this.ignoreUserTranscriptUntil === undefined || + !this._inputStartedAt || + alternative.startTime <= 0 + ) { + return false; + } + + // `SpeechData.startTime` is in seconds relative to audio start, while `inputStartedAt` and + // `ignoreUserTranscriptUntil` are epoch milliseconds. + return alternative.startTime * 1000 + this._inputStartedAt < this.ignoreUserTranscriptUntil; + } + + private shouldHoldSttEvent(ev: SpeechEvent): boolean { + if (!this.isInterruptionEnabled) { + return false; + } + if (this.isAgentSpeaking) { + return true; + } + + // reset when the user starts speaking after the agent speech + if (ev.type === SpeechEventType.START_OF_SPEECH) { + this.ignoreUserTranscriptUntil = undefined; + this.transcriptBuffer = []; + return false; + } + + if (this.ignoreUserTranscriptUntil === undefined) { + return false; + } + // sentinel events are always held until we have something concrete to release them + if (!ev.alternatives || ev.alternatives.length === 0) { + return true; + } + + const alternative = ev.alternatives[0]; + + if ( + alternative.startTime !== alternative.endTime && + this.#alternativeEndsBeforeIgnoreWindow(alternative) + ) { + return true; + } + return false; + } + + private async trySendInterruptionSentinel( + frame: AudioFrame | InterruptionSentinel, + ): Promise { + if ( + this.isInterruptionEnabled && + this.interruptionStreamChannel && + !this.interruptionStreamChannel.closed + ) { + try { + await this.interruptionStreamChannel.write(frame); + return true; + } catch (e: unknown) { + this.logger.warn( + `could not forward interruption sentinel: ${e instanceof Error ? e.message : String(e)}`, + ); + } + } + return false; } private ensureUserTurnSpan(startTime?: number): Span { @@ -234,6 +494,25 @@ export class AudioRecognition { return; } + // handle interruption detection + // - hold the event until the ignore_user_transcript_until expires + // - release only relevant events + // - allow RECOGNITION_USAGE to pass through immediately + + if (ev.type !== SpeechEventType.RECOGNITION_USAGE && this.isInterruptionEnabled) { + if (this.shouldHoldSttEvent(ev)) { + this.logger.trace( + { event: ev.type, ignoreUserTranscriptUntil: this.ignoreUserTranscriptUntil }, + 'holding STT event until ignore_user_transcript_until expires', + ); + this.transcriptBuffer.push(ev); + return; + } else { + await this.flushHeldTranscripts(); + // no return here to allow the new event to be processed normally + } + } + switch (ev.type) { case SpeechEventType.FINAL_TRANSCRIPT: const transcript = ev.alternatives?.[0]?.text; @@ -417,6 +696,12 @@ export class AudioRecognition { } } + private onOverlapSpeechEvent(ev: OverlappingSpeechEvent) { + if (ev.isInterruption) { + this.hooks.onInterruption(ev); + } + } + private runEOUDetection(chatCtx: ChatContext) { this.logger.debug( { @@ -675,7 +960,9 @@ export class AudioRecognition { this.lastSpeakingTime = Date.now(); if (this.speechStartTime === undefined) { - this.speechStartTime = Date.now(); + // Backdate speechStartTime to the actual start of accumulated speech. + // ev.rawAccumulatedSpeech is in ms (VADEvent durations are all ms in TS). + this.speechStartTime = Date.now() - ev.rawAccumulatedSpeech; } } break; @@ -707,6 +994,85 @@ export class AudioRecognition { } } + private async createInterruptionTask( + interruptionDetection: AdaptiveInterruptionDetector | undefined, + signal: AbortSignal, + ) { + if (!interruptionDetection || !this.interruptionStreamChannel) return; + + const stream = interruptionDetection.createStream(); + const inputReader = this.interruptionStreamChannel.stream().getReader(); + + const cleanup = async () => { + try { + signal.removeEventListener('abort', abortHandler); + eventReader.releaseLock(); + await stream.close(); + } catch (e) { + this.logger.debug('createInterruptionTask: error during abort handler:', e); + } + }; + + // Forward input frames/sentinels to the interruption stream + const forwardTask = (async () => { + try { + const abortPromise = waitForAbort(signal); + while (!signal.aborted) { + const res = await Promise.race([inputReader.read(), abortPromise]); + if (!res) break; + const { value, done } = res; + if (done) break; + // Backdate to the actual start of the audio frame, not when it was received. + if (value instanceof AudioFrame) { + const frameDurationMs = (value.samplesPerChannel / value.sampleRate) * 1000; + this._inputStartedAt ??= Date.now() - frameDurationMs; + } else { + this._inputStartedAt ??= Date.now(); + } + await stream.pushFrame(value); + } + } finally { + inputReader.releaseLock(); + } + })(); + + // Read output events from the interruption stream + const eventReader = stream.stream().getReader(); + const abortHandler = async () => { + await cleanup(); + }; + signal.addEventListener('abort', abortHandler); + + try { + const abortPromise = waitForAbort(signal); + + while (!signal.aborted) { + const res = await Promise.race([eventReader.read(), abortPromise]); + if (!res) break; + const { done, value: ev } = res; + if (done) break; + this.onOverlapSpeechEvent(ev); + } + } catch (e) { + if (!signal.aborted) { + const cause = e instanceof Error ? e : new Error(String(e)); + interruptionDetection.emitError( + new InterruptionDetectionError( + cause.message, + Date.now(), + interruptionDetection.label, + false, + ), + ); + this.logger.error(e, 'Error in interruption task'); + } + } finally { + await cleanup(); + await forwardTask; + this.logger.debug('Interruption task closed'); + } + } + setInputAudioStream(audioStream: ReadableStream) { this.deferredInputStream.setSource(audioStream); } @@ -783,6 +1149,8 @@ export class AudioRecognition { await this.sttTask?.cancelAndWait(); await this.vadTask?.cancelAndWait(); await this.bounceEOUTask?.cancelAndWait(); + await this.interruptionTask?.cancelAndWait(); + await this.interruptionStreamChannel?.close(); } private _endUserTurnSpan({ @@ -809,6 +1177,14 @@ export class AudioRecognition { } private get vadBaseTurnDetection() { - return ['vad', undefined].includes(this.turnDetectionMode); + if (typeof this.turnDetectionMode === 'object') { + return false; + } + + if (this.turnDetectionMode === undefined || this.turnDetectionMode === 'vad') { + return true; + } + + return false; } } diff --git a/agents/src/voice/client_events.ts b/agents/src/voice/client_events.ts new file mode 100644 index 000000000..510331072 --- /dev/null +++ b/agents/src/voice/client_events.ts @@ -0,0 +1,838 @@ +// SPDX-FileCopyrightText: 2025 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { Room, RpcInvocationData, TextStreamInfo, TextStreamReader } from '@livekit/rtc-node'; +import type { TypedEventEmitter } from '@livekit/typed-emitter'; +import EventEmitter from 'events'; +import type { z } from 'zod'; +import { + RPC_GET_AGENT_INFO, + RPC_GET_CHAT_HISTORY, + RPC_GET_SESSION_STATE, + RPC_SEND_MESSAGE, + TOPIC_AGENT_REQUEST, + TOPIC_AGENT_RESPONSE, + TOPIC_CHAT, + TOPIC_CLIENT_EVENTS, +} from '../constants.js'; +import type { OverlappingSpeechEvent } from '../inference/interruption/types.js'; +import type { ToolContext } from '../llm/tool_context.js'; +import { log } from '../log.js'; +import { Future, Task, cancelAndWait, shortuuid } from '../utils.js'; +import type { AgentSession } from './agent_session.js'; +import { + AgentSessionEventTypes, + type AgentStateChangedEvent, + type ConversationItemAddedEvent, + type ErrorEvent, + type FunctionToolsExecutedEvent, + type MetricsCollectedEvent, + type UserInputTranscribedEvent, + type UserStateChangedEvent, +} from './events.js'; +import type { RoomIO } from './room_io/room_io.js'; +import { + agentMetricsToWire, + agentSessionUsageToWire, + chatItemToWire, + chatMessageToWire, + type clientAgentStateChangedSchema, + type clientConversationItemAddedSchema, + type clientErrorSchema, + clientEventSchema, + type clientFunctionToolsExecutedSchema, + type clientMetricsCollectedSchema, + type clientSessionUsageSchema, + type clientUserInputTranscribedSchema, + type clientUserOverlappingSpeechSchema, + type clientUserStateChangedSchema, + functionCallOutputToWire, + functionCallToWire, + getAgentInfoResponseSchema, + getChatHistoryResponseSchema, + getRTCStatsResponseSchema, + getSessionStateResponseSchema, + getSessionUsageResponseSchema, + msToS, + sendMessageRequestSchema, + sendMessageResponseSchema, + streamRequestSchema, + streamResponseSchema, +} from './wire_format.js'; + +/** @experimental */ +export type ClientAgentStateChangedEvent = z.infer; + +/** @experimental */ +export type ClientUserStateChangedEvent = z.infer; + +/** @experimental */ +export type ClientConversationItemAddedEvent = z.infer; + +/** @experimental */ +export type ClientUserInputTranscribedEvent = z.infer; + +/** @experimental */ +export type ClientFunctionToolsExecutedEvent = z.infer; + +/** @experimental */ +export type ClientMetricsCollectedEvent = z.infer; + +/** @experimental */ +export type ClientErrorEvent = z.infer; + +/** @experimental */ +export type ClientUserOverlappingSpeechEvent = z.infer; + +/** @experimental */ +export type ClientSessionUsageEvent = z.infer; + +/** @experimental */ +export type ClientEvent = z.infer; + +/** @experimental */ +export type ClientEventType = ClientEvent['type']; + +/** @experimental */ +export type StreamRequest = z.infer; + +/** @experimental */ +export type StreamResponse = z.infer; + +/** @experimental */ +export type GetSessionStateRequest = Record; + +/** @experimental */ +export type GetSessionStateResponse = z.infer; + +/** @experimental */ +export type GetChatHistoryRequest = Record; + +/** @experimental */ +export type GetChatHistoryResponse = z.infer; + +/** @experimental */ +export type GetAgentInfoRequest = Record; + +/** @experimental */ +export type GetAgentInfoResponse = z.infer; + +/** @experimental */ +export type SendMessageRequest = z.infer; + +/** @experimental */ +export type SendMessageResponse = z.infer; + +/** @experimental */ +export type GetRTCStatsRequest = Record; + +/** @experimental */ +export type GetRTCStatsResponse = z.infer; + +/** @experimental */ +export type GetSessionUsageRequest = Record; + +/** @experimental */ +export type GetSessionUsageResponse = z.infer; + +function serializeOptions(opts: { + turnHandling?: { + endpointing?: unknown; + interruption?: unknown; + }; + maxToolSteps?: number; + userAwayTimeout?: number | null; + preemptiveGeneration?: boolean; + useTtsAlignedTranscript?: boolean; +}): Record { + return { + endpointing: opts.turnHandling?.endpointing ?? {}, + interruption: opts.turnHandling?.interruption ?? {}, + max_tool_steps: opts.maxToolSteps, + user_away_timeout: opts.userAwayTimeout, + preemptive_generation: opts.preemptiveGeneration, + use_tts_aligned_transcript: opts.useTtsAlignedTranscript, + }; +} + +function toolNames(toolCtx: ToolContext | undefined): string[] { + if (!toolCtx) return []; + return Object.keys(toolCtx); +} + +/** @experimental */ +export type RemoteSessionEventTypes = + | 'agent_state_changed' + | 'user_state_changed' + | 'conversation_item_added' + | 'user_input_transcribed' + | 'function_tools_executed' + | 'metrics_collected' + | 'user_overlapping_speech' + | 'session_usage' + | 'error'; + +/** @experimental */ +export type RemoteSessionCallbacks = { + agent_state_changed: (ev: ClientAgentStateChangedEvent) => void; + user_state_changed: (ev: ClientUserStateChangedEvent) => void; + conversation_item_added: (ev: ClientConversationItemAddedEvent) => void; + user_input_transcribed: (ev: ClientUserInputTranscribedEvent) => void; + function_tools_executed: (ev: ClientFunctionToolsExecutedEvent) => void; + metrics_collected: (ev: ClientMetricsCollectedEvent) => void; + user_overlapping_speech: (ev: ClientUserOverlappingSpeechEvent) => void; + session_usage: (ev: ClientSessionUsageEvent) => void; + error: (ev: ClientErrorEvent) => void; +}; + +export interface TextInputEvent { + text: string; + info: TextStreamInfo; + participantIdentity: string; +} + +export type TextInputCallback = (session: AgentSession, ev: TextInputEvent) => void | Promise; + +/** + * Handles exposing AgentSession state to room participants and allows interaction. + * + * This class provides: + * - Event streaming: Automatically streams AgentSession events to clients via a text stream + * - RPC handlers: Allows clients to request state, chat history, and agent info on demand + * - Text input handling: Receives text messages from clients and generates agent replies + */ + +/** @experimental */ +export class ClientEventsHandler { + private readonly session: AgentSession; + private readonly roomIO: RoomIO; + + private textInputCb?: TextInputCallback; + private textStreamHandlerRegistered = false; + private rpcHandlersRegistered = false; + private requestHandlerRegistered = false; + private eventHandlersRegistered = false; + private started = false; + + private readonly tasks = new Set>(); + private readonly logger = log(); + + constructor(session: AgentSession, roomIO: RoomIO) { + this.session = session; + this.roomIO = roomIO; + } + + private get room(): Room { + return this.roomIO.rtcRoom; + } + + async start(): Promise { + if (this.started) return; + + this.started = true; + this.registerRpcHandlers(); + this.registerRequestHandler(); + this.registerEventHandlers(); + } + + async close(): Promise { + if (!this.started) return; + this.started = false; + + if (this.textStreamHandlerRegistered) { + this.room.unregisterTextStreamHandler(TOPIC_CHAT); + this.textStreamHandlerRegistered = false; + } + + if (this.rpcHandlersRegistered) { + const localParticipant = this.room.localParticipant; + if (localParticipant) { + localParticipant.unregisterRpcMethod(RPC_GET_SESSION_STATE); + localParticipant.unregisterRpcMethod(RPC_GET_CHAT_HISTORY); + localParticipant.unregisterRpcMethod(RPC_GET_AGENT_INFO); + localParticipant.unregisterRpcMethod(RPC_SEND_MESSAGE); + } + this.rpcHandlersRegistered = false; + } + + if (this.requestHandlerRegistered) { + this.room.unregisterTextStreamHandler(TOPIC_AGENT_REQUEST); + this.requestHandlerRegistered = false; + } + + if (this.eventHandlersRegistered) { + this.session.off(AgentSessionEventTypes.AgentStateChanged, this.onAgentStateChanged); + this.session.off(AgentSessionEventTypes.UserStateChanged, this.onUserStateChanged); + this.session.off(AgentSessionEventTypes.ConversationItemAdded, this.onConversationItemAdded); + this.session.off(AgentSessionEventTypes.FunctionToolsExecuted, this.onFunctionToolsExecuted); + this.session.off(AgentSessionEventTypes.MetricsCollected, this.onMetricsCollected); + this.session.off(AgentSessionEventTypes.UserInputTranscribed, this.onUserInputTranscribed); + this.session.off(AgentSessionEventTypes.UserOverlappingSpeech, this.onUserOverlapSpeech); + this.session.off(AgentSessionEventTypes.Error, this.onError); + this.eventHandlersRegistered = false; + } + + await cancelAndWait([...this.tasks]); + this.tasks.clear(); + } + + /** + * Registers a callback to handle text input from clients. + * + * This callback will be called when a client sends a text message to the agent. + * The callback should return a promise that resolves when the text input has been processed. + * + * @param textInputCb - The callback to handle text input. + */ + registerTextInput(textInputCb: TextInputCallback): void { + this.textInputCb = textInputCb; + if (this.textStreamHandlerRegistered) return; + this.room.registerTextStreamHandler(TOPIC_CHAT, this.onUserTextInput); + this.textStreamHandlerRegistered = true; + } + + private registerRpcHandlers(): void { + if (this.rpcHandlersRegistered) return; + + const localParticipant = this.room.localParticipant; + if (!localParticipant) return; + + localParticipant.registerRpcMethod(RPC_GET_SESSION_STATE, this.rpcGetSessionState); + localParticipant.registerRpcMethod(RPC_GET_CHAT_HISTORY, this.rpcGetChatHistory); + localParticipant.registerRpcMethod(RPC_GET_AGENT_INFO, this.rpcGetAgentInfo); + localParticipant.registerRpcMethod(RPC_SEND_MESSAGE, this.rpcSendMessage); + this.rpcHandlersRegistered = true; + } + + private registerRequestHandler(): void { + if (this.requestHandlerRegistered) return; + + this.room.registerTextStreamHandler(TOPIC_AGENT_REQUEST, this.onStreamRequest); + this.requestHandlerRegistered = true; + } + + private registerEventHandlers(): void { + if (this.eventHandlersRegistered) return; + + this.session.on(AgentSessionEventTypes.AgentStateChanged, this.onAgentStateChanged); + this.session.on(AgentSessionEventTypes.UserStateChanged, this.onUserStateChanged); + this.session.on(AgentSessionEventTypes.ConversationItemAdded, this.onConversationItemAdded); + this.session.on(AgentSessionEventTypes.FunctionToolsExecuted, this.onFunctionToolsExecuted); + this.session.on(AgentSessionEventTypes.MetricsCollected, this.onMetricsCollected); + this.session.on(AgentSessionEventTypes.UserInputTranscribed, this.onUserInputTranscribed); + this.session.on(AgentSessionEventTypes.UserOverlappingSpeech, this.onUserOverlapSpeech); + this.session.on(AgentSessionEventTypes.Error, this.onError); + this.eventHandlersRegistered = true; + } + + private onStreamRequest = ( + reader: TextStreamReader, + participantInfo: { identity: string }, + ): void => { + const task = Task.from(async () => this.handleStreamRequest(reader, participantInfo.identity)); + this.trackTask(task); + }; + + private async handleStreamRequest( + reader: TextStreamReader, + participantIdentity: string, + ): Promise { + try { + const data = await reader.readAll(); + const request = streamRequestSchema.parse(JSON.parse(data)); + + let responsePayload = ''; + let error: string | null = null; + + try { + switch (request.method) { + case 'get_session_state': + responsePayload = await this.streamGetSessionState(); + break; + case 'get_chat_history': + responsePayload = await this.streamGetChatHistory(); + break; + case 'get_agent_info': + responsePayload = await this.streamGetAgentInfo(); + break; + case 'send_message': + responsePayload = await this.streamSendMessage(request.payload); + break; + case 'get_rtc_stats': + responsePayload = await this.streamGetRtcStats(); + break; + case 'get_session_usage': + responsePayload = await this.streamGetSessionUsage(); + break; + default: + error = `Unknown method: ${request.method}`; + } + } catch (e) { + error = e instanceof Error ? e.message : String(e); + } + + const response: StreamResponse = { + request_id: request.request_id, + payload: responsePayload, + error, + }; + + const localParticipant = this.room.localParticipant; + await localParticipant!.sendText(JSON.stringify(response), { + topic: TOPIC_AGENT_RESPONSE, + destinationIdentities: [participantIdentity], + }); + } catch (e) { + this.logger.warn({ error: e }, 'failed to handle stream request'); + } + } + + private async streamGetSessionState(): Promise { + const agent = this.session.currentAgent; + + const response: GetSessionStateResponse = { + agent_state: this.session.agentState, + user_state: this.session.userState, + agent_id: agent.id, + options: serializeOptions({ + turnHandling: this.session.options.turnHandling, + maxToolSteps: this.session.options.maxToolSteps, + userAwayTimeout: this.session.options.userAwayTimeout, + preemptiveGeneration: this.session.options.preemptiveGeneration, + useTtsAlignedTranscript: this.session.options.useTtsAlignedTranscript, + }), + created_at: msToS(this.session._startedAt ?? Date.now()), + }; + return JSON.stringify(response); + } + + private async streamGetChatHistory(): Promise { + return JSON.stringify({ + items: this.session.history.items.map(chatItemToWire), + }); + } + + private async streamGetAgentInfo(): Promise { + const agent = this.session.currentAgent; + return JSON.stringify({ + id: agent.id, + instructions: agent.instructions, + tools: toolNames(agent.toolCtx), + chat_ctx: agent.chatCtx.items.map(chatItemToWire), + }); + } + + private async streamSendMessage(payload: string): Promise { + const request = sendMessageRequestSchema.parse(JSON.parse(payload)); + const runResult = this.session.run({ userInput: request.text }); + await runResult.wait(); + return JSON.stringify({ + items: runResult.events.map((ev) => chatItemToWire(ev.item)), + }); + } + + private async streamGetRtcStats(): Promise { + // TODO(parity): map rtc stats fields once getRtcStats API shape is finalized in rtc-node. + return JSON.stringify({ + publisher_stats: [], + subscriber_stats: [], + }); + } + + private async streamGetSessionUsage(): Promise { + return JSON.stringify({ + usage: agentSessionUsageToWire(this.session.usage), + created_at: msToS(Date.now()), + }); + } + + private onUserOverlapSpeech = (event: OverlappingSpeechEvent): void => { + const clientEvent: ClientUserOverlappingSpeechEvent = { + type: 'user_overlapping_speech', + is_interruption: event.isInterruption, + created_at: msToS(event.timestamp), + overlap_started_at: event.overlapStartedAt != null ? msToS(event.overlapStartedAt) : null, + detection_delay: event.detectionDelayInS, + sent_at: msToS(Date.now()), + }; + this.streamClientEvent(clientEvent); + }; + + private onAgentStateChanged = (event: AgentStateChangedEvent): void => { + const clientEvent: ClientAgentStateChangedEvent = { + type: 'agent_state_changed', + old_state: event.oldState, + new_state: event.newState, + created_at: msToS(event.createdAt), + }; + this.streamClientEvent(clientEvent); + }; + + private onUserStateChanged = (event: UserStateChangedEvent): void => { + const clientEvent: ClientUserStateChangedEvent = { + type: 'user_state_changed', + old_state: event.oldState, + new_state: event.newState, + created_at: msToS(event.createdAt), + }; + this.streamClientEvent(clientEvent); + }; + + private onConversationItemAdded = (event: ConversationItemAddedEvent): void => { + if (event.item.type !== 'message') { + return; + } + this.streamClientEvent({ + type: 'conversation_item_added', + item: chatMessageToWire(event.item) as ClientConversationItemAddedEvent['item'], + created_at: msToS(event.createdAt), + }); + }; + + private onUserInputTranscribed = (event: UserInputTranscribedEvent): void => { + this.streamClientEvent({ + type: 'user_input_transcribed', + transcript: event.transcript, + is_final: event.isFinal, + language: event.language, + created_at: msToS(event.createdAt), + }); + }; + + private onFunctionToolsExecuted = (event: FunctionToolsExecutedEvent): void => { + this.streamClientEvent({ + type: 'function_tools_executed', + function_calls: event.functionCalls.map( + functionCallToWire, + ) as ClientFunctionToolsExecutedEvent['function_calls'], + function_call_outputs: event.functionCallOutputs.map((o) => + o + ? (functionCallOutputToWire(o) as NonNullable< + ClientFunctionToolsExecutedEvent['function_call_outputs'][number] + >) + : null, + ), + created_at: msToS(event.createdAt), + }); + }; + + private onMetricsCollected = (event: MetricsCollectedEvent): void => { + this.streamClientEvent({ + type: 'metrics_collected', + metrics: agentMetricsToWire(event.metrics) as ClientMetricsCollectedEvent['metrics'], + created_at: msToS(event.createdAt), + }); + + this.streamClientEvent({ + type: 'session_usage', + usage: agentSessionUsageToWire(this.session.usage) as ClientSessionUsageEvent['usage'], + created_at: msToS(Date.now()), + }); + }; + + private onError = (event: ErrorEvent): void => { + const clientEvent: ClientErrorEvent = { + type: 'error', + message: event.error ? String(event.error) : 'Unknown error', + created_at: msToS(event.createdAt), + }; + this.streamClientEvent(clientEvent); + }; + + private getTargetIdentities(): string[] | null { + const linked = this.roomIO.linkedParticipant; + + // TODO(permissions): check linked.permissions.can_subscribe_metrics + return linked ? [linked.identity] : null; + } + + private streamClientEvent(event: ClientEvent): void { + const task = Task.from(async () => this.sendClientEvent(event)); + this.trackTask(task); + } + + private async sendClientEvent(event: ClientEvent): Promise { + if (!this.room.isConnected) return; + + const destinationIdentities = this.getTargetIdentities(); + if (!destinationIdentities) return; + + try { + const localParticipant = this.room.localParticipant; + if (!localParticipant) return; + + const writer = await localParticipant.streamText({ + topic: TOPIC_CLIENT_EVENTS, + destinationIdentities, + }); + await writer.write(JSON.stringify(event)); + await writer.close(); + } catch (e) { + this.logger.warn({ error: e }, 'failed to stream event to clients'); + } + } + + private rpcGetSessionState = async (): Promise => { + return this.streamGetSessionState(); + }; + + private rpcGetChatHistory = async (): Promise => { + return this.streamGetChatHistory(); + }; + + private rpcGetAgentInfo = async (): Promise => { + return this.streamGetAgentInfo(); + }; + + private rpcSendMessage = async (data: RpcInvocationData): Promise => { + return this.streamSendMessage(data.payload); + }; + + private onUserTextInput = ( + reader: TextStreamReader, + participantInfo: { identity: string }, + ): void => { + const linkedParticipant = this.roomIO.linkedParticipant; + if (linkedParticipant && participantInfo.identity !== linkedParticipant.identity) { + return; + } + + const participant = this.room.remoteParticipants.get(participantInfo.identity); + if (!participant) { + this.logger.warn('participant not found, ignoring text input'); + return; + } + + if (!this.textInputCb) { + this.logger.error('text input callback is not set, ignoring text input'); + return; + } + + const task = Task.from(async () => { + const text = await reader.readAll(); + const result = this.textInputCb!(this.session, { + text, + info: reader.info, + participantIdentity: participantInfo.identity, + }); + + if (result instanceof Promise) { + await result; + } + }); + + this.trackTask(task); + }; + + private trackTask(task: Task): void { + this.tasks.add(task); + task.addDoneCallback(() => { + this.tasks.delete(task); + }); + } +} + +/** + * Client-side interface to interact with a remote AgentSession. + * + * This class allows frontends/clients to: + * - Subscribe to real-time events from the agent session + * - Query session state, chat history, and agent info via RPC + * - Send messages to the agent + * + * Example: + * ```typescript + * const session = new RemoteSession(room, agentIdentity); + * session.on('agent_state_changed', (event) => { + * console.log('Agent state changed:', event.new_state); + * }); + * session.on('user_state_changed', (event) => { + * console.log('User state changed:', event.new_state); + * }); + * session.on('conversation_item_added', (event) => { + * console.log('Conversation item added:', event.item); + * }); + * await session.start(); + * + * const state = await session.fetchSessionState(); + * console.log('Session state:', state); + * + * const response = await session.sendMessage('Hello!'); + * console.log('Response:', response); + * ``` + */ +// TODO: expose this class +/** @experimental */ +export class RemoteSession extends (EventEmitter as new () => TypedEventEmitter) { + private readonly room: Room; + private readonly agentIdentity: string; + private started = false; + + private readonly tasks = new Set>(); + private readonly pendingRequests = new Map>(); + private readonly logger = log(); + + constructor(room: Room, agentIdentity: string) { + super(); + this.room = room; + this.agentIdentity = agentIdentity; + } + + async start(): Promise { + if (this.started) return; + this.started = true; + this.room.registerTextStreamHandler(TOPIC_CLIENT_EVENTS, this.onEventStream); + this.room.registerTextStreamHandler(TOPIC_AGENT_RESPONSE, this.onResponseStream); + } + + async close(): Promise { + if (!this.started) return; + + this.started = false; + this.room.unregisterTextStreamHandler(TOPIC_CLIENT_EVENTS); + this.room.unregisterTextStreamHandler(TOPIC_AGENT_RESPONSE); + + for (const pending of this.pendingRequests.values()) { + pending.reject(new Error('RemoteSession closed')); + } + + this.pendingRequests.clear(); + + await cancelAndWait([...this.tasks]); + this.tasks.clear(); + } + + private onEventStream = ( + reader: TextStreamReader, + participantInfo: { identity: string }, + ): void => { + if (participantInfo.identity !== this.agentIdentity) return; + this.trackTask(Task.from(async () => this.readEvent(reader))); + }; + + private onResponseStream = ( + reader: TextStreamReader, + participantInfo: { identity: string }, + ): void => { + if (participantInfo.identity !== this.agentIdentity) return; + this.trackTask(Task.from(async () => this.readResponse(reader))); + }; + + private async readResponse(reader: TextStreamReader): Promise { + try { + const data = await reader.readAll(); + const response = streamResponseSchema.parse(JSON.parse(data)); + const future = this.pendingRequests.get(response.request_id); + this.pendingRequests.delete(response.request_id); + + if (!future || future.done) return; + future.resolve(response); + } catch (e) { + this.logger.warn({ error: e }, 'failed to read stream response'); + } + } + + private async readEvent(reader: TextStreamReader): Promise { + try { + const data = await reader.readAll(); + const event = this.parseEvent(data); + if (event) { + this.emit(event.type, event as never); + } + } catch (e) { + this.logger.warn({ error: e }, 'failed to parse client event'); + } + } + + private parseEvent(data: string): ClientEvent | null { + try { + const result = clientEventSchema.safeParse(JSON.parse(data)); + if (!result.success) { + this.logger.warn({ error: result.error }, 'failed to validate event'); + return null; + } + return result.data; + } catch (e) { + this.logger.warn({ error: e }, 'failed to parse event'); + return null; + } + } + + private async sendRequest(method: string, payload: string, timeout = 60000): Promise { + const requestId = shortuuid('req_'); + const request: StreamRequest = { + request_id: requestId, + method, + payload, + }; + + const future = new Future(); + this.pendingRequests.set(requestId, future); + + const localParticipant = this.room.localParticipant; + if (!localParticipant) { + this.pendingRequests.delete(requestId); + throw new Error('RemoteSession room has no local participant'); + } + + await localParticipant.sendText(JSON.stringify(request), { + topic: TOPIC_AGENT_REQUEST, + destinationIdentities: [this.agentIdentity], + }); + + const timer = setTimeout(() => { + if (!future.done) { + this.pendingRequests.delete(requestId); + future.reject(new Error(`RemoteSession request timed out: ${method}`)); + } + }, timeout); + + try { + const response = await future.await; + if (response.error) { + throw new Error(response.error); + } + return response.payload; + } finally { + clearTimeout(timer); + } + } + + async fetchSessionState(): Promise { + const raw = JSON.parse(await this.sendRequest('get_session_state', '{}')); + return getSessionStateResponseSchema.parse(raw); + } + + async fetchChatHistory(): Promise { + const raw = JSON.parse(await this.sendRequest('get_chat_history', '{}')); + return getChatHistoryResponseSchema.parse(raw); + } + + async fetchAgentInfo(): Promise { + const raw = JSON.parse(await this.sendRequest('get_agent_info', '{}')); + return getAgentInfoResponseSchema.parse(raw); + } + + async sendMessage(text: string, responseTimeout = 60000): Promise { + const payload = JSON.stringify({ text } satisfies SendMessageRequest); + const raw = JSON.parse(await this.sendRequest('send_message', payload, responseTimeout)); + return sendMessageResponseSchema.parse(raw); + } + + async fetchRtcStats(): Promise { + const raw = JSON.parse(await this.sendRequest('get_rtc_stats', '{}')); + return getRTCStatsResponseSchema.parse(raw); + } + + async fetchSessionUsage(): Promise { + const raw = JSON.parse(await this.sendRequest('get_session_usage', '{}')); + return getSessionUsageResponseSchema.parse(raw); + } + + private trackTask(task: Task): void { + this.tasks.add(task); + task.addDoneCallback(() => { + this.tasks.delete(task); + }); + } +} diff --git a/agents/src/voice/events.ts b/agents/src/voice/events.ts index 7d8ff325f..f063af1ce 100644 --- a/agents/src/voice/events.ts +++ b/agents/src/voice/events.ts @@ -1,6 +1,8 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +import type { InterruptionDetectionError } from '../inference/interruption/errors.js'; +import type { OverlappingSpeechEvent } from '../inference/interruption/types.js'; import type { ChatMessage, FunctionCall, @@ -25,6 +27,7 @@ export enum AgentSessionEventTypes { FunctionToolsExecuted = 'function_tools_executed', MetricsCollected = 'metrics_collected', SpeechCreated = 'speech_created', + UserOverlappingSpeech = 'user_overlapping_speech', Error = 'error', Close = 'close', } @@ -215,13 +218,13 @@ export const createSpeechCreatedEvent = ({ export type ErrorEvent = { type: 'error'; - error: RealtimeModelError | STTError | TTSError | LLMError | unknown; + error: RealtimeModelError | STTError | TTSError | LLMError | InterruptionDetectionError | unknown; source: LLM | STT | TTS | RealtimeModel | unknown; createdAt: number; }; export const createErrorEvent = ( - error: RealtimeModelError | STTError | TTSError | LLMError | unknown, + error: RealtimeModelError | STTError | TTSError | LLMError | InterruptionDetectionError | unknown, source: LLM | STT | TTS | RealtimeModel | unknown, createdAt: number = Date.now(), ): ErrorEvent => ({ @@ -233,14 +236,20 @@ export const createErrorEvent = ( export type CloseEvent = { type: 'close'; - error: RealtimeModelError | STTError | TTSError | LLMError | null; + error: RealtimeModelError | STTError | TTSError | LLMError | InterruptionDetectionError | null; reason: ShutdownReason; createdAt: number; }; export const createCloseEvent = ( reason: ShutdownReason, - error: RealtimeModelError | STTError | TTSError | LLMError | null = null, + error: + | RealtimeModelError + | STTError + | TTSError + | LLMError + | InterruptionDetectionError + | null = null, createdAt: number = Date.now(), ): CloseEvent => ({ type: 'close', @@ -257,5 +266,6 @@ export type AgentEvent = | ConversationItemAddedEvent | FunctionToolsExecutedEvent | SpeechCreatedEvent + | OverlappingSpeechEvent | ErrorEvent | CloseEvent; diff --git a/agents/src/voice/generation.ts b/agents/src/voice/generation.ts index 1f141ab37..d2eba8fc0 100644 --- a/agents/src/voice/generation.ts +++ b/agents/src/voice/generation.ts @@ -51,6 +51,7 @@ export class _LLMGenerationData { generatedText: string = ''; generatedToolCalls: FunctionCall[]; id: string; + ttft?: number; constructor( public readonly textStream: ReadableStream, @@ -416,6 +417,8 @@ export function performLLMInference( toolCtx: ToolContext, modelSettings: ModelSettings, controller: AbortController, + model?: string, + provider?: string, ): [Task, _LLMGenerationData] { const textStream = new IdentityTransform(); const toolCallStream = new IdentityTransform(); @@ -431,8 +434,17 @@ export function performLLMInference( ); span.setAttribute(traceTypes.ATTR_FUNCTION_TOOLS, JSON.stringify(Object.keys(toolCtx))); + if (model) { + span.setAttribute(traceTypes.ATTR_GEN_AI_REQUEST_MODEL, model); + } + if (provider) { + span.setAttribute(traceTypes.ATTR_GEN_AI_PROVIDER_NAME, provider); + } + let llmStreamReader: ReadableStreamDefaultReader | null = null; let llmStream: ReadableStream | null = null; + const startTime = performance.now() / 1000; // Convert to seconds + let firstTokenReceived = false; try { llmStream = await node(chatCtx, toolCtx, modelSettings); @@ -455,6 +467,11 @@ export function performLLMInference( const { done, value: chunk } = result; if (done) break; + if (!firstTokenReceived) { + firstTokenReceived = true; + data.ttft = performance.now() / 1000 - startTime; + } + if (typeof chunk === 'string') { data.generatedText += chunk; await textWriter.write(chunk); @@ -493,6 +510,9 @@ export function performLLMInference( } span.setAttribute(traceTypes.ATTR_RESPONSE_TEXT, data.generatedText); + if (data.ttft !== undefined) { + span.setAttribute(traceTypes.ATTR_RESPONSE_TTFT, data.ttft); + } } catch (error) { if (error instanceof DOMException && error.name === 'AbortError') { // Abort signal was triggered, handle gracefully @@ -527,6 +547,8 @@ export function performTTSInference( text: ReadableStream, modelSettings: ModelSettings, controller: AbortController, + model?: string, + provider?: string, ): [Task, _TTSGenerationData] { const audioStream = new IdentityTransform(); const outputWriter = audioStream.writable.getWriter(); @@ -558,10 +580,27 @@ export function performTTSInference( } })(); - const _performTTSInferenceImpl = async (signal: AbortSignal) => { + let ttfb: number | undefined; + + const genData: _TTSGenerationData = { + audioStream: audioOutputStream, + timedTextsFut, + ttfb: undefined, + }; + + const _performTTSInferenceImpl = async (signal: AbortSignal, span: Span) => { + if (model) { + span.setAttribute(traceTypes.ATTR_GEN_AI_REQUEST_MODEL, model); + } + if (provider) { + span.setAttribute(traceTypes.ATTR_GEN_AI_PROVIDER_NAME, provider); + } + let ttsStreamReader: ReadableStreamDefaultReader | null = null; let ttsStream: ReadableStream | null = null; let pushedDuration = 0; + const startTime = performance.now() / 1000; // Convert to seconds + let firstByteReceived = false; try { ttsStream = await node(textOnlyStream.readable, modelSettings); @@ -595,6 +634,13 @@ export function performTTSInference( break; } + if (!firstByteReceived) { + firstByteReceived = true; + ttfb = performance.now() / 1000 - startTime; + genData.ttfb = ttfb; + span.setAttribute(traceTypes.ATTR_RESPONSE_TTFB, ttfb); + } + // Write the audio frame to the output stream await outputWriter.write(frame); @@ -631,6 +677,10 @@ export function performTTSInference( } throw error; } finally { + if (!timedTextsFut.done) { + // Ensure downstream consumers don't hang on errors. + timedTextsFut.resolve(null); + } ttsStreamReader?.releaseLock(); await ttsStream?.cancel(); await outputWriter.close(); @@ -642,16 +692,11 @@ export function performTTSInference( const currentContext = otelContext.active(); const inferenceTask = async (signal: AbortSignal) => - tracer.startActiveSpan(async () => _performTTSInferenceImpl(signal), { + tracer.startActiveSpan(async (span) => _performTTSInferenceImpl(signal, span), { name: 'tts_node', context: currentContext, }); - const genData: _TTSGenerationData = { - audioStream: audioOutputStream, - timedTextsFut, - }; - return [ Task.from((controller) => inferenceTask(controller.signal), controller, 'performTTSInference'), genData, @@ -719,7 +764,6 @@ export function performTextForwarding( export interface _AudioOut { audio: Array; - /** Future that will be set with the timestamp of the first frame's capture */ firstFrameFut: Future; } @@ -807,7 +851,6 @@ export function performAudioForwarding( ]; } -// function_tool span is already implemented in tracableToolExecution below (line ~796) export function performToolExecutions({ session, speechHandle, diff --git a/agents/src/voice/index.ts b/agents/src/voice/index.ts index 947013336..ea6573ffe 100644 --- a/agents/src/voice/index.ts +++ b/agents/src/voice/index.ts @@ -5,6 +5,7 @@ export { Agent, AgentTask, StopResponse, type AgentOptions, type ModelSettings } export { AgentSession, type AgentSessionOptions, type VoiceOptions } from './agent_session.js'; export * from './avatar/index.js'; export * from './background_audio.js'; +export { type TextInputCallback, type TextInputEvent } from './client_events.js'; export * from './events.js'; export { type TimedString } from './io.js'; export * from './report.js'; diff --git a/agents/src/voice/report.test.ts b/agents/src/voice/report.test.ts new file mode 100644 index 000000000..b7b95a451 --- /dev/null +++ b/agents/src/voice/report.test.ts @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it } from 'vitest'; +import { ChatContext } from '../llm/chat_context.js'; +import type { VoiceOptions } from './agent_session.js'; +import { createSessionReport, sessionReportToJSON } from './report.js'; + +function baseOptions(): VoiceOptions { + return { + maxToolSteps: 3, + preemptiveGeneration: false, + userAwayTimeout: 15, + useTtsAlignedTranscript: true, + turnHandling: {}, + }; +} + +function serializeOptions(options: VoiceOptions) { + const report = createSessionReport({ + jobId: 'job', + roomId: 'room-id', + room: 'room', + options, + events: [], + chatHistory: ChatContext.empty(), + enableRecording: false, + timestamp: 0, + startedAt: 0, + }); + + const payload = sessionReportToJSON(report); + return payload.options as Record; +} + +describe('sessionReportToJSON', () => { + it('serializes interruption and endpointing values from turnHandling', () => { + const options = baseOptions(); + options.turnHandling = { + interruption: { + mode: 'adaptive', + discardAudioIfUninterruptible: false, + minDuration: 1200, + minWords: 2, + }, + endpointing: { + minDelay: 900, + maxDelay: 4500, + }, + }; + + const serialized = serializeOptions(options); + expect(serialized).toMatchObject({ + allow_interruptions: true, + discard_audio_if_uninterruptible: false, + min_interruption_duration: 1200, + min_interruption_words: 2, + min_endpointing_delay: 900, + max_endpointing_delay: 4500, + max_tool_steps: 3, + }); + }); + + it('prefers turnHandling values over deprecated flat fields', () => { + const options = baseOptions(); + options.allowInterruptions = false; + options.discardAudioIfUninterruptible = true; + options.minInterruptionDuration = 400; + options.minInterruptionWords = 1; + options.minEndpointingDelay = 500; + options.maxEndpointingDelay = 2500; + options.turnHandling = { + interruption: { + mode: 'vad', + discardAudioIfUninterruptible: false, + minDuration: 1400, + minWords: 4, + }, + endpointing: { + minDelay: 700, + maxDelay: 3900, + }, + }; + + const serialized = serializeOptions(options); + expect(serialized).toMatchObject({ + allow_interruptions: true, + discard_audio_if_uninterruptible: false, + min_interruption_duration: 1400, + min_interruption_words: 4, + min_endpointing_delay: 700, + max_endpointing_delay: 3900, + max_tool_steps: 3, + }); + }); + + it('falls back to deprecated flat fields when turnHandling values are absent', () => { + const options = baseOptions(); + options.allowInterruptions = false; + options.discardAudioIfUninterruptible = false; + options.minInterruptionDuration = 600; + options.minInterruptionWords = 3; + options.minEndpointingDelay = 1000; + options.maxEndpointingDelay = 5000; + + const serialized = serializeOptions(options); + expect(serialized).toMatchObject({ + allow_interruptions: false, + discard_audio_if_uninterruptible: false, + min_interruption_duration: 600, + min_interruption_words: 3, + min_endpointing_delay: 1000, + max_endpointing_delay: 5000, + max_tool_steps: 3, + }); + }); +}); diff --git a/agents/src/voice/report.ts b/agents/src/voice/report.ts index 49701a696..a7498e45f 100644 --- a/agents/src/voice/report.ts +++ b/agents/src/voice/report.ts @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 import type { ChatContext } from '../llm/chat_context.js'; +import { type ModelUsage, filterZeroValues } from '../metrics/model_usage.js'; import type { VoiceOptions } from './agent_session.js'; import type { AgentEvent } from './events.js'; @@ -23,6 +24,8 @@ export interface SessionReport { audioRecordingStartedAt?: number; /** Duration of the session in milliseconds */ duration?: number; + /** Usage summaries for the session, one per model/provider combination */ + modelUsage?: ModelUsage[]; } export interface SessionReportOptions { @@ -41,6 +44,8 @@ export interface SessionReportOptions { audioRecordingPath?: string; /** Timestamp when the audio recording started (milliseconds) */ audioRecordingStartedAt?: number; + /** Usage summaries for the session, one per model/provider combination */ + modelUsage?: ModelUsage[]; } export function createSessionReport(opts: SessionReportOptions): SessionReport { @@ -61,6 +66,7 @@ export function createSessionReport(opts: SessionReportOptions): SessionReport { audioRecordingStartedAt, duration: audioRecordingStartedAt !== undefined ? timestamp - audioRecordingStartedAt : undefined, + modelUsage: opts.modelUsage, }; } @@ -70,6 +76,22 @@ export function createSessionReport(opts: SessionReportOptions): SessionReport { // - Uploads to LiveKit Cloud observability endpoint with JWT auth export function sessionReportToJSON(report: SessionReport): Record { const events: Record[] = []; + const interruptionConfig = report.options.turnHandling?.interruption; + const endpointingConfig = report.options.turnHandling?.endpointing; + + // Keep backwards compatibility with deprecated fields + const allowInterruptions = + interruptionConfig?.mode !== undefined + ? interruptionConfig.mode !== false + : report.options.allowInterruptions; + const discardAudioIfUninterruptible = + interruptionConfig?.discardAudioIfUninterruptible ?? + report.options.discardAudioIfUninterruptible; + const minInterruptionDuration = + interruptionConfig?.minDuration ?? report.options.minInterruptionDuration; + const minInterruptionWords = interruptionConfig?.minWords ?? report.options.minInterruptionWords; + const minEndpointingDelay = endpointingConfig?.minDelay ?? report.options.minEndpointingDelay; + const maxEndpointingDelay = endpointingConfig?.maxDelay ?? report.options.maxEndpointingDelay; for (const event of report.events) { if (event.type === 'metrics_collected') { @@ -85,16 +107,17 @@ export function sessionReportToJSON(report: SessionReport): Record void | Promise; - -const DEFAULT_TEXT_INPUT_CALLBACK: TextInputCallback = (sess: AgentSession, ev: TextInputEvent) => { +export const DEFAULT_TEXT_INPUT_CALLBACK: TextInputCallback = (sess, ev) => { sess.interrupt(); sess.generateReply({ userInput: ev.text }); }; @@ -146,8 +137,6 @@ export class RoomIO { private forwardUserTranscriptTask?: Task; private initTask?: Task; - private textStreamHandlerRegistered = false; - private logger = log(); constructor({ @@ -282,37 +271,6 @@ export class RoomIO { } }; - private onUserTextInput = (reader: TextStreamReader, participantInfo: { identity: string }) => { - if (participantInfo.identity !== this.participantIdentity) { - return; - } - - const participant = this.room.remoteParticipants.get(participantInfo.identity); - if (!participant) { - this.logger.warn('participant not found, ignoring text input'); - return; - } - - const readText = async () => { - const text = await reader.readAll(); - - const textInputResult = this.inputOptions.textInputCallback!(this.agentSession, { - text, - info: reader.info, - participant, - }); - - // check if callback is a Promise - if (textInputResult instanceof Promise) { - await textInputResult; - } - }; - - readText().catch((error) => { - this.logger.error({ error }, 'Error reading text input'); - }); - }; - private async forwardUserTranscript(signal: AbortSignal): Promise { const reader = this.userTranscriptStream.readable.getReader(); try { @@ -387,6 +345,10 @@ export class RoomIO { return this.participantAvailableFuture.done; } + get rtcRoom(): Room { + return this.room; + } + get linkedParticipant(): RemoteParticipant | undefined { if (!this.isParticipantAvailable) { return undefined; @@ -439,17 +401,6 @@ export class RoomIO { } start() { - if (this.inputOptions.textEnabled) { - try { - this.room.registerTextStreamHandler(TOPIC_CHAT, this.onUserTextInput); - this.textStreamHandlerRegistered = true; - } catch (error) { - if (this.inputOptions.textEnabled) { - this.logger.warn(`text stream handler for topic "${TOPIC_CHAT}" already set, ignoring`); - } - } - } - // -- create inputs -- if (this.inputOptions.audioEnabled) { this.audioInput = new ParticipantAudioInputStream({ @@ -525,11 +476,6 @@ export class RoomIO { this.agentSession.off(AgentSessionEventTypes.UserInputTranscribed, this.onUserInputTranscribed); this.agentSession.off(AgentSessionEventTypes.AgentStateChanged, this.onAgentStateChanged); - if (this.textStreamHandlerRegistered) { - this.room.unregisterTextStreamHandler(TOPIC_CHAT); - this.textStreamHandlerRegistered = false; - } - await this.initTask?.cancelAndWait(); // Close stream FIRST so reader.read() in forwardUserTranscript can exit. diff --git a/agents/src/voice/turn_config/endpointing.ts b/agents/src/voice/turn_config/endpointing.ts new file mode 100644 index 000000000..f2603e00f --- /dev/null +++ b/agents/src/voice/turn_config/endpointing.ts @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +/** + * Configuration for endpointing, which determines when the user's turn is complete. + */ +export interface EndpointingOptions { + /** + * Endpointing mode. `"fixed"` uses a fixed delay, `"dynamic"` adjusts delay based on + * end-of-utterance prediction. + * @defaultValue "fixed" + */ + mode: 'fixed' | 'dynamic'; + /** + * Minimum time in milliseconds since the last detected speech before the agent declares the user's + * turn complete. In VAD mode this effectively behaves like `max(VAD silence, minDelay)`; + * in STT mode it is applied after the STT end-of-speech signal, so it can be additive with + * the STT provider's endpointing delay. + * @defaultValue 500 + */ + minDelay: number; + /** + * Maximum time in milliseconds the agent will wait before terminating the turn. + * @defaultValue 3000 + */ + maxDelay: number; +} + +export const defaultEndpointingOptions = { + mode: 'fixed', + minDelay: 500, + maxDelay: 3000, +} as const satisfies EndpointingOptions; diff --git a/agents/src/voice/turn_config/interruption.ts b/agents/src/voice/turn_config/interruption.ts new file mode 100644 index 000000000..06197f7f3 --- /dev/null +++ b/agents/src/voice/turn_config/interruption.ts @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +/** + * Configuration for interruption handling. + */ +export interface InterruptionOptions { + /** + * Whether interruptions are enabled. + * @defaultValue true + */ + enabled: boolean; + /** + * Interruption handling strategy. `"adaptive"` for ML-based detection, `"vad"` for simple + * voice-activity detection. `undefined` means auto-detect. + * @defaultValue undefined + */ + mode: 'adaptive' | 'vad' | false | undefined; + /** + * When `true`, buffered audio is dropped while the agent is speaking and cannot be interrupted. + * @defaultValue true + */ + discardAudioIfUninterruptible: boolean; + /** + * Minimum speech length in milliseconds to register as an interruption. + * @defaultValue 500 + */ + minDuration: number; + /** + * Minimum number of words to consider an interruption, only used if STT is enabled. + * @defaultValue 0 + */ + minWords: number; + /** + * If set, emit an `agentFalseInterruption` event after this amount of time if the user is + * silent and no user transcript is detected after the interruption. Set to `undefined` to + * disable. The value is in milliseconds. + * @defaultValue 2000 + */ + falseInterruptionTimeout: number; + /** + * Whether to resume the false interruption after the `falseInterruptionTimeout`. + * @defaultValue true + */ + resumeFalseInterruption: boolean; +} + +export const defaultInterruptionOptions = { + enabled: true, + mode: undefined, + discardAudioIfUninterruptible: true, + minDuration: 500, + minWords: 0, + falseInterruptionTimeout: 2000, + resumeFalseInterruption: true, +} as const satisfies InterruptionOptions; diff --git a/agents/src/voice/turn_config/turn_handling.ts b/agents/src/voice/turn_config/turn_handling.ts new file mode 100644 index 000000000..1458fb663 --- /dev/null +++ b/agents/src/voice/turn_config/turn_handling.ts @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { TurnDetectionMode } from '../agent_session.js'; +import { type EndpointingOptions, defaultEndpointingOptions } from './endpointing.js'; +import { type InterruptionOptions, defaultInterruptionOptions } from './interruption.js'; + +/** + * Configuration for the turn handling system. Used to configure the turn taking behavior of the + * session. + */ +export interface TurnHandlingOptions { + /** + * Strategy for deciding when the user has finished speaking. + * + * - `"stt"` – rely on speech-to-text end-of-utterance cues + * - `"vad"` – rely on Voice Activity Detection start/stop cues + * - `"realtime_llm"` – use server-side detection from a realtime LLM + * - `"manual"` – caller controls turn boundaries explicitly + * + * If not set, the session chooses the best available mode in priority order + * `realtime_llm → vad → stt → manual`; it automatically falls back if the necessary model + * is missing. + */ + turnDetection: TurnDetectionMode | undefined; + /** + * Configuration for endpointing. + */ + endpointing: Partial; + /** + * Configuration for interruption handling. + */ + interruption: Partial; +} + +export interface InternalTurnHandlingOptions extends TurnHandlingOptions { + endpointing: EndpointingOptions; + interruption: InterruptionOptions; +} + +export const defaultTurnHandlingOptions: InternalTurnHandlingOptions = { + turnDetection: undefined, + interruption: defaultInterruptionOptions, + endpointing: defaultEndpointingOptions, +}; diff --git a/agents/src/voice/turn_config/utils.test.ts b/agents/src/voice/turn_config/utils.test.ts new file mode 100644 index 000000000..d6b6d19ac --- /dev/null +++ b/agents/src/voice/turn_config/utils.test.ts @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { beforeAll, describe, expect, it } from 'vitest'; +import { initializeLogger } from '../../log.js'; +import { defaultEndpointingOptions } from './endpointing.js'; +import { defaultInterruptionOptions } from './interruption.js'; +import { defaultTurnHandlingOptions } from './turn_handling.js'; +import { migrateLegacyOptions } from './utils.js'; + +beforeAll(() => { + initializeLogger({ pretty: true, level: 'info' }); +}); + +describe('migrateLegacyOptions', () => { + it('should return all defaults when no options are provided', () => { + const result = migrateLegacyOptions({}); + + expect(result.options.turnHandling).toEqual({ + turnDetection: defaultTurnHandlingOptions.turnDetection, + endpointing: defaultEndpointingOptions, + interruption: defaultInterruptionOptions, + }); + expect(result.options.maxToolSteps).toBe(3); + expect(result.options.preemptiveGeneration).toBe(false); + expect(result.options.userAwayTimeout).toBe(15.0); + }); + + it('should migrate legacy flat fields into nested turnHandling config', () => { + const result = migrateLegacyOptions({ + voiceOptions: { + minInterruptionDuration: 1000, + minInterruptionWords: 3, + discardAudioIfUninterruptible: false, + minEndpointingDelay: 800, + maxEndpointingDelay: 5000, + }, + }); + + expect(result.options.turnHandling.interruption!.minDuration).toBe(1000); + expect(result.options.turnHandling.interruption!.minWords).toBe(3); + expect(result.options.turnHandling.interruption!.discardAudioIfUninterruptible).toBe(false); + expect(result.options.turnHandling.endpointing!.minDelay).toBe(800); + expect(result.options.turnHandling.endpointing!.maxDelay).toBe(5000); + }); + + it('should set interruption.enabled to false when allowInterruptions is false', () => { + const result = migrateLegacyOptions({ + options: { + allowInterruptions: false, + }, + }); + + expect(result.options.turnHandling.interruption!.enabled).toBe(false); + }); + + it('should give options precedence over voiceOptions when both are provided', () => { + const result = migrateLegacyOptions({ + voiceOptions: { + minInterruptionDuration: 1000, + maxEndpointingDelay: 5000, + maxToolSteps: 10, + }, + options: { + minInterruptionDuration: 2000, + maxEndpointingDelay: 8000, + maxToolSteps: 5, + }, + }); + + expect(result.options.turnHandling.interruption!.minDuration).toBe(2000); + expect(result.options.turnHandling.endpointing!.maxDelay).toBe(8000); + expect(result.options.maxToolSteps).toBe(5); + }); + + it('should let explicit turnHandling override legacy flat fields', () => { + const result = migrateLegacyOptions({ + options: { + minInterruptionDuration: 1000, + minEndpointingDelay: 800, + turnHandling: { + interruption: { minDuration: 3000 }, + endpointing: { minDelay: 2000 }, + }, + }, + }); + + expect(result.options.turnHandling.interruption!.minDuration).toBe(3000); + expect(result.options.turnHandling.endpointing!.minDelay).toBe(2000); + }); + + it('should preserve top-level turnDetection in the result', () => { + const result = migrateLegacyOptions({ + turnDetection: 'vad', + }); + + expect(result.turnDetection).toBe('vad'); + expect(result.options.turnHandling.turnDetection).toBe('vad'); + }); +}); diff --git a/agents/src/voice/turn_config/utils.ts b/agents/src/voice/turn_config/utils.ts new file mode 100644 index 000000000..55234dc76 --- /dev/null +++ b/agents/src/voice/turn_config/utils.ts @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { log } from '../../log.js'; +import { + type AgentSessionOptions, + type InternalSessionOptions, + defaultSessionOptions, +} from '../agent_session.js'; +import { defaultEndpointingOptions } from './endpointing.js'; +import { defaultInterruptionOptions } from './interruption.js'; +import { type TurnHandlingOptions, defaultTurnHandlingOptions } from './turn_handling.js'; + +export function migrateLegacyOptions( + legacyOptions: AgentSessionOptions, +): AgentSessionOptions & { options: InternalSessionOptions } { + const logger = log(); + const { voiceOptions, turnDetection, options: sessionOptions, ...rest } = legacyOptions; + + if (voiceOptions !== undefined && sessionOptions !== undefined) { + logger.warn( + 'Both voiceOptions and options have been supplied as part of the AgentSessionOptions, voiceOptions will be merged with options taking precedence', + ); + } + + // Preserve turnDetection before cloning since structuredClone converts class instances to plain objects + const originalTurnDetection = + sessionOptions?.turnHandling?.turnDetection ?? + voiceOptions?.turnHandling?.turnDetection ?? + turnDetection; + + // Exclude potentially non-cloneable turnDetection objects before structuredClone. + // They are restored from originalTurnDetection below. + const cloneableVoiceOptions = voiceOptions + ? { + ...voiceOptions, + turnHandling: voiceOptions.turnHandling + ? { ...voiceOptions.turnHandling, turnDetection: undefined } + : voiceOptions.turnHandling, + } + : voiceOptions; + const cloneableSessionOptions = sessionOptions + ? { + ...sessionOptions, + turnHandling: sessionOptions.turnHandling + ? { ...sessionOptions.turnHandling, turnDetection: undefined } + : sessionOptions.turnHandling, + } + : sessionOptions; + + const mergedOptions = structuredClone({ ...cloneableVoiceOptions, ...cloneableSessionOptions }); + + const turnHandling: TurnHandlingOptions = { + interruption: { + discardAudioIfUninterruptible: mergedOptions?.discardAudioIfUninterruptible, + minDuration: mergedOptions?.minInterruptionDuration, + minWords: mergedOptions?.minInterruptionWords, + }, + endpointing: { + minDelay: mergedOptions?.minEndpointingDelay, + maxDelay: mergedOptions?.maxEndpointingDelay, + }, + + ...mergedOptions.turnHandling, + // Restore original turnDetection after spread to preserve class instance with methods + // (structuredClone converts class instances to plain objects, losing prototype methods) + turnDetection: originalTurnDetection, + } as const; + + if (mergedOptions?.allowInterruptions === false) { + turnHandling.interruption.enabled = false; + } + + const optionsWithDefaults = { + ...defaultSessionOptions, + ...mergedOptions, + turnHandling: mergeWithDefaults(turnHandling), + }; + + const newAgentSessionOptions: AgentSessionOptions & { + options: InternalSessionOptions; + } = { + ...rest, + options: optionsWithDefaults, + voiceOptions: optionsWithDefaults, + turnDetection: turnHandling.turnDetection, + }; + + return newAgentSessionOptions; +} + +/** Remove keys whose value is `undefined` so they don't shadow defaults when spread. */ +export function stripUndefined(obj: T): Partial { + return Object.fromEntries(Object.entries(obj).filter(([, v]) => v !== undefined)) as Partial; +} + +export function mergeWithDefaults(config: TurnHandlingOptions) { + return { + turnDetection: config.turnDetection ?? defaultTurnHandlingOptions.turnDetection, + endpointing: { ...defaultEndpointingOptions, ...stripUndefined(config.endpointing) }, + interruption: { ...defaultInterruptionOptions, ...stripUndefined(config.interruption) }, + } as const; +} diff --git a/agents/src/voice/wire_format.ts b/agents/src/voice/wire_format.ts new file mode 100644 index 000000000..3ea7782e5 --- /dev/null +++ b/agents/src/voice/wire_format.ts @@ -0,0 +1,827 @@ +// SPDX-FileCopyrightText: 2025 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Explicit wire-format converters that produce the exact JSON shape emitted by +// Python Pydantic models (snake_case keys, durations in seconds). +// The agents-playground frontend (types.ts / useClientEvents.ts) consumes this +// format directly via JSON.parse — any mismatch breaks the UI. +import { z } from 'zod'; +import type { + AgentHandoffItem, + AudioContent, + ChatContent, + ChatItem, + ChatMessage, + FunctionCall, + FunctionCallOutput, + ImageContent, + MetricsReport, +} from '../llm/chat_context.js'; +import type { + AgentMetrics, + EOUMetrics, + InterruptionMetrics, + LLMMetrics, + MetricsMetadata, + RealtimeModelMetrics, + RealtimeModelMetricsCachedTokenDetails, + RealtimeModelMetricsInputTokenDetails, + RealtimeModelMetricsOutputTokenDetails, + STTMetrics, + TTSMetrics, + VADMetrics, +} from '../metrics/base.js'; +import type { + InterruptionModelUsage, + LLMModelUsage, + ModelUsage, + STTModelUsage, + TTSModelUsage, +} from '../metrics/model_usage.js'; +import type { AgentSessionUsage } from './agent_session.js'; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +type WireObject = Record; + +export function msToS(ms: number): number { + return ms / 1000; +} + +function omitUndefined(obj: WireObject): WireObject { + const result: WireObject = {}; + for (const [k, v] of Object.entries(obj)) { + if (v !== undefined) { + result[k] = v; + } + } + return result; +} + +function imageContentToWire(img: ImageContent): WireObject { + return omitUndefined({ + id: img.id, + type: img.type, + image: typeof img.image === 'string' ? img.image : undefined, + inference_detail: img.inferenceDetail, + inference_width: img.inferenceWidth, + inference_height: img.inferenceHeight, + mime_type: img.mimeType, + }); +} + +function audioContentToWire(audio: AudioContent): WireObject { + return omitUndefined({ + type: audio.type, + transcript: audio.transcript, + }); +} + +function chatContentToWire(content: ChatContent): unknown { + if (typeof content === 'string') return content; + if (content.type === 'image_content') return imageContentToWire(content); + return audioContentToWire(content); +} + +function metricsReportToWire(m: MetricsReport): WireObject { + return omitUndefined({ + started_speaking_at: m.startedSpeakingAt, + stopped_speaking_at: m.stoppedSpeakingAt, + transcription_delay: m.transcriptionDelay, + end_of_turn_delay: m.endOfTurnDelay, + on_user_turn_completed_delay: m.onUserTurnCompletedDelay, + llm_node_ttft: m.llmNodeTtft, + tts_node_ttfb: m.ttsNodeTtfb, + e2e_latency: m.e2eLatency, + }); +} + +export function chatMessageToWire(msg: ChatMessage): WireObject { + const result: WireObject = { + id: msg.id, + type: msg.type, + role: msg.role, + content: msg.content.map(chatContentToWire), + interrupted: msg.interrupted, + created_at: msToS(msg.createdAt), + }; + + if (msg.transcriptConfidence !== undefined) { + result.transcript_confidence = msg.transcriptConfidence; + } + if (Object.keys(msg.metrics).length > 0) { + result.metrics = metricsReportToWire(msg.metrics); + } + if (Object.keys(msg.extra).length > 0) { + result.extra = msg.extra; + } + return result; +} + +export function functionCallToWire(fc: FunctionCall): WireObject { + const result: WireObject = { + id: fc.id, + type: fc.type, + call_id: fc.callId, + arguments: fc.args, + name: fc.name, + created_at: msToS(fc.createdAt), + }; + + if (Object.keys(fc.extra).length > 0) { + result.extra = fc.extra; + } + if (fc.groupId !== undefined) { + result.group_id = fc.groupId; + } + return result; +} + +export function functionCallOutputToWire(fco: FunctionCallOutput): WireObject { + return { + id: fco.id, + type: fco.type, + name: fco.name, + call_id: fco.callId, + output: fco.output, + is_error: fco.isError, + created_at: msToS(fco.createdAt), + }; +} + +export function agentHandoffToWire(ah: AgentHandoffItem): WireObject { + const result: WireObject = { + id: ah.id, + type: ah.type, + new_agent_id: ah.newAgentId, + created_at: msToS(ah.createdAt), + }; + if (ah.oldAgentId !== undefined) { + result.old_agent_id = ah.oldAgentId; + } + return result; +} + +export function chatItemToWire(item: ChatItem): WireObject { + switch (item.type) { + case 'message': + return chatMessageToWire(item); + case 'function_call': + return functionCallToWire(item); + case 'function_call_output': + return functionCallOutputToWire(item); + case 'agent_handoff': + return agentHandoffToWire(item); + } +} + +function metadataToWire(m: MetricsMetadata | undefined): WireObject | null { + if (!m) return null; + return omitUndefined({ + model_name: m.modelName, + model_provider: m.modelProvider, + }); +} + +function llmMetricsToWire(m: LLMMetrics): WireObject { + return omitUndefined({ + type: m.type, + label: m.label, + request_id: m.requestId, + timestamp: msToS(m.timestamp), + duration: msToS(m.durationMs), + ttft: msToS(m.ttftMs), + cancelled: m.cancelled, + completion_tokens: m.completionTokens, + prompt_tokens: m.promptTokens, + prompt_cached_tokens: m.promptCachedTokens, + total_tokens: m.totalTokens, + tokens_per_second: m.tokensPerSecond, + speech_id: m.speechId, + metadata: metadataToWire(m.metadata), + }); +} + +function sttMetricsToWire(m: STTMetrics): WireObject { + return omitUndefined({ + type: m.type, + label: m.label, + request_id: m.requestId, + timestamp: msToS(m.timestamp), + duration: msToS(m.durationMs), + audio_duration: msToS(m.audioDurationMs), + input_tokens: m.inputTokens, + output_tokens: m.outputTokens, + streamed: m.streamed, + metadata: metadataToWire(m.metadata), + }); +} + +function ttsMetricsToWire(m: TTSMetrics): WireObject { + return omitUndefined({ + type: m.type, + label: m.label, + request_id: m.requestId, + timestamp: msToS(m.timestamp), + ttfb: msToS(m.ttfbMs), + duration: msToS(m.durationMs), + audio_duration: msToS(m.audioDurationMs), + cancelled: m.cancelled, + characters_count: m.charactersCount, + input_tokens: m.inputTokens, + output_tokens: m.outputTokens, + streamed: m.streamed, + segment_id: m.segmentId, + speech_id: m.speechId, + metadata: metadataToWire(m.metadata), + }); +} + +function vadMetricsToWire(m: VADMetrics): WireObject { + return { + type: m.type, + label: m.label, + timestamp: msToS(m.timestamp), + idle_time: msToS(m.idleTimeMs), + inference_duration_total: msToS(m.inferenceDurationTotalMs), + inference_count: m.inferenceCount, + }; +} + +function eouMetricsToWire(m: EOUMetrics): WireObject { + return omitUndefined({ + type: m.type, + timestamp: msToS(m.timestamp), + end_of_utterance_delay: msToS(m.endOfUtteranceDelayMs), + transcription_delay: msToS(m.transcriptionDelayMs), + on_user_turn_completed_delay: msToS(m.onUserTurnCompletedDelayMs), + speech_id: m.speechId, + }); +} + +function cachedTokenDetailsToWire(d: RealtimeModelMetricsCachedTokenDetails): WireObject { + return { + audio_tokens: d.audioTokens, + text_tokens: d.textTokens, + image_tokens: d.imageTokens, + }; +} + +function inputTokenDetailsToWire(d: RealtimeModelMetricsInputTokenDetails): WireObject { + return omitUndefined({ + audio_tokens: d.audioTokens, + text_tokens: d.textTokens, + image_tokens: d.imageTokens, + cached_tokens: d.cachedTokens, + cached_tokens_details: d.cachedTokensDetails + ? cachedTokenDetailsToWire(d.cachedTokensDetails) + : undefined, + }); +} + +function outputTokenDetailsToWire(d: RealtimeModelMetricsOutputTokenDetails): WireObject { + return { + text_tokens: d.textTokens, + audio_tokens: d.audioTokens, + image_tokens: d.imageTokens, + }; +} + +function realtimeModelMetricsToWire(m: RealtimeModelMetrics): WireObject { + return omitUndefined({ + type: m.type, + label: m.label, + request_id: m.requestId, + timestamp: msToS(m.timestamp), + duration: msToS(m.durationMs), + session_duration: m.sessionDurationMs !== undefined ? msToS(m.sessionDurationMs) : undefined, + ttft: msToS(m.ttftMs), + cancelled: m.cancelled, + input_tokens: m.inputTokens, + output_tokens: m.outputTokens, + total_tokens: m.totalTokens, + tokens_per_second: m.tokensPerSecond, + input_token_details: inputTokenDetailsToWire(m.inputTokenDetails), + output_token_details: outputTokenDetailsToWire(m.outputTokenDetails), + metadata: metadataToWire(m.metadata), + }); +} + +function interruptionMetricsToWire(m: InterruptionMetrics): WireObject { + return omitUndefined({ + type: m.type, + timestamp: msToS(m.timestamp), + total_duration: msToS(m.totalDuration), + prediction_duration: msToS(m.predictionDuration), + detection_delay: msToS(m.detectionDelay), + num_interruptions: m.numInterruptions, + num_backchannels: m.numBackchannels, + num_requests: m.numRequests, + metadata: metadataToWire(m.metadata), + }); +} + +export function agentMetricsToWire(m: AgentMetrics): WireObject { + switch (m.type) { + case 'llm_metrics': + return llmMetricsToWire(m); + case 'stt_metrics': + return sttMetricsToWire(m); + case 'tts_metrics': + return ttsMetricsToWire(m); + case 'vad_metrics': + return vadMetricsToWire(m); + case 'eou_metrics': + return eouMetricsToWire(m); + case 'realtime_model_metrics': + return realtimeModelMetricsToWire(m); + case 'interruption_metrics': + return interruptionMetricsToWire(m); + } +} + +function llmModelUsageToWire(u: Partial): WireObject { + return { + type: u.type, + provider: u.provider ?? '', + model: u.model ?? '', + input_tokens: u.inputTokens ?? 0, + input_cached_tokens: u.inputCachedTokens ?? 0, + input_audio_tokens: u.inputAudioTokens ?? 0, + input_cached_audio_tokens: u.inputCachedAudioTokens ?? 0, + input_text_tokens: u.inputTextTokens ?? 0, + input_cached_text_tokens: u.inputCachedTextTokens ?? 0, + input_image_tokens: u.inputImageTokens ?? 0, + input_cached_image_tokens: u.inputCachedImageTokens ?? 0, + output_tokens: u.outputTokens ?? 0, + output_audio_tokens: u.outputAudioTokens ?? 0, + output_text_tokens: u.outputTextTokens ?? 0, + session_duration: msToS(u.sessionDurationMs ?? 0), + }; +} + +function ttsModelUsageToWire(u: Partial): WireObject { + return { + type: u.type, + provider: u.provider ?? '', + model: u.model ?? '', + input_tokens: u.inputTokens ?? 0, + output_tokens: u.outputTokens ?? 0, + characters_count: u.charactersCount ?? 0, + audio_duration: msToS(u.audioDurationMs ?? 0), + }; +} + +function sttModelUsageToWire(u: Partial): WireObject { + return { + type: u.type, + provider: u.provider ?? '', + model: u.model ?? '', + input_tokens: u.inputTokens ?? 0, + output_tokens: u.outputTokens ?? 0, + audio_duration: msToS(u.audioDurationMs ?? 0), + }; +} + +function interruptionModelUsageToWire(u: Partial): WireObject { + return { + type: u.type, + provider: u.provider ?? '', + model: u.model ?? '', + total_requests: u.totalRequests ?? 0, + }; +} + +export function modelUsageToWire(u: Partial): WireObject { + switch (u.type) { + case 'llm_usage': + return llmModelUsageToWire(u as Partial); + case 'tts_usage': + return ttsModelUsageToWire(u as Partial); + case 'stt_usage': + return sttModelUsageToWire(u as Partial); + case 'interruption_usage': + return interruptionModelUsageToWire(u as Partial); + default: + return u as WireObject; + } +} + +export function agentSessionUsageToWire(u: AgentSessionUsage): WireObject { + return { + model_usage: u.modelUsage.map(modelUsageToWire), + }; +} + +// =========================================================================== +// Zod wire-format schemas +// These validate the exact JSON shape that Python Pydantic emits on the wire. +// Inferred types via z.infer give fully typed parse results. +// =========================================================================== +const imageContentWireSchema = z.object({ + id: z.string(), + type: z.literal('image_content'), + image: z.string(), + inference_detail: z.enum(['auto', 'high', 'low']).optional(), + inference_width: z.number().optional(), + inference_height: z.number().optional(), + mime_type: z.string().optional(), +}); + +const audioContentWireSchema = z.object({ + type: z.literal('audio_content'), + transcript: z.string().nullable().optional(), +}); + +const chatContentWireSchema = z.union([z.string(), imageContentWireSchema, audioContentWireSchema]); + +const metricsReportWireSchema = z + .object({ + started_speaking_at: z.number().optional(), + stopped_speaking_at: z.number().optional(), + transcription_delay: z.number().optional(), + end_of_turn_delay: z.number().optional(), + on_user_turn_completed_delay: z.number().optional(), + llm_node_ttft: z.number().optional(), + tts_node_ttfb: z.number().optional(), + e2e_latency: z.number().optional(), + }) + .optional(); + +export const chatMessageWireSchema = z.object({ + id: z.string(), + type: z.literal('message'), + role: z.enum(['developer', 'system', 'user', 'assistant']), + content: z.array(chatContentWireSchema), + interrupted: z.boolean(), + created_at: z.number(), + transcript_confidence: z.number().optional(), + metrics: metricsReportWireSchema, + extra: z.record(z.string(), z.unknown()).optional(), +}); + +export const functionCallWireSchema = z.object({ + id: z.string(), + type: z.literal('function_call'), + call_id: z.string(), + arguments: z.string(), + name: z.string(), + created_at: z.number(), + extra: z.record(z.string(), z.unknown()).optional(), + group_id: z.string().optional(), +}); + +export const functionCallOutputWireSchema = z.object({ + id: z.string(), + type: z.literal('function_call_output'), + name: z.string(), + call_id: z.string(), + output: z.string(), + is_error: z.boolean(), + created_at: z.number(), +}); + +export const agentHandoffWireSchema = z.object({ + id: z.string(), + type: z.literal('agent_handoff'), + new_agent_id: z.string(), + created_at: z.number(), + old_agent_id: z.string().optional(), +}); + +export const chatItemWireSchema = z.discriminatedUnion('type', [ + chatMessageWireSchema, + functionCallWireSchema, + functionCallOutputWireSchema, + agentHandoffWireSchema, +]); + +const metadataWireSchema = z + .object({ + model_name: z.string().optional(), + model_provider: z.string().optional(), + }) + .nullable() + .optional(); + +export const llmMetricsWireSchema = z.object({ + type: z.literal('llm_metrics'), + label: z.string(), + request_id: z.string(), + timestamp: z.number(), + duration: z.number(), + ttft: z.number(), + cancelled: z.boolean(), + completion_tokens: z.number(), + prompt_tokens: z.number(), + prompt_cached_tokens: z.number(), + total_tokens: z.number(), + tokens_per_second: z.number(), + speech_id: z.string().nullable().optional(), + metadata: metadataWireSchema, +}); + +export const sttMetricsWireSchema = z.object({ + type: z.literal('stt_metrics'), + label: z.string(), + request_id: z.string(), + timestamp: z.number(), + duration: z.number(), + audio_duration: z.number(), + input_tokens: z.number().optional(), + output_tokens: z.number().optional(), + streamed: z.boolean(), + metadata: metadataWireSchema, +}); + +export const ttsMetricsWireSchema = z.object({ + type: z.literal('tts_metrics'), + label: z.string(), + request_id: z.string(), + timestamp: z.number(), + ttfb: z.number(), + duration: z.number(), + audio_duration: z.number(), + cancelled: z.boolean(), + characters_count: z.number(), + input_tokens: z.number().optional(), + output_tokens: z.number().optional(), + streamed: z.boolean(), + segment_id: z.string().nullable().optional(), + speech_id: z.string().nullable().optional(), + metadata: metadataWireSchema, +}); + +export const vadMetricsWireSchema = z.object({ + type: z.literal('vad_metrics'), + label: z.string(), + timestamp: z.number(), + idle_time: z.number(), + inference_duration_total: z.number(), + inference_count: z.number(), +}); + +export const eouMetricsWireSchema = z.object({ + type: z.literal('eou_metrics'), + timestamp: z.number(), + end_of_utterance_delay: z.number(), + transcription_delay: z.number(), + on_user_turn_completed_delay: z.number(), + speech_id: z.string().nullable().optional(), +}); + +const cachedTokenDetailsWireSchema = z.object({ + audio_tokens: z.number(), + text_tokens: z.number(), + image_tokens: z.number(), +}); + +const inputTokenDetailsWireSchema = z.object({ + audio_tokens: z.number(), + text_tokens: z.number(), + image_tokens: z.number(), + cached_tokens: z.number(), + cached_tokens_details: cachedTokenDetailsWireSchema.nullable().optional(), +}); + +const outputTokenDetailsWireSchema = z.object({ + text_tokens: z.number(), + audio_tokens: z.number(), + image_tokens: z.number(), +}); + +export const realtimeModelMetricsWireSchema = z.object({ + type: z.literal('realtime_model_metrics'), + label: z.string(), + request_id: z.string(), + timestamp: z.number(), + duration: z.number(), + session_duration: z.number().optional(), + ttft: z.number(), + cancelled: z.boolean(), + input_tokens: z.number(), + output_tokens: z.number(), + total_tokens: z.number(), + tokens_per_second: z.number(), + input_token_details: inputTokenDetailsWireSchema, + output_token_details: outputTokenDetailsWireSchema, + metadata: metadataWireSchema, +}); + +export const interruptionMetricsWireSchema = z.object({ + type: z.literal('interruption_metrics'), + timestamp: z.number(), + total_duration: z.number(), + prediction_duration: z.number(), + detection_delay: z.number(), + num_interruptions: z.number(), + num_backchannels: z.number(), + num_requests: z.number(), + metadata: metadataWireSchema, +}); + +export const agentMetricsWireSchema = z.discriminatedUnion('type', [ + llmMetricsWireSchema, + sttMetricsWireSchema, + ttsMetricsWireSchema, + vadMetricsWireSchema, + eouMetricsWireSchema, + realtimeModelMetricsWireSchema, + interruptionMetricsWireSchema, +]); + +// --------------------------------------------------------------------------- +// Model usage schemas +// --------------------------------------------------------------------------- + +export const llmModelUsageWireSchema = z.object({ + type: z.literal('llm_usage'), + provider: z.string().optional(), + model: z.string().optional(), + input_tokens: z.number().optional(), + input_cached_tokens: z.number().optional(), + input_audio_tokens: z.number().optional(), + input_cached_audio_tokens: z.number().optional(), + input_text_tokens: z.number().optional(), + input_cached_text_tokens: z.number().optional(), + input_image_tokens: z.number().optional(), + input_cached_image_tokens: z.number().optional(), + output_tokens: z.number().optional(), + output_audio_tokens: z.number().optional(), + output_text_tokens: z.number().optional(), + session_duration: z.number().optional(), +}); + +export const ttsModelUsageWireSchema = z.object({ + type: z.literal('tts_usage'), + provider: z.string().optional(), + model: z.string().optional(), + input_tokens: z.number().optional(), + output_tokens: z.number().optional(), + characters_count: z.number().optional(), + audio_duration: z.number().optional(), +}); + +export const sttModelUsageWireSchema = z.object({ + type: z.literal('stt_usage'), + provider: z.string().optional(), + model: z.string().optional(), + input_tokens: z.number().optional(), + output_tokens: z.number().optional(), + audio_duration: z.number().optional(), +}); + +export const interruptionModelUsageWireSchema = z.object({ + type: z.literal('interruption_usage'), + provider: z.string().optional(), + model: z.string().optional(), + total_requests: z.number().optional(), +}); + +export const modelUsageWireSchema = z.discriminatedUnion('type', [ + llmModelUsageWireSchema, + ttsModelUsageWireSchema, + sttModelUsageWireSchema, + interruptionModelUsageWireSchema, +]); + +export const agentSessionUsageWireSchema = z.object({ + model_usage: z.array(modelUsageWireSchema), +}); + +// --------------------------------------------------------------------------- +// Client event schemas +// --------------------------------------------------------------------------- + +const agentStateSchema = z.enum(['initializing', 'idle', 'listening', 'thinking', 'speaking']); +const userStateSchema = z.enum(['speaking', 'listening', 'away']); + +export const clientAgentStateChangedSchema = z.object({ + type: z.literal('agent_state_changed'), + old_state: agentStateSchema, + new_state: agentStateSchema, + created_at: z.number(), +}); + +export const clientUserStateChangedSchema = z.object({ + type: z.literal('user_state_changed'), + old_state: userStateSchema, + new_state: userStateSchema, + created_at: z.number(), +}); + +export const clientConversationItemAddedSchema = z.object({ + type: z.literal('conversation_item_added'), + item: chatMessageWireSchema, + created_at: z.number(), +}); + +export const clientUserInputTranscribedSchema = z.object({ + type: z.literal('user_input_transcribed'), + transcript: z.string(), + is_final: z.boolean(), + language: z.string().nullable(), + created_at: z.number(), +}); + +export const clientFunctionToolsExecutedSchema = z.object({ + type: z.literal('function_tools_executed'), + function_calls: z.array(functionCallWireSchema), + function_call_outputs: z.array(functionCallOutputWireSchema.nullable()), + created_at: z.number(), +}); + +export const clientMetricsCollectedSchema = z.object({ + type: z.literal('metrics_collected'), + metrics: agentMetricsWireSchema, + created_at: z.number(), +}); + +export const clientErrorSchema = z.object({ + type: z.literal('error'), + message: z.string(), + created_at: z.number(), +}); + +export const clientUserOverlappingSpeechSchema = z.object({ + type: z.literal('user_overlapping_speech'), + is_interruption: z.boolean(), + created_at: z.number(), + sent_at: z.number(), + detection_delay: z.number(), + overlap_started_at: z.number().nullable(), +}); + +export const clientSessionUsageSchema = z.object({ + type: z.literal('session_usage'), + usage: agentSessionUsageWireSchema, + created_at: z.number(), +}); + +export const clientEventSchema = z.discriminatedUnion('type', [ + clientAgentStateChangedSchema, + clientUserStateChangedSchema, + clientConversationItemAddedSchema, + clientUserInputTranscribedSchema, + clientFunctionToolsExecutedSchema, + clientMetricsCollectedSchema, + clientErrorSchema, + clientUserOverlappingSpeechSchema, + clientSessionUsageSchema, +]); + +// --------------------------------------------------------------------------- +// RPC schemas +// --------------------------------------------------------------------------- + +export const sendMessageRequestSchema = z.object({ + text: z.string(), +}); + +export const streamRequestSchema = z.object({ + request_id: z.string(), + method: z.string(), + payload: z.string(), +}); + +export const streamResponseSchema = z.object({ + request_id: z.string(), + payload: z.string(), + error: z.string().nullable().optional(), +}); + +export const getSessionStateResponseSchema = z.object({ + agent_state: agentStateSchema, + user_state: userStateSchema, + agent_id: z.string(), + options: z.record(z.string(), z.unknown()), + created_at: z.number(), +}); + +export const getChatHistoryResponseSchema = z.object({ + items: z.array(chatItemWireSchema), +}); + +export const getAgentInfoResponseSchema = z.object({ + id: z.string(), + instructions: z.string().nullable(), + tools: z.array(z.string()), + chat_ctx: z.array(chatItemWireSchema), +}); + +export const sendMessageResponseSchema = z.object({ + items: z.array(chatItemWireSchema), +}); + +export const getRTCStatsResponseSchema = z.object({ + publisher_stats: z.array(z.record(z.string(), z.unknown())), + subscriber_stats: z.array(z.record(z.string(), z.unknown())), +}); + +export const getSessionUsageResponseSchema = z.object({ + usage: agentSessionUsageWireSchema, + created_at: z.number(), +}); diff --git a/examples/src/basic_agent.ts b/examples/src/basic_agent.ts index ac2512b2c..e5fd5290f 100644 --- a/examples/src/basic_agent.ts +++ b/examples/src/basic_agent.ts @@ -9,6 +9,7 @@ import { defineAgent, inference, llm, + log, metrics, voice, } from '@livekit/agents'; @@ -39,6 +40,8 @@ export default defineAgent({ }, }); + const logger = log(); + const session = new voice.AgentSession({ // Speech-to-text (STT) is your agent's ears, turning the user's speech into text that the LLM can understand // See all available models at https://docs.livekit.io/agents/models/stt/ @@ -64,12 +67,20 @@ export default defineAgent({ // VAD and turn detection are used to determine when the user is speaking and when the agent should respond // See more at https://docs.livekit.io/agents/build/turns vad: ctx.proc.userData.vad! as silero.VAD, - turnDetection: new livekit.turnDetector.MultilingualModel(), + // to use realtime model, replace the stt, llm, tts and vad with the following // llm: new openai.realtime.RealtimeModel(), - voiceOptions: { + options: { // allow the LLM to generate a response while waiting for the end of turn preemptiveGeneration: true, + turnHandling: { + turnDetection: new livekit.turnDetector.MultilingualModel(), + interruption: { + resumeFalseInterruption: true, + falseInterruptionTimeout: 1, + mode: 'adaptive', + }, + }, useTtsAlignedTranscript: true, aecWarmupDuration: 3000, }, @@ -83,11 +94,23 @@ export default defineAgent({ }, }); - const usageCollector = new metrics.UsageCollector(); - + // Log metrics as they are emitted session.on(voice.AgentSessionEventTypes.MetricsCollected, (ev) => { metrics.logMetrics(ev.metrics); - usageCollector.collect(ev.metrics); + }); + + // Log usage summary when job shuts down + ctx.addShutdownCallback(async () => { + logger.info( + { + usage: session.usage, + }, + 'Session usage summary', + ); + }); + + session.on(voice.AgentSessionEventTypes.UserOverlappingSpeech, (ev) => { + logger.warn({ type: ev.type, isInterruption: ev.isInterruption }, 'user overlapping speech'); }); await session.start({ diff --git a/examples/src/bey_avatar.ts b/examples/src/bey_avatar.ts index f8eb1f3d1..5cad7655c 100644 --- a/examples/src/bey_avatar.ts +++ b/examples/src/bey_avatar.ts @@ -1,7 +1,15 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { type JobContext, WorkerOptions, cli, defineAgent, metrics, voice } from '@livekit/agents'; +import { + type JobContext, + WorkerOptions, + cli, + defineAgent, + log, + metrics, + voice, +} from '@livekit/agents'; import * as bey from '@livekit/agents-plugin-bey'; import * as openai from '@livekit/agents-plugin-openai'; import { fileURLToPath } from 'node:url'; @@ -12,6 +20,7 @@ export default defineAgent({ instructions: 'You are a helpful assistant. Speak clearly and concisely.', }); + const logger = log(); const session = new voice.AgentSession({ llm: new openai.realtime.RealtimeModel({ voice: 'alloy', @@ -32,11 +41,19 @@ export default defineAgent({ }); await avatar.start(session, ctx.room); - const usageCollector = new metrics.UsageCollector(); - + // Log metrics as they are emitted (session.usage is automatically collected) session.on(voice.AgentSessionEventTypes.MetricsCollected, (ev) => { metrics.logMetrics(ev.metrics); - usageCollector.collect(ev.metrics); + }); + + // Log usage summary when job shuts down + ctx.addShutdownCallback(async () => { + logger.info( + { + usage: session.usage, + }, + 'Session usage summary', + ); }); session.generateReply({ diff --git a/examples/src/cartesia_tts.ts b/examples/src/cartesia_tts.ts index a11aae33a..4d40b7334 100644 --- a/examples/src/cartesia_tts.ts +++ b/examples/src/cartesia_tts.ts @@ -7,6 +7,7 @@ import { WorkerOptions, cli, defineAgent, + log, metrics, voice, } from '@livekit/agents'; @@ -28,6 +29,7 @@ export default defineAgent({ "You are a helpful assistant, you can hear the user's message and respond to it.", }); + const logger = log(); const vad = ctx.proc.userData.vad! as silero.VAD; const session = new voice.AgentSession({ @@ -40,11 +42,19 @@ export default defineAgent({ turnDetection: new livekit.turnDetector.MultilingualModel(), }); - const usageCollector = new metrics.UsageCollector(); - + // Log metrics as they are emitted (session.usage is automatically collected) session.on(voice.AgentSessionEventTypes.MetricsCollected, (ev) => { metrics.logMetrics(ev.metrics); - usageCollector.collect(ev.metrics); + }); + + // Log usage summary when job shuts down + ctx.addShutdownCallback(async () => { + logger.info( + { + usage: session.usage, + }, + 'Session usage summary', + ); }); await session.start({ diff --git a/examples/src/comprehensive_test.ts b/examples/src/comprehensive_test.ts index b6d08d6cd..bac9910cc 100644 --- a/examples/src/comprehensive_test.ts +++ b/examples/src/comprehensive_test.ts @@ -8,6 +8,7 @@ import { cli, defineAgent, llm, + log, metrics, voice, } from '@livekit/agents'; @@ -238,6 +239,7 @@ export default defineAgent({ proc.userData.vad = await silero.VAD.load(); }, entry: async (ctx: JobContext) => { + const logger = log(); const vad = ctx.proc.userData.vad! as silero.VAD; const session = new voice.AgentSession({ vad, @@ -249,11 +251,19 @@ export default defineAgent({ testedRealtimeLlmChoices: new Set(), }, }); - const usageCollector = new metrics.UsageCollector(); - + // Log metrics as they are emitted (session.usage is automatically collected) session.on(voice.AgentSessionEventTypes.MetricsCollected, (ev) => { metrics.logMetrics(ev.metrics); - usageCollector.collect(ev.metrics); + }); + + // Log usage summary when job shuts down + ctx.addShutdownCallback(async () => { + logger.info( + { + usage: session.usage, + }, + 'Session usage summary', + ); }); await session.start({ diff --git a/examples/src/hedra/hedra_avatar.ts b/examples/src/hedra/hedra_avatar.ts index 38c103d66..9bead6094 100644 --- a/examples/src/hedra/hedra_avatar.ts +++ b/examples/src/hedra/hedra_avatar.ts @@ -9,6 +9,7 @@ import { defineAgent, inference, initializeLogger, + log, metrics, voice, } from '@livekit/agents'; @@ -33,6 +34,7 @@ export default defineAgent({ instructions: 'You are a helpful assistant. Speak clearly and concisely.', }); + const logger = log(); const session = new voice.AgentSession({ stt: new inference.STT({ model: 'deepgram/nova-3', @@ -68,11 +70,19 @@ export default defineAgent({ }); await avatar.start(session, ctx.room); - const usageCollector = new metrics.UsageCollector(); - + // Log metrics as they are emitted (session.usage is automatically collected) session.on(voice.AgentSessionEventTypes.MetricsCollected, (ev) => { metrics.logMetrics(ev.metrics); - usageCollector.collect(ev.metrics); + }); + + // Log usage summary when job shuts down + ctx.addShutdownCallback(async () => { + logger.info( + { + usage: session.usage, + }, + 'Session usage summary', + ); }); session.generateReply({ diff --git a/examples/src/inworld_tts.ts b/examples/src/inworld_tts.ts index fec5c552d..45bb6961c 100644 --- a/examples/src/inworld_tts.ts +++ b/examples/src/inworld_tts.ts @@ -7,6 +7,7 @@ import { WorkerOptions, cli, defineAgent, + log, metrics, voice, } from '@livekit/agents'; @@ -26,6 +27,7 @@ export default defineAgent({ "You are a helpful assistant, you can hear the user's message and respond to it in 1-2 short sentences.", }); + const logger = log(); // Create TTS instance const tts = new inworld.TTS({ timestampType: 'WORD', @@ -96,11 +98,19 @@ export default defineAgent({ } }); - const usageCollector = new metrics.UsageCollector(); - + // Log metrics as they are emitted (session.usage is automatically collected) session.on(voice.AgentSessionEventTypes.MetricsCollected, (ev) => { metrics.logMetrics(ev.metrics); - usageCollector.collect(ev.metrics); + }); + + // Log usage summary when job shuts down + ctx.addShutdownCallback(async () => { + logger.info( + { + usage: session.usage, + }, + 'Session usage summary', + ); }); await session.start({ diff --git a/plugins/cartesia/src/tts.ts b/plugins/cartesia/src/tts.ts index 37cfcf367..467535406 100644 --- a/plugins/cartesia/src/tts.ts +++ b/plugins/cartesia/src/tts.ts @@ -82,6 +82,14 @@ export class TTS extends tts.TTS { #opts: TTSOptions; label = 'cartesia.TTS'; + get model(): string { + return this.#opts.model; + } + + get provider(): string { + return 'Cartesia'; + } + constructor(opts: Partial = {}) { const resolvedOpts = { ...defaultTTSOptions, diff --git a/plugins/deepgram/src/stt.ts b/plugins/deepgram/src/stt.ts index 805015ec4..f2f232010 100644 --- a/plugins/deepgram/src/stt.ts +++ b/plugins/deepgram/src/stt.ts @@ -70,6 +70,14 @@ export class STT extends stt.STT { label = 'deepgram.STT'; private abortController = new AbortController(); + get model(): string { + return this.#opts.model; + } + + get provider(): string { + return 'Deepgram'; + } + constructor(opts: Partial = defaultSTTOptions) { super({ streaming: true, diff --git a/plugins/deepgram/src/tts.ts b/plugins/deepgram/src/tts.ts index 5e9aceb30..6c6c2ff98 100644 --- a/plugins/deepgram/src/tts.ts +++ b/plugins/deepgram/src/tts.ts @@ -46,6 +46,14 @@ export class TTS extends tts.TTS { private opts: TTSOptions; label = 'deepgram.TTS'; + get model(): string { + return this.opts.model; + } + + get provider(): string { + return 'Deepgram'; + } + constructor(opts: Partial = {}) { super(opts.sampleRate || defaultTTSOptions.sampleRate, NUM_CHANNELS, { streaming: opts.capabilities?.streaming ?? defaultTTSOptions.capabilities.streaming, diff --git a/plugins/google/src/llm.ts b/plugins/google/src/llm.ts index 7b9d4b4ac..cea76ac52 100644 --- a/plugins/google/src/llm.ts +++ b/plugins/google/src/llm.ts @@ -51,6 +51,13 @@ export class LLM extends llm.LLM { return this.#opts.model; } + get provider(): string { + if (this.#opts.vertexai) { + return 'Vertex AI'; + } + return 'Gemini'; + } + /** * Create a new instance of Google GenAI LLM. * diff --git a/plugins/livekit/src/turn_detector/base.ts b/plugins/livekit/src/turn_detector/base.ts index 0b31ed80a..b9fff38b9 100644 --- a/plugins/livekit/src/turn_detector/base.ts +++ b/plugins/livekit/src/turn_detector/base.ts @@ -170,6 +170,14 @@ export abstract class EOUModel { #logger = log(); + get model(): string { + return MODEL_REVISIONS[this.modelType]; + } + + get provider(): string { + return 'livekit'; + } + constructor(opts: EOUModelOptions) { const { modelType = 'en', diff --git a/plugins/openai/src/llm.ts b/plugins/openai/src/llm.ts index 22299344a..f4d055506 100644 --- a/plugins/openai/src/llm.ts +++ b/plugins/openai/src/llm.ts @@ -86,6 +86,15 @@ export class LLM extends llm.LLM { return this.#opts.model; } + get provider(): string { + try { + const url = new URL(this.#client.baseURL); + return url.host; + } catch { + return 'api.openai.com'; + } + } + /** * Create a new instance of OpenAI LLM with Azure. * diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 1aaffd014..2711bff5b 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -144,6 +144,15 @@ export class RealtimeModel extends llm.RealtimeModel { return this._options.model; } + get provider(): string { + try { + const url = new URL(this._options.baseURL); + return url.host; + } catch { + return 'api.openai.com'; + } + } + constructor( options: { model?: string; diff --git a/plugins/openai/src/stt.ts b/plugins/openai/src/stt.ts index ef2f32aea..2933c62fd 100644 --- a/plugins/openai/src/stt.ts +++ b/plugins/openai/src/stt.ts @@ -28,6 +28,19 @@ export class STT extends stt.STT { #client: OpenAI; label = 'openai.STT'; + get model(): string { + return this.#opts.model; + } + + get provider(): string { + try { + const url = new URL(this.#client.baseURL); + return url.host; + } catch { + return 'api.openai.com'; + } + } + /** * Create a new instance of OpenAI STT. * diff --git a/plugins/openai/src/tts.ts b/plugins/openai/src/tts.ts index 2bb77c3d5..3bce9d501 100644 --- a/plugins/openai/src/tts.ts +++ b/plugins/openai/src/tts.ts @@ -32,6 +32,19 @@ export class TTS extends tts.TTS { label = 'openai.TTS'; private abortController = new AbortController(); + get model(): string { + return this.#opts.model; + } + + get provider(): string { + try { + const url = new URL(this.#client.baseURL); + return url.host; + } catch { + return 'api.openai.com'; + } + } + /** * Create a new instance of OpenAI TTS. * diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 94cda9f29..1ebb9eab0 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -169,6 +169,9 @@ importers: livekit-server-sdk: specifier: ^2.14.1 version: 2.14.1 + ofetch: + specifier: ^1.5.1 + version: 1.5.1 openai: specifier: ^6.8.1 version: 6.8.1(ws@8.18.3)(zod@3.25.76) @@ -3082,6 +3085,9 @@ packages: resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==} engines: {node: '>=6'} + destr@2.0.5: + resolution: {integrity: sha512-ugFTXCtDZunbzasqBxrK93Ik/DRYsO6S/fedkWEMKqt04xZ4csmnmwGDBAb07QWNaGMAmnTIemsYZCksjATwsA==} + detect-indent@6.1.0: resolution: {integrity: sha512-reYkTUJAZb9gUuZ2RvVCNhVHdg62RHnJ7WJl8ftMi4diZ6NWlciOzQN88pUhSELEwflJht4oQDv0F0BMlwaYtA==} engines: {node: '>=8'} @@ -4213,6 +4219,9 @@ packages: engines: {node: '>=10.5.0'} deprecated: Use your platform's native DOMException instead + node-fetch-native@1.6.7: + resolution: {integrity: sha512-g9yhqoedzIUm0nTnTqAQvueMPVOuIY16bqgAJJC8XOOubYFNwz6IER9qs0Gq2Xd0+CecCKFjtdDTMA4u4xG06Q==} + node-fetch@2.7.0: resolution: {integrity: sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==} engines: {node: 4.x || >=6.0.0} @@ -4268,6 +4277,9 @@ packages: obug@2.1.1: resolution: {integrity: sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==} + ofetch@1.5.1: + resolution: {integrity: sha512-2W4oUZlVaqAPAil6FUg/difl6YhqhUR7x2eZY4bQCko22UXg3hptq9KLQdqFClV+Wu85UX7hNtdGTngi/1BxcA==} + on-exit-leak-free@2.1.2: resolution: {integrity: sha512-0eJJY6hXLGf1udHwfNftBqH+g73EU4B504nZeKpz1sYRKafAghwxEJunB2O7rDZkL4PGfsMVnTXZ2EjibbqcsA==} engines: {node: '>=14.0.0'} @@ -5105,6 +5117,9 @@ packages: ufo@1.5.3: resolution: {integrity: sha512-Y7HYmWaFwPUmkoQCUIAYpKqkOf+SbVj/2fJJZ4RJMCfZp0rTGwRbzQD+HghfnhKOjL9E01okqz+ncJskGYfBNw==} + ufo@1.6.3: + resolution: {integrity: sha512-yDJTmhydvl5lJzBmy/hyOAA0d+aqCBuwl818haVdYCRrWV84o7YyeVm4QlVHStqNrrJSTb6jKuFAVqAFsr+K3Q==} + unbox-primitive@1.0.2: resolution: {integrity: sha512-61pPlCD9h51VoreyJ0BReideM3MDKMKnh6+V9L08331ipq6Q8OFXZYiqP6n/tbHx4s5I9uRhcye6BrbkizkBDw==} @@ -7449,6 +7464,8 @@ snapshots: dequal@2.0.3: {} + destr@2.0.5: {} + detect-indent@6.1.0: {} detect-libc@2.1.2: {} @@ -8785,6 +8802,8 @@ snapshots: node-domexception@1.0.0: {} + node-fetch-native@1.6.7: {} + node-fetch@2.7.0: dependencies: whatwg-url: 5.0.0 @@ -8845,6 +8864,12 @@ snapshots: obug@2.1.1: {} + ofetch@1.5.1: + dependencies: + destr: 2.0.5 + node-fetch-native: 1.6.7 + ufo: 1.6.3 + on-exit-leak-free@2.1.2: {} once@1.4.0: @@ -9814,6 +9839,8 @@ snapshots: ufo@1.5.3: {} + ufo@1.6.3: {} + unbox-primitive@1.0.2: dependencies: call-bind: 1.0.7