Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class ModelConfig:
supports_prompt: Whether the model accepts prompt parameter.
supports_mode: Whether the model accepts mode parameter.
supports_language: Whether the model accepts language parameter.
supports_vad_params: Whether the model accepts fine-grained VAD parameters.
default_language: Default language code (None = auto-detect).
default_mode: Default mode (None = not applicable).
use_translate_endpoint: Whether to use speech_to_text_translate_streaming endpoint.
Expand All @@ -134,6 +135,7 @@ class ModelConfig:
supports_prompt: bool
supports_mode: bool
supports_language: bool
supports_vad_params: bool
default_language: str | None
default_mode: str | None
use_translate_endpoint: bool
Expand All @@ -146,6 +148,7 @@ class ModelConfig:
supports_prompt=False,
supports_mode=False,
supports_language=True,
supports_vad_params=False,
default_language="unknown",
default_mode=None,
use_translate_endpoint=False,
Expand All @@ -156,6 +159,7 @@ class ModelConfig:
supports_prompt=True,
supports_mode=False,
supports_language=False,
supports_vad_params=False,
default_language=None,
default_mode=None,
use_translate_endpoint=True,
Expand All @@ -166,6 +170,7 @@ class ModelConfig:
supports_prompt=True,
supports_mode=True,
supports_language=True,
supports_vad_params=True,
default_language="en-IN",
default_mode="transcribe",
use_translate_endpoint=False,
Expand Down Expand Up @@ -253,6 +258,14 @@ def _model_supports_mode(model: str) -> bool:
return False


def _model_supports_vad_params(model: str) -> bool:
"""Check whether the model supports fine-grained VAD parameters."""
model_config = _get_model_config(model)
if model_config:
return model_config.supports_vad_params
return False


class ConnectionState(enum.Enum):
"""WebSocket connection states."""

Expand Down Expand Up @@ -287,6 +300,16 @@ class SarvamSTTOptions:
sample_rate: int = 16000
flush_signal: bool | None = None
input_audio_codec: str | None = None
positive_speech_threshold: float | None = None
negative_speech_threshold: float | None = None
min_speech_frames: int | None = None
first_turn_min_speech_frames: int | None = None
negative_frames_count: int | None = None
negative_frames_window: int | None = None
start_speech_volume_threshold: float | None = None
interrupt_min_speech_frames: int | None = None
pre_speech_pad_frames: int | None = None
num_initial_ignored_frames: int | None = None

def __post_init__(self) -> None:
"""Set URLs based on model if not explicitly provided."""
Expand Down Expand Up @@ -360,6 +383,28 @@ def _build_websocket_url(base_url: str, opts: SarvamSTTOptions) -> str:
if opts.input_audio_codec:
params["input_audio_codec"] = opts.input_audio_codec

if _model_supports_vad_params(opts.model):
if opts.positive_speech_threshold is not None:
params["positive_speech_threshold"] = str(opts.positive_speech_threshold)
if opts.negative_speech_threshold is not None:
params["negative_speech_threshold"] = str(opts.negative_speech_threshold)
if opts.min_speech_frames is not None:
params["min_speech_frames"] = str(opts.min_speech_frames)
if opts.first_turn_min_speech_frames is not None:
params["first_turn_min_speech_frames"] = str(opts.first_turn_min_speech_frames)
if opts.negative_frames_count is not None:
params["negative_frames_count"] = str(opts.negative_frames_count)
if opts.negative_frames_window is not None:
params["negative_frames_window"] = str(opts.negative_frames_window)
if opts.start_speech_volume_threshold is not None:
params["start_speech_volume_threshold"] = str(opts.start_speech_volume_threshold)
if opts.interrupt_min_speech_frames is not None:
params["interrupt_min_speech_frames"] = str(opts.interrupt_min_speech_frames)
if opts.pre_speech_pad_frames is not None:
params["pre_speech_pad_frames"] = str(opts.pre_speech_pad_frames)
if opts.num_initial_ignored_frames is not None:
params["num_initial_ignored_frames"] = str(opts.num_initial_ignored_frames)

return f"{base_url}?{urlencode(params)}"


