Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions sdk/voice/speechmatics/voice/_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ class AudioBuffer:
frame_size and total_seconds. As the buffer fills, the oldest
data is removed and the start_time is updated.

The function get_slice(start_time, end_time) will return a snapshot
of the data between the start_time and end_time. If the start_time is
before the start of the buffer, then the start_time will be set to the
start of the buffer. If the end_time is after the end of the buffer,
then the end_time will be set to the end of the buffer.
The function get_frames(start_time, end_time) will return a snapshot
of the data between the start_time and end_time, with optional fade-out.
If the start_time is before the start of the buffer, then the start_time
will be set to the start of the buffer. If the end_time is after the end
of the buffer, then the end_time will be set to the end of the buffer.

Timing is based on the number of bytes added to the buffer.

Expand Down Expand Up @@ -90,7 +90,8 @@ async def put_bytes(self, data: bytes) -> None:
data: The data frame to add to the buffer.
"""

# If the right length and buffer zero
# If data is exactly one frame and there's no buffered remainder,
# put the frame directly into the buffer.
if len(data) // self._sample_width == self._frame_size and len(self._buffer) == 0:
return await self.put_frame(data)

Expand All @@ -109,19 +110,23 @@ async def put_bytes(self, data: bytes) -> None:
await self.put_frame(frame)

async def put_frame(self, data: bytes) -> None:
"""Add data to the buffer.
"""Add data frame to the buffer.

New data added to the end of the buffer. The oldest data is removed
New data frame is added to the end of the buffer. The oldest data is removed
to maintain the total number of seconds in the buffer.

Args:
data: The data frame to add to the buffer.
"""
# Verify number of bytes matches frame size
if len(data) != self._frame_bytes:
raise ValueError(f"Invalid frame size: {len(data)} bytes, expected {self._frame_bytes} bytes")

# Add data to the buffer
async with self._lock:
self._frames.append(data)
self._total_frames += 1
# Trim to rolling window, keep last _max_frames frames
if len(self._frames) > self._max_frames:
self._frames = self._frames[-self._max_frames :]

Expand Down Expand Up @@ -192,6 +197,7 @@ def _fade_out_audio(self, data: bytes, fade_out: float = 0.01) -> bytes:
Bytes with fade-out applied.
"""
# Choose dtype
# Todo - establish supported sample_width values
dtype: type[np.signedinteger]
if self._sample_width == 1:
dtype = np.int8
Expand All @@ -212,11 +218,14 @@ def _fade_out_audio(self, data: bytes, fade_out: float = 0.01) -> bytes:
envelope = np.linspace(1.0, 0.0, fade_samples, endpoint=True)

# Apply fade
faded = samples.astype(np.float32)
faded[-fade_samples:] *= envelope
# Only convert the section being modified to save memory
tail = samples[-fade_samples:].astype(np.float32) * envelope

# Robust Conversion: Round to nearest integer and clip to valid range to avoid wraparound
info = np.iinfo(dtype)
faded_tail = np.round(tail).clip(info.min, info.max).astype(dtype)

# Convert back to original dtype and bytes
return bytes(faded.astype(dtype).tobytes())
return samples[:-fade_samples].tobytes() + faded_tail.tobytes()

async def reset(self) -> None:
"""Reset the buffer."""
Expand Down
1 change: 1 addition & 0 deletions sdk/voice/speechmatics/voice/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __init__(

# Audio sampling info
self._audio_sample_rate: int = self._audio_format.sample_rate
# Todo - establish supported sample_width values
self._audio_sample_width: int = {
AudioEncoding.PCM_F32LE: 4,
AudioEncoding.PCM_S16LE: 2,
Expand Down
4 changes: 2 additions & 2 deletions sdk/voice/speechmatics/voice/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ class AgentServerMessageType(str, Enum):
StartOfTurn: Start of turn has been detected.
EndOfTurnPrediction: End of turn prediction timing.
EndOfTurn: End of turn has been detected.
SmartTurn: Smart turn metadata.
SmartTurnResult: Smart turn metadata.
SpeakersResult: Speakers result has been detected.
Metrics: Metrics for the STT engine.
SessionMetrics: Metrics for the STT engine.
SpeakerMetrics: Metrics relating to speakers.

Examples:
Expand Down
110 changes: 73 additions & 37 deletions sdk/voice/speechmatics/voice/_smart_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ class SmartTurnDetector:
Further information at https://github.com/pipecat-ai/smart-turn
"""

def __init__(self, auto_init: bool = True, threshold: float = 0.8):
# Constants
DEFAULT_SAMPLE_RATE = 16000
DEFAULT_THRESHOLD = 0.8
WINDOW_SECONDS = 8

