diff --git a/utils.py b/utils.py index 3dc9be1..0d6cec1 100644 --- a/utils.py +++ b/utils.py @@ -13,7 +13,7 @@ from ignite.contrib.handlers import ProgressBar from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler, OutputHandler, TensorboardLogger -from pytorch_pretrained_bert import cached_path +from pytorch_pretrained_bert import cached_path, BertTokenizer DATASETS_URL = { 'wikitext-2': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/wikitext-2/train.txt", @@ -25,16 +25,16 @@ 'simplebooks-92-raw': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-92-raw/train.txt", 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-92-raw/valid.txt"}, 'imdb': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.txt", - 'test': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/test.txt"}, + 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/valid.txt"}, 'trec': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.txt", - 'test': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.txt"}, + 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.txt"}, } DATASETS_LABELS_URL = { 'imdb': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.labels.txt", - 'test': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/test.labels.txt"}, + 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/valid.labels.txt"}, 'trec': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.labels.txt", - 'test': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.labels.txt"}, + 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.labels.txt"}, } DATASETS_LABELS_CONVERSION = { @@ -124,7 +124,7 @@ def get_and_tokenize_dataset(tokenizer, dataset_dir='wikitext-103', dataset_cach if with_labels: label_conversion_map = DATASETS_LABELS_CONVERSION[dataset_dir] for split_name in DATASETS_LABELS_URL[dataset_dir]: - dataset_file = cached_path(dataset_map['labels'][split_name]) + dataset_file = cached_path(DATASETS_LABELS_URL[dataset_dir][split_name]) with open(dataset_file, "r", encoding="utf-8") as f: all_lines = f.readlines() labels[split_name] = [label_conversion_map[line.strip()] for line in tqdm(all_lines)] @@ -154,3 +154,10 @@ def encode(obj): torch.save(encoded_dataset, dataset_cache) return encoded_dataset + + +if __name__ == '__main__': + tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) + get_and_tokenize_dataset(tokenizer, dataset_dir='imdb', dataset_cache=None, with_labels=True) + +