Expand Down Expand Up @@ -425,6 +470,16 @@ def __init__(
sample_rate: int = 16000,
flush_signal: bool | None = None,
input_audio_codec: str | None = None,
positive_speech_threshold: float | None = None,
negative_speech_threshold: float | None = None,
min_speech_frames: int | None = None,
first_turn_min_speech_frames: int | None = None,
negative_frames_count: int | None = None,
negative_frames_window: int | None = None,
start_speech_volume_threshold: float | None = None,
interrupt_min_speech_frames: int | None = None,
pre_speech_pad_frames: int | None = None,
num_initial_ignored_frames: int | None = None,
) -> None:
super().__init__(
capabilities=stt.STTCapabilities(
Expand Down Expand Up @@ -453,6 +508,16 @@ def __init__(
sample_rate=sample_rate,
flush_signal=flush_signal,
input_audio_codec=input_audio_codec,
positive_speech_threshold=positive_speech_threshold,
negative_speech_threshold=negative_speech_threshold,
min_speech_frames=min_speech_frames,
first_turn_min_speech_frames=first_turn_min_speech_frames,
negative_frames_count=negative_frames_count,
negative_frames_window=negative_frames_window,
start_speech_volume_threshold=start_speech_volume_threshold,
interrupt_min_speech_frames=interrupt_min_speech_frames,
pre_speech_pad_frames=pre_speech_pad_frames,
num_initial_ignored_frames=num_initial_ignored_frames,
)
self._session = http_session
self._logger = logger.getChild(self.__class__.__name__)
Expand Down Expand Up @@ -668,6 +733,16 @@ def stream(
sample_rate: NotGivenOr[int] = NOT_GIVEN,
flush_signal: NotGivenOr[bool] = NOT_GIVEN,
input_audio_codec: NotGivenOr[str] = NOT_GIVEN,
positive_speech_threshold: NotGivenOr[float] = NOT_GIVEN,
negative_speech_threshold: NotGivenOr[float] = NOT_GIVEN,
min_speech_frames: NotGivenOr[int] = NOT_GIVEN,
first_turn_min_speech_frames: NotGivenOr[int] = NOT_GIVEN,
negative_frames_count: NotGivenOr[int] = NOT_GIVEN,
negative_frames_window: NotGivenOr[int] = NOT_GIVEN,
start_speech_volume_threshold: NotGivenOr[float] = NOT_GIVEN,
interrupt_min_speech_frames: NotGivenOr[int] = NOT_GIVEN,
pre_speech_pad_frames: NotGivenOr[int] = NOT_GIVEN,
num_initial_ignored_frames: NotGivenOr[int] = NOT_GIVEN,
) -> SpeechStream:
"""Create a streaming transcription session."""
opts_language, opts_model, opts_mode = self._resolve_opts(
Expand All @@ -689,6 +764,54 @@ def stream(
opts_input_codec = (
input_audio_codec if is_given(input_audio_codec) else self._opts.input_audio_codec
)
opts_positive_speech = (
positive_speech_threshold
if is_given(positive_speech_threshold)
else self._opts.positive_speech_threshold
)
opts_negative_speech = (
negative_speech_threshold
if is_given(negative_speech_threshold)
else self._opts.negative_speech_threshold
)
opts_min_speech = (
min_speech_frames if is_given(min_speech_frames) else self._opts.min_speech_frames
)
opts_first_turn = (
first_turn_min_speech_frames
if is_given(first_turn_min_speech_frames)
else self._opts.first_turn_min_speech_frames
)
opts_neg_count = (
negative_frames_count
if is_given(negative_frames_count)
else self._opts.negative_frames_count
)
opts_neg_window = (
negative_frames_window
if is_given(negative_frames_window)
else self._opts.negative_frames_window
)
opts_vol_threshold = (
start_speech_volume_threshold
if is_given(start_speech_volume_threshold)
else self._opts.start_speech_volume_threshold
)
opts_interrupt = (
interrupt_min_speech_frames
if is_given(interrupt_min_speech_frames)
else self._opts.interrupt_min_speech_frames
)
opts_pre_pad = (
pre_speech_pad_frames
if is_given(pre_speech_pad_frames)
else self._opts.pre_speech_pad_frames
)
opts_initial_ignored = (
num_initial_ignored_frames
if is_given(num_initial_ignored_frames)
else self._opts.num_initial_ignored_frames
)
single_attempt_conn_options = self._single_attempt_conn_options(conn_options)

# Create options for the stream
Expand All @@ -702,6 +825,16 @@ def stream(
sample_rate=opts_sample_rate,
flush_signal=opts_flush_signal,
input_audio_codec=opts_input_codec,
positive_speech_threshold=opts_positive_speech,
negative_speech_threshold=opts_negative_speech,
min_speech_frames=opts_min_speech,
first_turn_min_speech_frames=opts_first_turn,
negative_frames_count=opts_neg_count,
negative_frames_window=opts_neg_window,
start_speech_volume_threshold=opts_vol_threshold,
interrupt_min_speech_frames=opts_interrupt,
pre_speech_pad_frames=opts_pre_pad,
num_initial_ignored_frames=opts_initial_ignored,
)

# Create a fresh session for this stream to avoid conflicts
Expand Down