diff --git a/torchTextClassifiers/tokenizers/base.py b/torchTextClassifiers/tokenizers/base.py index dee5546..3a3b360 100644 --- a/torchTextClassifiers/tokenizers/base.py +++ b/torchTextClassifiers/tokenizers/base.py @@ -103,6 +103,11 @@ def __repr__(self): def __call__(self, text: Union[str, List[str]], **kwargs) -> list: return self.tokenize(text, **kwargs) + @classmethod + @abstractmethod + def load_from_s3(cls, s3_path: str, filesystem): + pass + class HuggingFaceTokenizer(BaseTokenizer): def __init__( @@ -178,17 +183,14 @@ def load(cls, load_path: str): @classmethod def load_from_s3(cls, s3_path: str, filesystem): if filesystem.exists(s3_path) is False: - raise FileNotFoundError( - f"Tokenizer not found at {s3_path}. Please train it first (see src/train_tokenizers)." - ) + raise FileNotFoundError(f"Tokenizer not found at {s3_path}.") with filesystem.open(s3_path, "rb") as f: json_str = f.read().decode("utf-8") tokenizer_obj = Tokenizer.from_str(json_str) - tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_obj) - instance = cls(vocab_size=len(tokenizer), trained=True) - instance.tokenizer = tokenizer + instance = cls(vocab_size=tokenizer_obj.get_vocab_size(), trained=True) + instance.tokenizer = tokenizer_obj instance._post_training() return instance diff --git a/torchTextClassifiers/tokenizers/ngram.py b/torchTextClassifiers/tokenizers/ngram.py index ed0d8cb..fae323e 100644 --- a/torchTextClassifiers/tokenizers/ngram.py +++ b/torchTextClassifiers/tokenizers/ngram.py @@ -432,11 +432,24 @@ def save_pretrained(self, save_directory: str): print(f"✓ Tokenizer saved to {save_directory}") @classmethod - def from_pretrained(cls, directory: str): + def load_from_s3(cls, s3_path: str, filesystem): """Load tokenizer from saved configuration.""" - with open(f"{directory}/tokenizer.json", "r") as f: + + config = json.load(filesystem.open(s3_path, "r")) + tokenizer = cls.build_from_config(config) + return tokenizer + + @classmethod + def load(cls, path: str): + """Load tokenizer from saved configuration.""" + + with open(path, "r") as f: config = json.load(f) + tokenizer = cls.build_from_config(config) + return tokenizer + @classmethod + def build_from_config(cls, config): tokenizer = cls( min_count=config["min_count"], min_n=config["min_n"], @@ -468,5 +481,4 @@ def from_pretrained(cls, directory: str): ) print("✓ Subword cache built") - print(f"✓ Tokenizer loaded from {directory}") return tokenizer