diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index f581555b9739..1e50a3ca25d5 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -117,11 +117,27 @@ def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: return metrics_mean_ci +def filter_datasets(dataset_meta_info: dict, datasets: Optional[List[str]]) -> List[str]: + """Select datasets from the dataset meta info.""" + if datasets is None: + # Dataset filtering not specified, return all datasets + return list(dataset_meta_info.keys()) + else: + datasets = datasets.split(",") + # Check if datasets are valid + for dataset in datasets: + if dataset not in dataset_meta_info: + raise ValueError(f"Dataset {dataset} not found in dataset meta info") + # Return all requsted datasets + return datasets + + def run_inference_and_evaluation( model_config: ModelLoadConfig, inference_config: InferenceConfig, eval_config: EvaluationConfig, dataset_meta_info: dict, + datasets: Optional[List[str]], out_dir: str, num_repeats: int = 1, confidence_level: float = 0.95, @@ -137,6 +153,8 @@ def run_inference_and_evaluation( inference_config: Configuration for inference. eval_config: Configuration for evaluation. dataset_meta_info: Dictionary containing dataset metadata. + datasets: List of dataset names to run inference and evaluation on. If None, all datasets in the + dataset meta info will be processed. out_dir: Output directory for results. num_repeats: Number of times to repeat inference (for CI estimation). confidence_level: Confidence level for CI calculation. @@ -170,7 +188,6 @@ def run_inference_and_evaluation( runner = MagpieInferenceRunner(model, inference_config) # Tracking metrics across datasets - datasets = list(dataset_meta_info.keys()) ssim_per_dataset = [] cer_per_dataset = [] all_datasets_filewise_metrics = {} @@ -369,8 +386,15 @@ def create_argument_parser() -> argparse.ArgumentParser: data_group.add_argument( '--datasets_json_path', type=str, + required=True, + default=None, + help='Path to dataset configuration JSON file (will process all datasets in the file if --datasets is not specified)', + ) + data_group.add_argument( + '--datasets', + type=str, default=None, - help='Path to dataset configuration JSON file (will process all datasets in the file)', + help='Comma-separated list of dataset names to process using names from the datasets_json_path file. If not specified, all datasets in the datasets_json_path will be processed.', ) data_group.add_argument( '--out_dir', @@ -478,7 +502,7 @@ def main(): args = parser.parse_args() dataset_meta_info = load_evalset_config(args.datasets_json_path) - datasets = list(dataset_meta_info.keys()) + datasets = filter_datasets(dataset_meta_info, args.datasets) logging.info(f"Loaded {len(datasets)} datasets: {', '.join(datasets)}") @@ -549,6 +573,7 @@ def main(): inference_config=inference_config, eval_config=eval_config, dataset_meta_info=dataset_meta_info, + datasets=datasets, out_dir=args.out_dir, num_repeats=args.num_repeats, confidence_level=args.confidence_level,