diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 3820782db..9789a04f0 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -818,7 +818,16 @@ def inference_with_vad(self, input, input_len=None, **cfg): spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None) ) # del result['spk_embedding'] - sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) + if kwargs.get("return_spk_center", False): + sv_output, spk_center = postprocess( + all_segments, None, labels, spk_embedding.cpu(), return_spk_center=True + ) + # Per-speaker ERes2NetV2 centroids, indexed by the `spk` id in + # sentence_info. Kept on the result for downstream voiceprint use + # (the per-chunk spk_embedding below is still deleted to keep output small). + result["spk_embedding_center"] = spk_center + else: + sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) if self.spk_mode == "punc_segment" and "timestamp" not in result and "timestamps" not in result: logging.warning("No timestamps in ASR result (e.g. SenseVoice), falling back to vad_segment mode for speaker diarization.") self.spk_mode = "vad_segment" diff --git a/funasr/models/campplus/utils.py b/funasr/models/campplus/utils.py index e9f5eb4a1..a3314f745 100644 --- a/funasr/models/campplus/utils.py +++ b/funasr/models/campplus/utils.py @@ -138,7 +138,11 @@ def extract_feature(audio): def postprocess( - segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray + segments: list, + vad_segments: list, + labels: np.ndarray, + embeddings: np.ndarray, + return_spk_center: bool = False, ) -> list: """Postprocess. @@ -184,6 +188,10 @@ def is_overlapped(t1, t2): # smooth the result distribute_res = smooth(distribute_res) + if return_spk_center: + # spk_embs[i] is the centroid (mean of clustered chunk embeddings) for + # corrected speaker label i, aligned with the `spk` ids in sentence_info. + return distribute_res, spk_embs return distribute_res