def __init__(self, auto_init: bool = True, threshold: float = DEFAULT_THRESHOLD):
"""Create the new SmartTurnDetector.

Args:
Expand Down Expand Up @@ -125,7 +130,7 @@ def setup(self) -> None:
self.session = self.build_session(SMART_TURN_MODEL_LOCAL_PATH)

# Load the feature extractor
self.feature_extractor = WhisperFeatureExtractor(chunk_length=8)
self.feature_extractor = WhisperFeatureExtractor(chunk_length=self.WINDOW_SECONDS)

# Set initialized
self._is_initialized = True
Expand Down Expand Up @@ -156,83 +161,113 @@ def build_session(self, onnx_path: str) -> ort.InferenceSession:
# Return the new session
return ort.InferenceSession(onnx_path, sess_options=so)

async def predict(
self, audio_array: bytes, language: str, sample_rate: int = 16000, sample_width: int = 2
) -> SmartTurnPredictionResult:
"""Predict whether an audio segment is complete (turn ended) or incomplete.
def _prepare_audio(self, audio_array: bytes, sample_rate: int, sample_width: int) -> np.ndarray:
"""Prepare the audio for inference.

Args:
audio_array: Numpy array containing audio samples at 16kHz. The function
will convert the audio into float32 and truncate to 8 seconds (keeping the end)
or pad to 8 seconds.
language: Language of the audio.
audio_array: Raw PCM bytes at 16kHz. The function converts the audio into float32 and
truncate to WINDOW_SECONDS (keeping the end).
sample_rate: Sample rate of the audio.
sample_width: Sample width of the audio.

Returns:
Prediction result containing completion status and probability.
Numpy array containing audio samples at DEFAULT_SAMPLE_RATE.
"""

# Check if initialized
if not self._is_initialized:
return SmartTurnPredictionResult(error="SmartTurnDetector is not initialized")

# Check a valid language
if not self.valid_language(language):
logger.warning(f"Invalid language: {language}. Results may be unreliable.")

# Record start time
start_time = datetime.datetime.now()

# Todo - fix support for other sample widths
# Convert into numpy array
dtype = np.int16 if sample_width == 2 else np.int8
int16_array: np.ndarray = np.frombuffer(audio_array, dtype=dtype).astype(np.int16)

# Truncate to last 8 seconds if needed (keep the tail/end of audio)
max_samples = 8 * sample_rate
# Truncate to last WINDOW_SECONDS seconds if needed (keep the tail/end of audio)
max_samples = self.WINDOW_SECONDS * sample_rate
if len(int16_array) > max_samples:
int16_array = int16_array[-max_samples:]

# Convert int16 to float32 in range [-1, 1] (same as reference implementation)
float32_array: np.ndarray = int16_array.astype(np.float32) / 32768.0

# Process audio using Whisper's feature extractor
return float32_array

def _get_input_features(self, audio_data: np.ndarray, sample_rate: int) -> np.ndarray:
"""
Get the input features for the audio data using Whisper's feature extractor.

Args:
audio_data: Numpy array containing audio samples.
sample_rate: Sample rate of the audio.
"""

inputs = self.feature_extractor(
float32_array,
audio_data,
sampling_rate=sample_rate,
return_tensors="np",
padding="max_length",
max_length=max_samples,
max_length=self.WINDOW_SECONDS * sample_rate,
truncation=True,
do_normalize=True,
)

# Extract features and ensure correct shape for ONNX
# Ensure dimensions are correct shape for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0)

# Run ONNX inference
outputs = self.session.run(None, {"input_features": input_features})
return input_features

# Extract probability (ONNX model returns sigmoid probabilities)
async def predict(
self, audio_array: bytes, language: str, sample_rate: int = DEFAULT_SAMPLE_RATE, sample_width: int = 2
) -> SmartTurnPredictionResult:
"""Predict whether an audio segment is complete (turn ended) or incomplete.

Args:
audio_array: Numpy array containing audio samples at sample_rate. The function
will convert the audio into float32 and truncate to WINDOW_SECONDS seconds (keeping the end)
or pad to WINDOW_SECONDS seconds.
language: Language of the audio.
sample_rate: Sample rate of the audio.
sample_width: Sample width of the audio.

Returns:
Prediction result containing completion status and probability.
"""

# Check if initialized
if not self._is_initialized:
return SmartTurnPredictionResult(error="SmartTurnDetector is not initialized")

# Check a valid language
if not self.valid_language(language):
logger.warning(f"Invalid language: {language}. Results may be unreliable.")

# Record start time
start_time = datetime.datetime.now()

# Convert the audio into required format
prepared_audio = self._prepare_audio(audio_array, sample_rate, sample_width)

