From db0d9310558d7ace73dfef3dccae9359ec36d173 Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Wed, 27 May 2026 00:45:11 +0800 Subject: [PATCH 1/7] Add voiceprint enrollment helpers --- aTrain/voiceprint_cli.py | 125 +++++++++++++++++++++++++++++++++++ tests/test_voiceprint_cli.py | 92 ++++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 aTrain/voiceprint_cli.py create mode 100644 tests/test_voiceprint_cli.py diff --git a/aTrain/voiceprint_cli.py b/aTrain/voiceprint_cli.py new file mode 100644 index 0000000..226c5de --- /dev/null +++ b/aTrain/voiceprint_cli.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import numpy as np +from aTrain_core.load_resources import get_model + +from aTrain.voiceprint_identification import extract_embedding +from aTrain.voiceprints import ( + EMBEDDING_MODEL_ID, + VOICEPRINT_SCHEMA_VERSION, + VoiceprintProfile, + load_voiceprint, + merge_centroid, + save_voiceprint, + validate_voiceprint_name, +) + + +def _utc_now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _normalise_vector(vector: np.ndarray) -> np.ndarray: + array = np.asarray(vector, dtype=np.float32) + if array.ndim != 1 or array.size == 0: + raise ValueError("Voiceprint embedding must be a non-empty 1-D vector.") + norm = float(np.linalg.norm(array)) + if norm <= 0: + raise ValueError("Voiceprint embedding cannot be a zero vector.") + return array / norm + + +def read_embedding_file(path: Path) -> tuple[list[str], np.ndarray]: + if not path.exists(): + raise FileNotFoundError(f"Speaker embedding file does not exist: {path}") + with np.load(path, allow_pickle=False) as payload: + labels = [str(item) for item in payload["labels"].tolist()] + embeddings = np.asarray(payload["embeddings"], dtype=np.float32) + if embeddings.ndim != 2 or len(labels) != embeddings.shape[0]: + raise ValueError("Speaker embedding file is invalid: labels and embeddings do not match.") + return labels, embeddings + + +def _save_new_or_update( + name: str, embedding: np.ndarray, enrollment: dict[str, Any], update: bool +) -> VoiceprintProfile: + cleaned_name = validate_voiceprint_name(name) + vector = _normalise_vector(embedding) + try: + existing = load_voiceprint(cleaned_name) + except FileNotFoundError: + profile = VoiceprintProfile( + name=cleaned_name, + model_id=EMBEDDING_MODEL_ID, + schema_version=VOICEPRINT_SCHEMA_VERSION, + embedding_dim=int(vector.shape[0]), + embedding=vector, + enrollments=[enrollment], + ) + save_voiceprint(profile) + return profile + + if not update: + raise FileExistsError(f"Voiceprint already exists: {cleaned_name}. Use --update to merge a new sample.") + merged = merge_centroid(existing.embedding, vector, max(1, len(existing.enrollments))) + profile = VoiceprintProfile( + name=existing.name, + model_id=existing.model_id, + schema_version=existing.schema_version, + embedding_dim=existing.embedding_dim, + embedding=merged, + enrollments=[*existing.enrollments, enrollment], + ) + save_voiceprint(profile) + return profile + + +def enroll_voiceprint_from_audio( + audio_path: Path, + name: str, + update: bool, + device, + min_duration_sec: float = 3.0, +) -> VoiceprintProfile: + model_path = get_model("speaker-detection") + embedding = extract_embedding(audio_path, model_path, device, min_duration_sec=min_duration_sec) + return _save_new_or_update( + name, + embedding, + { + "source_type": "audio", + "source_path": str(audio_path), + "created_at": _utc_now(), + }, + update, + ) + + +def enroll_voiceprint_from_speaker_embedding( + embedding_file: Path, + speaker_label: str, + name: str, + update: bool, + source: str | None = None, +) -> VoiceprintProfile: + labels, embeddings = read_embedding_file(embedding_file) + try: + index = labels.index(speaker_label) + except ValueError as error: + raise ValueError(f"Speaker label not found in embedding file: {speaker_label}") from error + return _save_new_or_update( + name, + embeddings[index], + { + "source_type": "speaker_embedding", + "source_path": str(embedding_file), + "speaker_label": speaker_label, + "source": source, + "created_at": _utc_now(), + }, + update, + ) diff --git a/tests/test_voiceprint_cli.py b/tests/test_voiceprint_cli.py new file mode 100644 index 0000000..e56f7c4 --- /dev/null +++ b/tests/test_voiceprint_cli.py @@ -0,0 +1,92 @@ +import unittest +from pathlib import Path +from unittest.mock import patch + +import numpy as np + +from aTrain.voiceprints import ( + EMBEDDING_MODEL_ID, + VOICEPRINT_SCHEMA_VERSION, + VoiceprintProfile, +) + + +class VoiceprintCliTests(unittest.TestCase): + def test_audio_enrollment_creates_new_profile(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_audio + + vector = np.array([1.0, 0.0], dtype=np.float32) + with patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + patch("aTrain.voiceprint_cli.extract_embedding", return_value=vector), \ + patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ + patch("aTrain.voiceprint_cli.save_voiceprint") as save: + profile = enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device="CPU") + + self.assertEqual(profile.name, "李想") + self.assertEqual(profile.schema_version, VOICEPRINT_SCHEMA_VERSION) + self.assertEqual(profile.model_id, EMBEDDING_MODEL_ID) + self.assertEqual(profile.embedding_dim, 2) + self.assertEqual(profile.enrollments[0]["source_type"], "audio") + save.assert_called_once() + + def test_existing_profile_without_update_fails(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_audio + + existing = VoiceprintProfile( + name="李想", + model_id=EMBEDDING_MODEL_ID, + schema_version=VOICEPRINT_SCHEMA_VERSION, + embedding_dim=2, + embedding=np.array([1.0, 0.0], dtype=np.float32), + enrollments=[], + ) + with patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([0.0, 1.0], dtype=np.float32)), \ + patch("aTrain.voiceprint_cli.load_voiceprint", return_value=existing): + with self.assertRaises(FileExistsError): + enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device="CPU") + + def test_existing_profile_with_update_merges_centroid(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_audio + + existing = VoiceprintProfile( + name="李想", + model_id=EMBEDDING_MODEL_ID, + schema_version=VOICEPRINT_SCHEMA_VERSION, + embedding_dim=2, + embedding=np.array([1.0, 0.0], dtype=np.float32), + enrollments=[{"source_type": "audio"}], + ) + with patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([0.0, 1.0], dtype=np.float32)), \ + patch("aTrain.voiceprint_cli.load_voiceprint", return_value=existing), \ + patch("aTrain.voiceprint_cli.save_voiceprint") as save: + profile = enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=True, device="CPU") + + np.testing.assert_allclose(profile.embedding, np.array([0.5, 0.5], dtype=np.float32)) + self.assertEqual(len(profile.enrollments), 2) + save.assert_called_once() + + def test_speaker_embedding_enrollment_selects_requested_label(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_speaker_embedding + + labels = ["SPEAKER_00", "SPEAKER_01"] + embeddings = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + with patch("aTrain.voiceprint_cli.read_embedding_file", return_value=(labels, embeddings)), \ + patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ + patch("aTrain.voiceprint_cli.save_voiceprint"): + profile = enroll_voiceprint_from_speaker_embedding(Path("speakers.npz"), "SPEAKER_01", "李想", update=False) + + np.testing.assert_allclose(profile.embedding, np.array([0.0, 1.0], dtype=np.float32)) + self.assertEqual(profile.enrollments[0]["speaker_label"], "SPEAKER_01") + + def test_speaker_embedding_enrollment_rejects_missing_label(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_speaker_embedding + + with patch("aTrain.voiceprint_cli.read_embedding_file", return_value=(["SPEAKER_00"], np.array([[1.0, 0.0]], dtype=np.float32))): + with self.assertRaisesRegex(ValueError, "Speaker label not found"): + enroll_voiceprint_from_speaker_embedding(Path("speakers.npz"), "SPEAKER_01", "李想", update=False) + + +if __name__ == "__main__": + unittest.main() From ecc796966227c9cf62a46625bcd23da7368e746e Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Wed, 27 May 2026 00:46:51 +0800 Subject: [PATCH 2/7] Add voiceprint enroll CLI command --- aTrain/cli.py | 77 ++++++++++++++++++++++++++++++ tests/test_cli_voiceprints.py | 89 +++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 tests/test_cli_voiceprints.py diff --git a/aTrain/cli.py b/aTrain/cli.py index 8a9b4ec..034d41d 100644 --- a/aTrain/cli.py +++ b/aTrain/cli.py @@ -23,8 +23,24 @@ check_file, check_inputs_transcribe, ) +from aTrain.voiceprint_cli import ( + enroll_voiceprint_from_audio, + enroll_voiceprint_from_speaker_embedding, +) +from aTrain.voiceprint_identification import ( + patch_core_speaker_capture, + read_captured_embeddings, +) +from aTrain.voiceprints import ( + apply_speaker_map_to_transcript, + assign_voiceprints, + cosine_similarity_matrix, + list_voiceprints, +) cli = typer.Typer(help="CLI for aTrain.", no_args_is_help=True) +voiceprint_cli = typer.Typer(help="Manage speaker voiceprints.", no_args_is_help=True) +cli.add_typer(voiceprint_cli, name="voiceprint") FORMAT_OUTPUTS = { "json": ("transcription.json", "{stem}.json"), @@ -329,6 +345,67 @@ def _print_summary(results: list[FileResult], skipped: list[Path]) -> None: typer.echo(f" - {result.path} - {result.reason}{staging}", err=True) +@voiceprint_cli.command("enroll") +def voiceprint_enroll( + name: Annotated[str, typer.Option("--name", help="Person name for the voiceprint.")], + audio: Annotated[ + Path | None, + typer.Option("--audio", help="Audio sample used for enrollment."), + ] = None, + speaker_embeddings: Annotated[ + Path | None, + typer.Option("--speaker-embeddings", help="NPZ speaker embedding artifact exported by transcribe."), + ] = None, + speaker: Annotated[ + str | None, + typer.Option("--speaker", help="Speaker label to enroll, such as SPEAKER_01."), + ] = None, + update: Annotated[ + bool, + typer.Option("--update", help="Merge into an existing profile."), + ] = False, + source: Annotated[ + str | None, + typer.Option("--source", help="Optional audit source stored with the enrollment."), + ] = None, + device: Annotated[Device, typer.Option(help="Hardware used for audio embedding extraction.")] = Device.CPU, + min_duration_sec: Annotated[ + float, + typer.Option("--min-duration-sec", help="Minimum audio duration accepted for direct audio enrollment.", min=0.1), + ] = 3.0, +): + """Create or update a local speaker voiceprint.""" + try: + source_count = int(audio is not None) + int(speaker_embeddings is not None) + if source_count != 1: + raise ValueError("Provide exactly one of --audio or --speaker-embeddings.") + if speaker_embeddings is not None and not speaker: + raise ValueError("--speaker is required with --speaker-embeddings.") + if audio is not None and speaker: + raise ValueError("--speaker can only be used with --speaker-embeddings.") + if audio is not None: + profile = enroll_voiceprint_from_audio( + audio, + name, + update=update, + device=device, + min_duration_sec=min_duration_sec, + ) + else: + profile = enroll_voiceprint_from_speaker_embedding( + speaker_embeddings, + speaker, + name, + update=update, + source=source, + ) + except Exception as error: + typer.echo(str(error), err=True) + raise typer.Exit(code=2) from error + + typer.echo(f"Voiceprint enrolled: {getattr(profile, 'name', name)}") + + @cli.command() def transcribe( input_path: Annotated[ diff --git a/tests/test_cli_voiceprints.py b/tests/test_cli_voiceprints.py new file mode 100644 index 0000000..24cc17b --- /dev/null +++ b/tests/test_cli_voiceprints.py @@ -0,0 +1,89 @@ +import tempfile +import unittest +from pathlib import Path +from unittest import mock + +from typer.testing import CliRunner + +from aTrain.cli import cli + + +class CliVoiceprintTests(unittest.TestCase): + def test_identify_speakers_requires_speaker_detection(self): + runner = CliRunner() + + result = runner.invoke( + cli, + [ + "transcribe", + "missing.wav", + "--no-speaker-detection", + "--identify-speakers", + ], + ) + + self.assertEqual(result.exit_code, 2) + self.assertIn("requires --speaker-detection", result.stderr) + + def test_voiceprint_enroll_rejects_missing_source(self): + runner = CliRunner() + + result = runner.invoke(cli, ["voiceprint", "enroll", "--name", "李想"]) + + self.assertEqual(result.exit_code, 2) + self.assertIn("Provide exactly one of --audio or --speaker-embeddings", result.stderr) + + def test_voiceprint_enroll_rejects_both_sources(self): + runner = CliRunner() + + result = runner.invoke( + cli, + [ + "voiceprint", + "enroll", + "--name", + "李想", + "--audio", + "sample.wav", + "--speaker-embeddings", + "speakers.npz", + "--speaker", + "SPEAKER_01", + ], + ) + + self.assertEqual(result.exit_code, 2) + self.assertIn("Provide exactly one of --audio or --speaker-embeddings", result.stderr) + + def test_voiceprint_enroll_from_speaker_embedding_calls_helper(self): + runner = CliRunner() + called = {} + + def fake_enroll(path, speaker, name, update, source=None): + called.update(path=path, speaker=speaker, name=name, update=update, source=source) + + with tempfile.TemporaryDirectory() as temp_dir: + embedding_file = Path(temp_dir) / "speakers.npz" + embedding_file.write_bytes(b"npz-stub") + with mock.patch("aTrain.cli.enroll_voiceprint_from_speaker_embedding", side_effect=fake_enroll): + result = runner.invoke( + cli, + [ + "voiceprint", + "enroll", + "--name", + "李想", + "--speaker-embeddings", + str(embedding_file), + "--speaker", + "SPEAKER_01", + "--update", + "--source", + "recording:abc", + ], + ) + + self.assertEqual(result.exit_code, 0) + self.assertEqual(called["speaker"], "SPEAKER_01") + self.assertEqual(called["name"], "李想") + self.assertIs(called["update"], True) From 72efb4bf1cca73e13f6f64f0e52450b718844648 Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Wed, 27 May 2026 00:48:39 +0800 Subject: [PATCH 3/7] Export speaker embeddings from CLI transcription --- aTrain/cli.py | 148 +++++++++++++++++++++++++++++++++- tests/test_cli_voiceprints.py | 28 +++++++ 2 files changed, 175 insertions(+), 1 deletion(-) diff --git a/aTrain/cli.py b/aTrain/cli.py index 034d41d..ddfba6a 100644 --- a/aTrain/cli.py +++ b/aTrain/cli.py @@ -1,3 +1,4 @@ +import json import shutil import sys import tempfile @@ -28,6 +29,7 @@ enroll_voiceprint_from_speaker_embedding, ) from aTrain.voiceprint_identification import ( + CAPTURE_FILENAME, patch_core_speaker_capture, read_captured_embeddings, ) @@ -192,6 +194,41 @@ def _copy_outputs( shutil.copy2(source_dir / planned.source_name, planned.target_path) +def _copy_speaker_embeddings( + staging_dir: Path, file_id: str, output_path: Path, overwrite: bool = True +) -> Path: + source = staging_dir / file_id / CAPTURE_FILENAME + if not source.exists(): + raise FileNotFoundError( + "Speaker embeddings were not captured. Use --speaker-detection and --identify-speakers." + ) + if output_path.exists() and not overwrite: + raise FileExistsError( + f"Target file exists: {output_path}. Use --overwrite to replace it." + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(source, output_path) + return output_path + + +def _postprocess_staged_outputs( + staging_dir: Path, + file_id: str, + speaker_detection: bool, + speaker_map: dict[str, str] | None, +) -> None: + if not speaker_map: + return + + from aTrain_core import outputs as core_outputs + + transcript_path = staging_dir / file_id / "transcription.json" + with transcript_path.open("r", encoding="utf-8") as handle: + transcript = json.load(handle) + apply_speaker_map_to_transcript(transcript, speaker_map) + core_outputs.create_output_files(transcript, speaker_detection, file_id) + + def _transcribe_one( item: InputFile, output_plan: list[OutputPlan], @@ -204,6 +241,10 @@ def _transcribe_one( compute_type: ComputeType, temperature: float | None, prompt: str | None, + identify_speakers: bool, + voiceprint_threshold: float, + voiceprint_margin: float, + speaker_embeddings_output: Path | None, cpu_threads: int, ) -> Path: for planned in output_plan: @@ -211,6 +252,10 @@ def _transcribe_one( raise FileExistsError( f"Target file exists: {planned.target_path}. Use --overwrite to replace it." ) + if speaker_embeddings_output is not None and speaker_embeddings_output.exists() and not overwrite: + raise FileExistsError( + f"Target file exists: {speaker_embeddings_output}. Use --overwrite to replace it." + ) from aTrain_core import outputs as core_outputs from aTrain_core.transcribe import prepare_transcription @@ -247,7 +292,30 @@ def _transcribe_one( progress={}, cpu_threads=cpu_threads, ) - transcribe_core(settings) + capture_speakers = identify_speakers and speaker_detection + if capture_speakers: + with patch_core_speaker_capture(): + transcribe_core(settings) + else: + transcribe_core(settings) + + speaker_map = _identify_captured_speakers( + staging_dir=staging_dir, + file_id=file_id, + enabled=capture_speakers, + threshold=voiceprint_threshold, + margin=voiceprint_margin, + ) + if speaker_embeddings_output is not None: + _copy_speaker_embeddings( + staging_dir, file_id, speaker_embeddings_output, overwrite=overwrite + ) + _postprocess_staged_outputs( + staging_dir, + file_id, + speaker_detection, + speaker_map, + ) _copy_outputs(staging_dir, file_id, output_plan, overwrite) shutil.rmtree(staging_dir, ignore_errors=True) return staging_dir @@ -259,6 +327,35 @@ def _transcribe_one( shutil.rmtree(gpu_log_dir, ignore_errors=True) +def _identify_captured_speakers( + staging_dir: Path, + file_id: str, + enabled: bool, + threshold: float, + margin: float, +) -> dict[str, str] | None: + if not enabled: + return None + captured = read_captured_embeddings(staging_dir, file_id) + if captured is None: + return None + + labels, embeddings = captured + voiceprints = list_voiceprints() + if not labels or not voiceprints: + return None + + scores = cosine_similarity_matrix(embeddings, voiceprints) + speaker_map = assign_voiceprints( + scores, + labels, + [profile.name for profile in voiceprints], + threshold, + margin, + ) + return speaker_map or None + + def _run_batch( inputs: list[InputFile], skipped: list[Path], @@ -273,6 +370,10 @@ def _run_batch( compute_type: ComputeType, temperature: float | None, prompt: str | None, + identify_speakers: bool, + voiceprint_threshold: float, + voiceprint_margin: float, + speaker_embeddings_output: Path | None, cpu_threads: int, ) -> int: results: list[FileResult] = [] @@ -296,6 +397,10 @@ def _run_batch( compute_type=compute_type, temperature=temperature, prompt=prompt, + identify_speakers=identify_speakers, + voiceprint_threshold=voiceprint_threshold, + voiceprint_margin=voiceprint_margin, + speaker_embeddings_output=speaker_embeddings_output, cpu_threads=cpu_threads, ) elapsed = int(time.monotonic() - started) @@ -427,6 +532,35 @@ def transcribe( int, typer.Option(help="Number of speakers. Use 0 to let aTrain auto-detect."), ] = 0, + identify_speakers: Annotated[ + bool, + typer.Option( + "--identify-speakers/--no-identify-speakers", + help="Rename diarized SPEAKER_xx labels using enrolled voiceprints.", + ), + ] = True, + voiceprint_threshold: Annotated[ + float, + typer.Option( + "--voiceprint-threshold", + help="Minimum cosine similarity required for voiceprint identification.", + min=0.0, + max=1.0, + ), + ] = 0.5, + voiceprint_margin: Annotated[ + float, + typer.Option( + "--voiceprint-margin", + help="Minimum score gap over competing speaker/name assignments.", + min=0.0, + max=1.0, + ), + ] = 0.05, + speaker_embeddings_output: Annotated[ + Path | None, + typer.Option("--speaker-embeddings-output", help="Write captured speaker embeddings to this NPZ path."), + ] = None, device: Annotated[Device, typer.Option(help="Hardware used to transcribe.")] = Device.GPU, compute_type: Annotated[ ComputeType, typer.Option(help="Data type used in computations.") @@ -463,6 +597,14 @@ def transcribe( try: selected_formats = _parse_formats(formats) inputs, skipped = _collect_inputs(input_path, recursive) + if identify_speakers and not speaker_detection: + raise ValueError("--identify-speakers requires --speaker-detection.") + if speaker_embeddings_output is not None and not speaker_detection: + raise ValueError("--speaker-embeddings-output requires --speaker-detection.") + if speaker_embeddings_output is not None and not identify_speakers: + raise ValueError("--speaker-embeddings-output requires --identify-speakers.") + if speaker_embeddings_output is not None and len(inputs) != 1: + raise ValueError("--speaker-embeddings-output supports single-file input only.") _check_model_downloaded(model) if speaker_detection: _check_model_downloaded("speaker-detection") @@ -492,6 +634,10 @@ def transcribe( compute_type=compute_type, temperature=temperature, prompt=prompt, + identify_speakers=identify_speakers, + voiceprint_threshold=voiceprint_threshold, + voiceprint_margin=voiceprint_margin, + speaker_embeddings_output=speaker_embeddings_output, cpu_threads=cpu_threads, ) raise typer.Exit(code=exit_code) diff --git a/tests/test_cli_voiceprints.py b/tests/test_cli_voiceprints.py index 24cc17b..334e74b 100644 --- a/tests/test_cli_voiceprints.py +++ b/tests/test_cli_voiceprints.py @@ -87,3 +87,31 @@ def fake_enroll(path, speaker, name, update, source=None): self.assertEqual(called["speaker"], "SPEAKER_01") self.assertEqual(called["name"], "李想") self.assertIs(called["update"], True) + + def test_copy_speaker_embeddings_writes_requested_output(self): + from aTrain import cli as cli_module + + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + staging_dir = root / "staging" + file_dir = staging_dir / "file-id" + file_dir.mkdir(parents=True) + captured = file_dir / "_speaker_embeddings.npz" + captured.write_bytes(b"npz-data") + output_path = root / "exported.npz" + + result = cli_module._copy_speaker_embeddings(staging_dir, "file-id", output_path) + + self.assertEqual(output_path.read_bytes(), b"npz-data") + self.assertEqual(result, output_path) + + def test_copy_speaker_embeddings_fails_when_capture_missing(self): + from aTrain import cli as cli_module + + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + staging_dir = root / "staging" + (staging_dir / "file-id").mkdir(parents=True) + + with self.assertRaisesRegex(FileNotFoundError, "Speaker embeddings were not captured"): + cli_module._copy_speaker_embeddings(staging_dir, "file-id", root / "exported.npz") From 6a2f69d48b7824461f0c879cb901da813dbb1939 Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Wed, 27 May 2026 00:49:47 +0800 Subject: [PATCH 4/7] Document CLI voiceprint enrollment --- README.md | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 108c23a..5347bd9 100644 --- a/README.md +++ b/README.md @@ -27,16 +27,93 @@ aTrain-cli.exe --help python -m aTrain.cli --help ``` -The CLI provides two commands: +The CLI provides these commands: ```powershell aTrain-cli init aTrain-cli transcribe INPUT [OPTIONS] +aTrain-cli voiceprint enroll [OPTIONS] ``` `aTrain-cli init` downloads the default transcription model and speaker-detection model used by `transcribe`. `INPUT` can be a single audio/video file or a directory. Directory input scans only the top-level directory by default; pass `--recursive` to include subdirectories. Outputs are copied from a temporary transcription workspace into the selected output directory. Existing output files are kept by default; pass `--overwrite` to replace them. -Example: +| Option | Type | Default | Notes | +| --- | --- | --- | --- | +| `--model` | string | `large-v3` | Whisper model name. | +| `--language` | string | `auto-detect` | Language code or `auto-detect`. | +| `--speaker-detection / --no-speaker-detection` | bool | `True` | Enables pyannote speaker detection. | +| `--speaker-count` | integer | `0` | `0` means auto-detect speaker count. | +| `--identify-speakers / --no-identify-speakers` | bool | `True` | Renames diarized `SPEAKER_xx` labels with enrolled voiceprints. Requires `--speaker-detection`; no-ops when no voiceprints are enrolled. | +| `--voiceprint-threshold` | float | `0.5` | Minimum cosine similarity required for a voiceprint match. | +| `--voiceprint-margin` | float | `0.05` | Minimum score gap over competing speaker/name assignments. | +| `--speaker-embeddings-output` | file | `None` | Writes captured per-speaker embeddings to an `.npz` file. Requires `--speaker-detection` and `--identify-speakers`; single-file input only. | +| `--device` | `cpu`, `gpu` | `gpu` | Hardware backend. | +| `--compute-type` | `int8`, `float16`, `float32` | `float32` | Model compute precision. | +| `--temperature` | float | `None` | Optional sampling temperature, `0.0` to `1.0`. | +| `--prompt` | string | `None` | Optional initial prompt for Whisper. | +| `--cpu-threads` | integer | `aTrain_core.globals.DEFAULT_CPU_THREADS` | `0` means automatic CPU thread selection. | +| `--recursive / --no-recursive` | bool | `False` | Applies only when `INPUT` is a directory. | +| `--formats` | CSV | `txt,timestamps` | Allowed values: `json`, `txt`, `timestamps`, `maxqda`, `srt`. | +| `--output` | directory | `./atrain-output` | Fallback output directory for all selected formats. | +| `--json-output` | directory | fallback to `--output` | Dedicated directory for JSON output. | +| `--txt-output` | directory | fallback to `--output` | Dedicated directory for plain text output. | +| `--timestamps-output` | directory | fallback to `--output` | Dedicated directory for timestamped text output. | +| `--maxqda-output` | directory | fallback to `--output` | Dedicated directory for MAXQDA output. | +| `--srt-output` | directory | fallback to `--output` | Dedicated directory for SRT output. | +| `--overwrite / --no-overwrite` | bool | `False` | Existing target files are kept by default; use `--overwrite` to replace them. | + +### Output Contract + +Output filenames are derived from the input file stem. For an input file named `interview01.wav`, the selected formats are written as: + +| Format | Output filename | +| --- | --- | +| `json` | `interview01.json` | +| `txt` | `interview01.txt` | +| `timestamps` | `interview01_timestamps.txt` | +| `maxqda` | `interview01_maxqda.txt` | +| `srt` | `interview01.srt` | + +For recursive directory input, the input folder's relative subdirectory structure is preserved below each output directory. This prevents collisions when different subdirectories contain files with the same stem. Top-level directory input without `--recursive` writes all selected files directly into the chosen output directories. + +### Model Initialization + +Use `init` to download models for both CLI and GUI use: + +```powershell +aTrain-cli init large-v3 +aTrain-cli init speaker-detection +aTrain-cli init all +``` + +Because `transcribe` defaults to `--model large-v3` and `--speaker-detection`, a fresh environment needs both `large-v3` and `speaker-detection` before the default transcription command can run. A model is treated as available when its model directory exists and contains at least one `.bin` file, including nested `.bin` files. + +### Speaker Voiceprints + +The GUI provides a `Voiceprints` page for enrolling and managing persistent speaker profiles. Each profile is stored as a JSON file below the local aTrain data directory's `voiceprints` folder. Enrollment uses the local `speaker-detection/embedding` model; it does not upload reference audio. + +### CLI voiceprint enrollment + +The CLI can create or update local voiceprint profiles. Profiles are stored in the same local voiceprint directory used by the GUI, and `transcribe --identify-speakers` consumes those profiles during later transcription runs. + +Enroll from a direct audio sample: + +```powershell +aTrain-cli voiceprint enroll --name "李想" --audio "D:\samples\li-xiang.wav" --update +``` + +Enroll from a captured speaker embedding exported during transcription: + +```powershell +aTrain-cli transcribe "D:\input\meeting.wav" --speaker-detection --identify-speakers --speaker-embeddings-output "D:\out\meeting.speaker-embeddings.npz" +aTrain-cli voiceprint enroll --name "李想" --speaker-embeddings "D:\out\meeting.speaker-embeddings.npz" --speaker SPEAKER_01 --update +``` + +If a diarized speaker matches an enrolled profile above `--voiceprint-threshold` and above the competing-match `--voiceprint-margin`, output speaker fields are rewritten from labels such as `SPEAKER_00` to the enrolled name. Low-confidence matches remain as `SPEAKER_xx`; tune `--voiceprint-threshold` and `--voiceprint-margin` when needed. + +### CLI Examples + +Transcribe one file with the default outputs: ```powershell aTrain-cli transcribe "D:\media\interview01.wav" --output "D:\transcripts" From 19c47536a691da3633225142169ac15819d133c3 Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Wed, 27 May 2026 01:35:35 +0800 Subject: [PATCH 5/7] Fix CLI voiceprint review issues --- aTrain/voiceprint_cli.py | 26 ++++++++++++++- tests/test_cli_voiceprints.py | 19 +++++++++++ tests/test_voiceprint_cli.py | 61 +++++++++++++++++++++++++++++++++-- 3 files changed, 102 insertions(+), 4 deletions(-) diff --git a/aTrain/voiceprint_cli.py b/aTrain/voiceprint_cli.py index 226c5de..8428744 100644 --- a/aTrain/voiceprint_cli.py +++ b/aTrain/voiceprint_cli.py @@ -5,8 +5,11 @@ from typing import Any import numpy as np +import torch from aTrain_core.load_resources import get_model +from aTrain_core.settings import Device +from aTrain.model_downloads import check_model_downloaded from aTrain.voiceprint_identification import extract_embedding from aTrain.voiceprints import ( EMBEDDING_MODEL_ID, @@ -44,6 +47,20 @@ def read_embedding_file(path: Path) -> tuple[list[str], np.ndarray]: return labels, embeddings +def _resolve_embedding_device(device) -> torch.device: + if isinstance(device, torch.device): + return device + + value = str(device).lower() + if value in {Device.GPU.value, "gpu", "cuda"}: + if not torch.cuda.is_available(): + raise ValueError("GPU is not available. Please choose CPU instead.") + return torch.device("cuda") + if value in {Device.CPU.value, "cpu"}: + return torch.device("cpu") + raise ValueError(f"Unsupported voiceprint enrollment device: {device}") + + def _save_new_or_update( name: str, embedding: np.ndarray, enrollment: dict[str, Any], update: bool ) -> VoiceprintProfile: @@ -85,8 +102,15 @@ def enroll_voiceprint_from_audio( device, min_duration_sec: float = 3.0, ) -> VoiceprintProfile: + check_model_downloaded("speaker-detection") + embedding_device = _resolve_embedding_device(device) model_path = get_model("speaker-detection") - embedding = extract_embedding(audio_path, model_path, device, min_duration_sec=min_duration_sec) + embedding = extract_embedding( + audio_path, + model_path, + embedding_device, + min_duration_sec=min_duration_sec, + ) return _save_new_or_update( name, embedding, diff --git a/tests/test_cli_voiceprints.py b/tests/test_cli_voiceprints.py index 334e74b..27a9fa4 100644 --- a/tests/test_cli_voiceprints.py +++ b/tests/test_cli_voiceprints.py @@ -115,3 +115,22 @@ def test_copy_speaker_embeddings_fails_when_capture_missing(self): with self.assertRaisesRegex(FileNotFoundError, "Speaker embeddings were not captured"): cli_module._copy_speaker_embeddings(staging_dir, "file-id", root / "exported.npz") + + def test_copy_speaker_embeddings_respects_no_overwrite(self): + from aTrain import cli as cli_module + + with tempfile.TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + staging_dir = root / "staging" + file_dir = staging_dir / "file-id" + file_dir.mkdir(parents=True) + (file_dir / "_speaker_embeddings.npz").write_bytes(b"new-data") + output_path = root / "exported.npz" + output_path.write_bytes(b"old-data") + + with self.assertRaisesRegex(FileExistsError, "Target file exists"): + cli_module._copy_speaker_embeddings( + staging_dir, "file-id", output_path, overwrite=False + ) + + self.assertEqual(output_path.read_bytes(), b"old-data") diff --git a/tests/test_voiceprint_cli.py b/tests/test_voiceprint_cli.py index e56f7c4..14a77b2 100644 --- a/tests/test_voiceprint_cli.py +++ b/tests/test_voiceprint_cli.py @@ -3,6 +3,8 @@ from unittest.mock import patch import numpy as np +import torch +from aTrain_core.settings import Device from aTrain.voiceprints import ( EMBEDDING_MODEL_ID, @@ -16,7 +18,8 @@ def test_audio_enrollment_creates_new_profile(self): from aTrain.voiceprint_cli import enroll_voiceprint_from_audio vector = np.array([1.0, 0.0], dtype=np.float32) - with patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ patch("aTrain.voiceprint_cli.extract_embedding", return_value=vector), \ patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ patch("aTrain.voiceprint_cli.save_voiceprint") as save: @@ -40,7 +43,8 @@ def test_existing_profile_without_update_fails(self): embedding=np.array([1.0, 0.0], dtype=np.float32), enrollments=[], ) - with patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([0.0, 1.0], dtype=np.float32)), \ patch("aTrain.voiceprint_cli.load_voiceprint", return_value=existing): with self.assertRaises(FileExistsError): @@ -57,7 +61,8 @@ def test_existing_profile_with_update_merges_centroid(self): embedding=np.array([1.0, 0.0], dtype=np.float32), enrollments=[{"source_type": "audio"}], ) - with patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([0.0, 1.0], dtype=np.float32)), \ patch("aTrain.voiceprint_cli.load_voiceprint", return_value=existing), \ patch("aTrain.voiceprint_cli.save_voiceprint") as save: @@ -87,6 +92,56 @@ def test_speaker_embedding_enrollment_rejects_missing_label(self): with self.assertRaisesRegex(ValueError, "Speaker label not found"): enroll_voiceprint_from_speaker_embedding(Path("speakers.npz"), "SPEAKER_01", "李想", update=False) + def test_audio_enrollment_checks_model_before_resolving_path(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_audio + + with patch("aTrain.voiceprint_cli.check_model_downloaded", side_effect=FileNotFoundError("missing")) as check, \ + patch("aTrain.voiceprint_cli.get_model") as get_model, \ + patch("aTrain.voiceprint_cli.extract_embedding") as extract: + with self.assertRaisesRegex(FileNotFoundError, "missing"): + enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device=Device.CPU) + + check.assert_called_once_with("speaker-detection") + get_model.assert_not_called() + extract.assert_not_called() + + def test_audio_enrollment_maps_cli_cpu_device_to_torch_device(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_audio + + with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([1.0, 0.0], dtype=np.float32)) as extract, \ + patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ + patch("aTrain.voiceprint_cli.save_voiceprint"): + enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device=Device.CPU) + + self.assertEqual(extract.call_args.args[2], torch.device("cpu")) + + def test_audio_enrollment_maps_cli_gpu_device_to_cuda_when_available(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_audio + + with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ + patch("aTrain.voiceprint_cli.torch.cuda.is_available", return_value=True), \ + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([1.0, 0.0], dtype=np.float32)) as extract, \ + patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ + patch("aTrain.voiceprint_cli.save_voiceprint"): + enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device=Device.GPU) + + self.assertEqual(extract.call_args.args[2], torch.device("cuda")) + + def test_audio_enrollment_rejects_unavailable_gpu_before_extracting(self): + from aTrain.voiceprint_cli import enroll_voiceprint_from_audio + + with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ + patch("aTrain.voiceprint_cli.torch.cuda.is_available", return_value=False), \ + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ + patch("aTrain.voiceprint_cli.extract_embedding") as extract: + with self.assertRaisesRegex(ValueError, "GPU is not available"): + enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device=Device.GPU) + + extract.assert_not_called() + if __name__ == "__main__": unittest.main() From 365a43ad9a7be87ce7c76d9f8b781b410c25e60d Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Wed, 27 May 2026 13:14:10 +0800 Subject: [PATCH 6/7] Fix CLI voiceprint validation order --- aTrain/cli.py | 54 +++++++----- aTrain/voiceprint_cli.py | 8 +- tests/test_cli_voiceprints.py | 7 +- tests/test_voiceprint_cli.py | 161 +++++++++++++++++++++++----------- 4 files changed, 153 insertions(+), 77 deletions(-) diff --git a/aTrain/cli.py b/aTrain/cli.py index ddfba6a..c6322c3 100644 --- a/aTrain/cli.py +++ b/aTrain/cli.py @@ -24,6 +24,7 @@ check_file, check_inputs_transcribe, ) + from aTrain.voiceprint_cli import ( enroll_voiceprint_from_audio, enroll_voiceprint_from_speaker_embedding, @@ -203,9 +204,7 @@ def _copy_speaker_embeddings( "Speaker embeddings were not captured. Use --speaker-detection and --identify-speakers." ) if output_path.exists() and not overwrite: - raise FileExistsError( - f"Target file exists: {output_path}. Use --overwrite to replace it." - ) + raise FileExistsError(f"Target file exists: {output_path}. Use --overwrite to replace it.") output_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(source, output_path) return output_path @@ -241,18 +240,22 @@ def _transcribe_one( compute_type: ComputeType, temperature: float | None, prompt: str | None, - identify_speakers: bool, - voiceprint_threshold: float, - voiceprint_margin: float, - speaker_embeddings_output: Path | None, - cpu_threads: int, + identify_speakers: bool = True, + voiceprint_threshold: float = 0.5, + voiceprint_margin: float = 0.05, + speaker_embeddings_output: Path | None = None, + cpu_threads: int = DEFAULT_CPU_THREADS, ) -> Path: for planned in output_plan: if planned.target_path.exists() and not overwrite: raise FileExistsError( f"Target file exists: {planned.target_path}. Use --overwrite to replace it." ) - if speaker_embeddings_output is not None and speaker_embeddings_output.exists() and not overwrite: + if ( + speaker_embeddings_output is not None + and speaker_embeddings_output.exists() + and not overwrite + ): raise FileExistsError( f"Target file exists: {speaker_embeddings_output}. Use --overwrite to replace it." ) @@ -369,12 +372,12 @@ def _run_batch( device: Device, compute_type: ComputeType, temperature: float | None, - prompt: str | None, - identify_speakers: bool, - voiceprint_threshold: float, - voiceprint_margin: float, - speaker_embeddings_output: Path | None, - cpu_threads: int, + prompt: str | None = None, + identify_speakers: bool = True, + voiceprint_threshold: float = 0.5, + voiceprint_margin: float = 0.05, + speaker_embeddings_output: Path | None = None, + cpu_threads: int = DEFAULT_CPU_THREADS, ) -> int: results: list[FileResult] = [] total = len(inputs) @@ -459,7 +462,9 @@ def voiceprint_enroll( ] = None, speaker_embeddings: Annotated[ Path | None, - typer.Option("--speaker-embeddings", help="NPZ speaker embedding artifact exported by transcribe."), + typer.Option( + "--speaker-embeddings", help="NPZ speaker embedding artifact exported by transcribe." + ), ] = None, speaker: Annotated[ str | None, @@ -473,10 +478,16 @@ def voiceprint_enroll( str | None, typer.Option("--source", help="Optional audit source stored with the enrollment."), ] = None, - device: Annotated[Device, typer.Option(help="Hardware used for audio embedding extraction.")] = Device.CPU, + device: Annotated[ + Device, typer.Option(help="Hardware used for audio embedding extraction.") + ] = Device.CPU, min_duration_sec: Annotated[ float, - typer.Option("--min-duration-sec", help="Minimum audio duration accepted for direct audio enrollment.", min=0.1), + typer.Option( + "--min-duration-sec", + help="Minimum audio duration accepted for direct audio enrollment.", + min=0.1, + ), ] = 3.0, ): """Create or update a local speaker voiceprint.""" @@ -559,7 +570,10 @@ def transcribe( ] = 0.05, speaker_embeddings_output: Annotated[ Path | None, - typer.Option("--speaker-embeddings-output", help="Write captured speaker embeddings to this NPZ path."), + typer.Option( + "--speaker-embeddings-output", + help="Write captured speaker embeddings to this NPZ path.", + ), ] = None, device: Annotated[Device, typer.Option(help="Hardware used to transcribe.")] = Device.GPU, compute_type: Annotated[ @@ -596,13 +610,13 @@ def transcribe( """Transcribe a single file or a directory of files.""" try: selected_formats = _parse_formats(formats) - inputs, skipped = _collect_inputs(input_path, recursive) if identify_speakers and not speaker_detection: raise ValueError("--identify-speakers requires --speaker-detection.") if speaker_embeddings_output is not None and not speaker_detection: raise ValueError("--speaker-embeddings-output requires --speaker-detection.") if speaker_embeddings_output is not None and not identify_speakers: raise ValueError("--speaker-embeddings-output requires --identify-speakers.") + inputs, skipped = _collect_inputs(input_path, recursive) if speaker_embeddings_output is not None and len(inputs) != 1: raise ValueError("--speaker-embeddings-output supports single-file input only.") _check_model_downloaded(model) diff --git a/aTrain/voiceprint_cli.py b/aTrain/voiceprint_cli.py index 8428744..faf2118 100644 --- a/aTrain/voiceprint_cli.py +++ b/aTrain/voiceprint_cli.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -23,7 +23,7 @@ def _utc_now() -> str: - return datetime.now(timezone.utc).isoformat() + return datetime.now(UTC).isoformat() def _normalise_vector(vector: np.ndarray) -> np.ndarray: @@ -81,7 +81,9 @@ def _save_new_or_update( return profile if not update: - raise FileExistsError(f"Voiceprint already exists: {cleaned_name}. Use --update to merge a new sample.") + raise FileExistsError( + f"Voiceprint already exists: {cleaned_name}. Use --update to merge a new sample." + ) merged = merge_centroid(existing.embedding, vector, max(1, len(existing.enrollments))) profile = VoiceprintProfile( name=existing.name, diff --git a/tests/test_cli_voiceprints.py b/tests/test_cli_voiceprints.py index 27a9fa4..c294bdd 100644 --- a/tests/test_cli_voiceprints.py +++ b/tests/test_cli_voiceprints.py @@ -3,9 +3,8 @@ from pathlib import Path from unittest import mock -from typer.testing import CliRunner - from aTrain.cli import cli +from typer.testing import CliRunner class CliVoiceprintTests(unittest.TestCase): @@ -65,7 +64,9 @@ def fake_enroll(path, speaker, name, update, source=None): with tempfile.TemporaryDirectory() as temp_dir: embedding_file = Path(temp_dir) / "speakers.npz" embedding_file.write_bytes(b"npz-stub") - with mock.patch("aTrain.cli.enroll_voiceprint_from_speaker_embedding", side_effect=fake_enroll): + with mock.patch( + "aTrain.cli.enroll_voiceprint_from_speaker_embedding", side_effect=fake_enroll + ): result = runner.invoke( cli, [ diff --git a/tests/test_voiceprint_cli.py b/tests/test_voiceprint_cli.py index 14a77b2..1f9f593 100644 --- a/tests/test_voiceprint_cli.py +++ b/tests/test_voiceprint_cli.py @@ -4,13 +4,12 @@ import numpy as np import torch -from aTrain_core.settings import Device - from aTrain.voiceprints import ( EMBEDDING_MODEL_ID, VOICEPRINT_SCHEMA_VERSION, VoiceprintProfile, ) +from aTrain_core.settings import Device class VoiceprintCliTests(unittest.TestCase): @@ -18,12 +17,18 @@ def test_audio_enrollment_creates_new_profile(self): from aTrain.voiceprint_cli import enroll_voiceprint_from_audio vector = np.array([1.0, 0.0], dtype=np.float32) - with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ - patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ - patch("aTrain.voiceprint_cli.extract_embedding", return_value=vector), \ - patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ - patch("aTrain.voiceprint_cli.save_voiceprint") as save: - profile = enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device="CPU") + with ( + patch("aTrain.voiceprint_cli.check_model_downloaded"), + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), + patch("aTrain.voiceprint_cli.extract_embedding", return_value=vector), + patch( + "aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing") + ), + patch("aTrain.voiceprint_cli.save_voiceprint") as save, + ): + profile = enroll_voiceprint_from_audio( + Path("sample.wav"), "李想", update=False, device="CPU" + ) self.assertEqual(profile.name, "李想") self.assertEqual(profile.schema_version, VOICEPRINT_SCHEMA_VERSION) @@ -43,12 +48,17 @@ def test_existing_profile_without_update_fails(self): embedding=np.array([1.0, 0.0], dtype=np.float32), enrollments=[], ) - with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ - patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ - patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([0.0, 1.0], dtype=np.float32)), \ - patch("aTrain.voiceprint_cli.load_voiceprint", return_value=existing): - with self.assertRaises(FileExistsError): - enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device="CPU") + with ( + patch("aTrain.voiceprint_cli.check_model_downloaded"), + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), + patch( + "aTrain.voiceprint_cli.extract_embedding", + return_value=np.array([0.0, 1.0], dtype=np.float32), + ), + patch("aTrain.voiceprint_cli.load_voiceprint", return_value=existing), + self.assertRaises(FileExistsError), + ): + enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device="CPU") def test_existing_profile_with_update_merges_centroid(self): from aTrain.voiceprint_cli import enroll_voiceprint_from_audio @@ -61,12 +71,19 @@ def test_existing_profile_with_update_merges_centroid(self): embedding=np.array([1.0, 0.0], dtype=np.float32), enrollments=[{"source_type": "audio"}], ) - with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ - patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ - patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([0.0, 1.0], dtype=np.float32)), \ - patch("aTrain.voiceprint_cli.load_voiceprint", return_value=existing), \ - patch("aTrain.voiceprint_cli.save_voiceprint") as save: - profile = enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=True, device="CPU") + with ( + patch("aTrain.voiceprint_cli.check_model_downloaded"), + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), + patch( + "aTrain.voiceprint_cli.extract_embedding", + return_value=np.array([0.0, 1.0], dtype=np.float32), + ), + patch("aTrain.voiceprint_cli.load_voiceprint", return_value=existing), + patch("aTrain.voiceprint_cli.save_voiceprint") as save, + ): + profile = enroll_voiceprint_from_audio( + Path("sample.wav"), "李想", update=True, device="CPU" + ) np.testing.assert_allclose(profile.embedding, np.array([0.5, 0.5], dtype=np.float32)) self.assertEqual(len(profile.enrollments), 2) @@ -77,10 +94,16 @@ def test_speaker_embedding_enrollment_selects_requested_label(self): labels = ["SPEAKER_00", "SPEAKER_01"] embeddings = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) - with patch("aTrain.voiceprint_cli.read_embedding_file", return_value=(labels, embeddings)), \ - patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ - patch("aTrain.voiceprint_cli.save_voiceprint"): - profile = enroll_voiceprint_from_speaker_embedding(Path("speakers.npz"), "SPEAKER_01", "李想", update=False) + with ( + patch("aTrain.voiceprint_cli.read_embedding_file", return_value=(labels, embeddings)), + patch( + "aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing") + ), + patch("aTrain.voiceprint_cli.save_voiceprint"), + ): + profile = enroll_voiceprint_from_speaker_embedding( + Path("speakers.npz"), "SPEAKER_01", "李想", update=False + ) np.testing.assert_allclose(profile.embedding, np.array([0.0, 1.0], dtype=np.float32)) self.assertEqual(profile.enrollments[0]["speaker_label"], "SPEAKER_01") @@ -88,18 +111,32 @@ def test_speaker_embedding_enrollment_selects_requested_label(self): def test_speaker_embedding_enrollment_rejects_missing_label(self): from aTrain.voiceprint_cli import enroll_voiceprint_from_speaker_embedding - with patch("aTrain.voiceprint_cli.read_embedding_file", return_value=(["SPEAKER_00"], np.array([[1.0, 0.0]], dtype=np.float32))): - with self.assertRaisesRegex(ValueError, "Speaker label not found"): - enroll_voiceprint_from_speaker_embedding(Path("speakers.npz"), "SPEAKER_01", "李想", update=False) + with ( + patch( + "aTrain.voiceprint_cli.read_embedding_file", + return_value=(["SPEAKER_00"], np.array([[1.0, 0.0]], dtype=np.float32)), + ), + self.assertRaisesRegex(ValueError, "Speaker label not found"), + ): + enroll_voiceprint_from_speaker_embedding( + Path("speakers.npz"), "SPEAKER_01", "李想", update=False + ) def test_audio_enrollment_checks_model_before_resolving_path(self): from aTrain.voiceprint_cli import enroll_voiceprint_from_audio - with patch("aTrain.voiceprint_cli.check_model_downloaded", side_effect=FileNotFoundError("missing")) as check, \ - patch("aTrain.voiceprint_cli.get_model") as get_model, \ - patch("aTrain.voiceprint_cli.extract_embedding") as extract: - with self.assertRaisesRegex(FileNotFoundError, "missing"): - enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device=Device.CPU) + with ( + patch( + "aTrain.voiceprint_cli.check_model_downloaded", + side_effect=FileNotFoundError("missing"), + ) as check, + patch("aTrain.voiceprint_cli.get_model") as get_model, + patch("aTrain.voiceprint_cli.extract_embedding") as extract, + self.assertRaisesRegex(FileNotFoundError, "missing"), + ): + enroll_voiceprint_from_audio( + Path("sample.wav"), "李想", update=False, device=Device.CPU + ) check.assert_called_once_with("speaker-detection") get_model.assert_not_called() @@ -108,37 +145,59 @@ def test_audio_enrollment_checks_model_before_resolving_path(self): def test_audio_enrollment_maps_cli_cpu_device_to_torch_device(self): from aTrain.voiceprint_cli import enroll_voiceprint_from_audio - with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ - patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ - patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([1.0, 0.0], dtype=np.float32)) as extract, \ - patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ - patch("aTrain.voiceprint_cli.save_voiceprint"): - enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device=Device.CPU) + with ( + patch("aTrain.voiceprint_cli.check_model_downloaded"), + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), + patch( + "aTrain.voiceprint_cli.extract_embedding", + return_value=np.array([1.0, 0.0], dtype=np.float32), + ) as extract, + patch( + "aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing") + ), + patch("aTrain.voiceprint_cli.save_voiceprint"), + ): + enroll_voiceprint_from_audio( + Path("sample.wav"), "李想", update=False, device=Device.CPU + ) self.assertEqual(extract.call_args.args[2], torch.device("cpu")) def test_audio_enrollment_maps_cli_gpu_device_to_cuda_when_available(self): from aTrain.voiceprint_cli import enroll_voiceprint_from_audio - with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ - patch("aTrain.voiceprint_cli.torch.cuda.is_available", return_value=True), \ - patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ - patch("aTrain.voiceprint_cli.extract_embedding", return_value=np.array([1.0, 0.0], dtype=np.float32)) as extract, \ - patch("aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing")), \ - patch("aTrain.voiceprint_cli.save_voiceprint"): - enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device=Device.GPU) + with ( + patch("aTrain.voiceprint_cli.check_model_downloaded"), + patch("aTrain.voiceprint_cli.torch.cuda.is_available", return_value=True), + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), + patch( + "aTrain.voiceprint_cli.extract_embedding", + return_value=np.array([1.0, 0.0], dtype=np.float32), + ) as extract, + patch( + "aTrain.voiceprint_cli.load_voiceprint", side_effect=FileNotFoundError("missing") + ), + patch("aTrain.voiceprint_cli.save_voiceprint"), + ): + enroll_voiceprint_from_audio( + Path("sample.wav"), "李想", update=False, device=Device.GPU + ) self.assertEqual(extract.call_args.args[2], torch.device("cuda")) def test_audio_enrollment_rejects_unavailable_gpu_before_extracting(self): from aTrain.voiceprint_cli import enroll_voiceprint_from_audio - with patch("aTrain.voiceprint_cli.check_model_downloaded"), \ - patch("aTrain.voiceprint_cli.torch.cuda.is_available", return_value=False), \ - patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), \ - patch("aTrain.voiceprint_cli.extract_embedding") as extract: - with self.assertRaisesRegex(ValueError, "GPU is not available"): - enroll_voiceprint_from_audio(Path("sample.wav"), "李想", update=False, device=Device.GPU) + with ( + patch("aTrain.voiceprint_cli.check_model_downloaded"), + patch("aTrain.voiceprint_cli.torch.cuda.is_available", return_value=False), + patch("aTrain.voiceprint_cli.get_model", return_value=Path("speaker-detection")), + patch("aTrain.voiceprint_cli.extract_embedding") as extract, + self.assertRaisesRegex(ValueError, "GPU is not available"), + ): + enroll_voiceprint_from_audio( + Path("sample.wav"), "李想", update=False, device=Device.GPU + ) extract.assert_not_called() From 1d1904f0b7ba71b8480471beb26b42aca80d53cc Mon Sep 17 00:00:00 2001 From: phoenixray2000 Date: Wed, 27 May 2026 13:17:57 +0800 Subject: [PATCH 7/7] Cover CLI voiceprint postprocessing --- tests/test_cli_voiceprints.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_cli_voiceprints.py b/tests/test_cli_voiceprints.py index c294bdd..6e3580c 100644 --- a/tests/test_cli_voiceprints.py +++ b/tests/test_cli_voiceprints.py @@ -1,3 +1,4 @@ +import json import tempfile import unittest from pathlib import Path @@ -8,6 +9,45 @@ class CliVoiceprintTests(unittest.TestCase): + def test_postprocess_applies_speaker_map_to_staged_transcript(self): + from aTrain import cli as cli_module + + with tempfile.TemporaryDirectory() as temp_dir: + staging_dir = Path(temp_dir) + transcript_dir = staging_dir / "file-id" + transcript_dir.mkdir() + (transcript_dir / "transcription.json").write_text( + json.dumps( + { + "segments": [ + { + "speaker": "SPEAKER_00", + "text": "hello", + "words": [ + {"speaker": "SPEAKER_00", "word": "hello"}, + ], + } + ] + } + ), + encoding="utf-8", + ) + + with mock.patch("aTrain_core.outputs.create_output_files") as create_outputs: + cli_module._postprocess_staged_outputs( + staging_dir, + "file-id", + speaker_detection=True, + speaker_map={"SPEAKER_00": "Ray"}, + ) + + create_outputs.assert_called_once() + transcript, speaker_detection, file_id = create_outputs.call_args.args + self.assertIs(speaker_detection, True) + self.assertEqual(file_id, "file-id") + self.assertEqual(transcript["segments"][0]["speaker"], "Ray") + self.assertEqual(transcript["segments"][0]["words"][0]["speaker"], "Ray") + def test_identify_speakers_requires_speaker_detection(self): runner = CliRunner()