Skip to content

Commit 05053d1

Browse files
committed
update cache_dir in readme and examples
1 parent 63ae5d2 commit 05053d1

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,12 @@ Here is a detailed documentation of the classes in the package and how to use th
162162
To load one of Google AI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as
163163

164164
```python
165-
model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH)
165+
model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
166166
```
167167

168168
where
169169

170170
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the six PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering`, and
171-
172171
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
173172

174173
- the shortcut name of a Google AI's pre-trained model selected in the list:
@@ -184,7 +183,8 @@ where
184183
- `bert_config.json` a configuration file for the model, and
185184
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`)
186185

187-
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
186+
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
187+
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information)
188188

189189
Example:
190190
```python

examples/run_classifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,8 @@ def main():
482482
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
483483

484484
# Prepare model
485-
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list))
485+
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list),
486+
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
486487
if args.fp16:
487488
model.half()
488489
model.to(device)

examples/run_squad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,8 @@ def main():
821821
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
822822

823823
# Prepare model
824-
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
824+
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
825+
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
825826
if args.fp16:
826827
model.half()
827828
model.to(device)

pytorch_pretrained_bert/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
BertForMaskedLM, BertForNextSentencePrediction,
44
BertForSequenceClassification, BertForQuestionAnswering)
55
from .optimization import BertAdam
6+
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE

0 commit comments

Comments
 (0)