Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,16 @@ def inference_with_vad(self, input, input_len=None, **cfg):
spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
)
# del result['spk_embedding']
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
if kwargs.get("return_spk_center", False):
sv_output, spk_center = postprocess(
all_segments, None, labels, spk_embedding.cpu(), return_spk_center=True
)
# Per-speaker ERes2NetV2 centroids, indexed by the `spk` id in
# sentence_info. Kept on the result for downstream voiceprint use
# (the per-chunk spk_embedding below is still deleted to keep output small).
result["spk_embedding_center"] = spk_center
else:
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
if self.spk_mode == "punc_segment" and "timestamp" not in result and "timestamps" not in result:
logging.warning("No timestamps in ASR result (e.g. SenseVoice), falling back to vad_segment mode for speaker diarization.")
self.spk_mode = "vad_segment"
Expand Down
10 changes: 9 additions & 1 deletion funasr/models/campplus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def extract_feature(audio):


def postprocess(
segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray
segments: list,
vad_segments: list,
labels: np.ndarray,
embeddings: np.ndarray,
return_spk_center: bool = False,
) -> list:
"""Postprocess.

Expand Down Expand Up @@ -184,6 +188,10 @@ def is_overlapped(t1, t2):
# smooth the result
distribute_res = smooth(distribute_res)

if return_spk_center:
# spk_embs[i] is the centroid (mean of clustered chunk embeddings) for
# corrected speaker label i, aligned with the `spk` ids in sentence_info.
return distribute_res, spk_embs
Comment on lines +191 to +194

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Recompute centroids after smoothing speaker labels

For recordings containing diarization regions shorter than smooth()'s 0.7s threshold, smooth() can reassign those regions to neighboring speakers, but spk_embs was already computed from the pre-smoothed labels. Returning it here means spk_embedding_center can include speakers that no longer appear in sentence_info, and the remaining speakers' centroids exclude embeddings that were assigned to them in the final diarization output, so downstream voiceprint matching uses centroids that do not match the returned spk IDs.

Useful? React with 👍 / 👎.

return distribute_res


Expand Down