# Feature extraction
input_features = self._get_input_features(prepared_audio, sample_rate)

# Model inference
outputs = self.session.run(None, {"input_features": input_features})
probability = outputs[0][0].item()

# Make prediction (True for Complete, False for Incomplete)
prediction = probability >= self._threshold

# Record end time
# Result Formatting
end_time = datetime.datetime.now()
duration = float((end_time - start_time).total_seconds())

# Return the result
return SmartTurnPredictionResult(
prediction=prediction,
probability=round(probability, 3),
processing_time=round(float((end_time - start_time).total_seconds()), 3),
processing_time=round(duration, 3),
)

@staticmethod
def truncate_audio_to_last_n_seconds(
audio_array: np.ndarray, n_seconds: float = 8.0, sample_rate: int = 16000
audio_array: np.ndarray, n_seconds: float = WINDOW_SECONDS, sample_rate: int = DEFAULT_SAMPLE_RATE
) -> np.ndarray:
"""Truncate audio to last n seconds or pad with zeros to meet n seconds.

Expand Down Expand Up @@ -270,7 +305,7 @@ def download_model() -> None:
If not, it will download the model from HuggingFace.
"""

# Check if model file exists
# Check if model file already exists
if SmartTurnDetector.model_exists():
return

Expand Down Expand Up @@ -300,7 +335,8 @@ def model_exists() -> bool:

@staticmethod
def valid_language(language: str) -> bool:
"""Check if the language is valid.
"""Check if the language is valid against list of supported languages
for the Pipecat model.

Args:
language: Language code to validate.
Expand Down
32 changes: 20 additions & 12 deletions sdk/voice/speechmatics/voice/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from ._models import SpeakerSegmentView
from ._models import SpeechFragment

# Constants
PAUSE_MIN_GAP_S = 0.1 # minimum gap in seconds to consider a pause
WPM_RECENT_WORD_WINDOW = 10 # number of recent words to estimate WPM from
WPM_VERY_SLOW_MAX = 80 # wpm < 80 => VERY_SLOW_SPEAKER
WPM_SLOW_MAX = 110 # 80 <= wpm < 110 => SLOW_SPEAKER
WPM_FAST_MIN = 250 # wpm > 250 => FAST_SPEAKER


class FragmentUtils:
"""Set of utility functions for working with SpeechFragment and SpeakerSegment objects."""
Expand Down Expand Up @@ -110,7 +117,7 @@ def segment_list_from_fragments(
speaker_groups.append([])
speaker_groups[-1].append(frag)

# Create SpeakerFragments objects
# Create SpeakerSegment objects
segments: list[SpeakerSegment] = []
for group in speaker_groups:
# Skip if the group is empty
Expand Down Expand Up @@ -143,7 +150,7 @@ def segment_list_from_fragments(
FragmentUtils.update_segment_text(session=session, segment=segment)
segments.append(segment)

# Return the grouped SpeakerFragments objects
# Return the grouped SpeakerSegment objects
return segments

@staticmethod
Expand Down Expand Up @@ -288,17 +295,18 @@ def _annotate_segment(segment: SpeakerSegment) -> AnnotationResult:
# Rate of speech
if len(words) > 1:
# Calculate the approximate words-per-minute (for last few words)
recent_words = words[-10:]
recent_words = words[-WPM_RECENT_WORD_WINDOW:]
word_time_span = recent_words[-1].end_time - recent_words[0].start_time
wpm = (len(recent_words) / word_time_span) * 60
if word_time_span != 0:
wpm = (len(recent_words) / word_time_span) * 60

# Categorize the speaker
if wpm < 80:
result.add(AnnotationFlags.VERY_SLOW_SPEAKER)
elif wpm < 110:
result.add(AnnotationFlags.SLOW_SPEAKER)
elif wpm > 250:
result.add(AnnotationFlags.FAST_SPEAKER)
# Categorize the speaker
if wpm < WPM_VERY_SLOW_MAX:
result.add(AnnotationFlags.VERY_SLOW_SPEAKER)
elif wpm < WPM_SLOW_MAX:
result.add(AnnotationFlags.SLOW_SPEAKER)
elif wpm > WPM_FAST_MIN:
result.add(AnnotationFlags.FAST_SPEAKER)

# Return the annotation result
return result
Expand Down Expand Up @@ -400,7 +408,7 @@ def find_segment_pauses(session: ClientSessionInfo, view: SpeakerSegmentView) ->
next_word = words[i + 1]
gap_start = word.end_time
gap_end = next_word.start_time
if gap_end - gap_start > 0.1:
if gap_end - gap_start > PAUSE_MIN_GAP_S:
segment.fragments.append(
SpeechFragment(
idx=word.idx + 1,
Expand Down
Loading