From 3ce6b2b817b735b7ce8a572cadec86bf2685e836 Mon Sep 17 00:00:00 2001 From: sahilsharma05 Date: Fri, 6 Sep 2019 16:35:56 +0900 Subject: [PATCH 1/4] Fixed issue: invalid index while using dataset_map['labels'] --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index 3dc9be1..e67fed1 100644 --- a/utils.py +++ b/utils.py @@ -124,7 +124,7 @@ def get_and_tokenize_dataset(tokenizer, dataset_dir='wikitext-103', dataset_cach if with_labels: label_conversion_map = DATASETS_LABELS_CONVERSION[dataset_dir] for split_name in DATASETS_LABELS_URL[dataset_dir]: - dataset_file = cached_path(dataset_map['labels'][split_name]) + dataset_file = cached_path(DATASETS_LABELS_URL[dataset_dir][split_name]) with open(dataset_file, "r", encoding="utf-8") as f: all_lines = f.readlines() labels[split_name] = [label_conversion_map[line.strip()] for line in tqdm(all_lines)] From cb56df75378f8ae8368b1543e08ed7a7dd6287e4 Mon Sep 17 00:00:00 2001 From: sahilsharma05 Date: Fri, 6 Sep 2019 16:41:11 +0900 Subject: [PATCH 2/4] Fixed issue: - replaced 'test' with 'valid' key in global variables DATASETS_URL and DATASETS_LABELS_URL. The training logic is using valid files. - for IMDB dataset, test.txt and test.labels.txt doesn't exist. replaced with valid.txt and valid.labels.txt (exist on s3 server). --- utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/utils.py b/utils.py index e67fed1..53272a0 100644 --- a/utils.py +++ b/utils.py @@ -25,16 +25,16 @@ 'simplebooks-92-raw': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-92-raw/train.txt", 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-92-raw/valid.txt"}, 'imdb': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.txt", - 'test': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/test.txt"}, + 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/valid.txt"}, 'trec': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.txt", - 'test': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.txt"}, + 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.txt"}, } DATASETS_LABELS_URL = { 'imdb': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.labels.txt", - 'test': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/test.labels.txt"}, + 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/valid.labels.txt"}, 'trec': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.labels.txt", - 'test': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.labels.txt"}, + 'valid': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.labels.txt"}, } DATASETS_LABELS_CONVERSION = { @@ -154,3 +154,10 @@ def encode(obj): torch.save(encoded_dataset, dataset_cache) return encoded_dataset + + +if __name__ == '__main__': + tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) + get_and_tokenize_dataset(tokenizer, dataset_dir='imdb', dataset_cache=None, with_labels=True) + + From ac9f44b50fef94575c14ab3fc51da78973ad52a5 Mon Sep 17 00:00:00 2001 From: sahilsharma05 Date: Fri, 6 Sep 2019 16:42:36 +0900 Subject: [PATCH 3/4] delete extra added code --- utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/utils.py b/utils.py index 53272a0..dd79c6c 100644 --- a/utils.py +++ b/utils.py @@ -156,8 +156,4 @@ def encode(obj): return encoded_dataset -if __name__ == '__main__': - tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) - get_and_tokenize_dataset(tokenizer, dataset_dir='imdb', dataset_cache=None, with_labels=True) - From ec5e9bbcd57f3b42e2d8df48b563fb4fa1190ffc Mon Sep 17 00:00:00 2001 From: sahilsharma05 Date: Fri, 6 Sep 2019 16:43:17 +0900 Subject: [PATCH 4/4] added unit test case for get_and_tokenize_dataset method --- utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/utils.py b/utils.py index dd79c6c..0d6cec1 100644 --- a/utils.py +++ b/utils.py @@ -13,7 +13,7 @@ from ignite.contrib.handlers import ProgressBar from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler, OutputHandler, TensorboardLogger -from pytorch_pretrained_bert import cached_path +from pytorch_pretrained_bert import cached_path, BertTokenizer DATASETS_URL = { 'wikitext-2': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/wikitext-2/train.txt", @@ -156,4 +156,8 @@ def encode(obj): return encoded_dataset +if __name__ == '__main__': + tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) + get_and_tokenize_dataset(tokenizer, dataset_dir='imdb', dataset_cache=None, with_labels=True) +