diff --git a/pythaitts/pretrained/khanomtan_tts.py b/pythaitts/pretrained/khanomtan_tts.py index c895cfb..9e94fe5 100644 --- a/pythaitts/pretrained/khanomtan_tts.py +++ b/pythaitts/pretrained/khanomtan_tts.py @@ -11,7 +11,6 @@ This model uses the TTS package from: `https://github.com/idiap/coqui-ai-TTS `_ """ import tempfile -from TTS.utils.synthesizer import Synthesizer from huggingface_hub import hf_hub_download @@ -46,7 +45,13 @@ def load_synthesizer(self, mode): """ mode: The model mode (best_mode or last_checkpoint) """ - if mode=="best_model": + try: + from TTS.utils.synthesizer import Synthesizer + except ImportError: + raise ImportError( + "You must install coqui-tts before using this model: pip install coqui-tts" + ) + if mode == "best_model": self.best_model_path = hf_hub_download(repo_id="wannaphong/khanomtan-tts-v{0}".format(self.version),filename=self.best_model_path_name,force_filename="best_model-v{0}.pth".format(self.version)) self.synthesizer = Synthesizer( tts_checkpoint=self.best_model_path, diff --git a/tests/test_khanomtan.py b/tests/test_khanomtan.py new file mode 100644 index 0000000..0630908 --- /dev/null +++ b/tests/test_khanomtan.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +""" +Unit tests for KhanomTan TTS integration +""" +import sys +import unittest +from unittest.mock import patch, MagicMock + + +class TestKhanomTanImportError(unittest.TestCase): + """Test that a helpful ImportError is raised when coqui-tts is not installed""" + + def test_import_error_when_tts_not_installed(self): + """Test that ImportError with helpful message is raised when TTS package is missing""" + with patch.dict(sys.modules, {"TTS": None, "TTS.utils": None, "TTS.utils.synthesizer": None}): + # Remove cached module if present + for key in list(sys.modules.keys()): + if key.startswith("pythaitts.pretrained.khanomtan"): + del sys.modules[key] + + from pythaitts.pretrained.khanomtan_tts import KhanomTan + + instance = KhanomTan.__new__(KhanomTan) + instance.version = "1.0" + instance.best_model_path_name = "best_model.pth" + instance.last_checkpoint_model_path_name = "checkpoint_440000.pth" + instance.config_path = "config.json" + instance.speakers_path = "speakers.pth" + instance.languages_path = "language_ids.json" + instance.speaker_encoder_model_path = "model_se.pth" + instance.speaker_encoder_config_path = "config_se.json" + instance.synthesizer = None + + with self.assertRaises(ImportError) as ctx: + instance.load_synthesizer("last_checkpoint") + + self.assertIn("coqui-tts", str(ctx.exception)) + self.assertIn("pip install", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main()