Skip to content

Commit e1bfad4

Browse files
authored
Merge pull request #112 from huggingface/fourth-release
Fourth release
2 parents 91aab2a + d821358 commit e1bfad4

File tree

14 files changed

+234
-141
lines changed

14 files changed

+234
-141
lines changed

README.md

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ This implementation is provided with [Google's pre-trained models](https://githu
1919

2020
## Installation
2121

22-
This repo was tested on Python 3.6+ and PyTorch 0.4.1
22+
This repo was tested on Python 3.5+ and PyTorch 0.4.1/1.0.0
2323

2424
### With pip
2525

@@ -46,13 +46,13 @@ python -m pytest -sv tests/
4646

4747
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
4848

49-
- Seven PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
49+
- Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
5050
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**),
5151
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
5252
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
5353
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
5454
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
55-
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
55+
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
5656
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
5757
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
5858

@@ -156,7 +156,7 @@ Here is a detailed documentation of the classes in the package and how to use th
156156
| Sub-section | Description |
157157
|-|-|
158158
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
159-
| [PyTorch models](#PyTorch-models) | API of the seven PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
159+
| [PyTorch models](#PyTorch-models) | API of the eight PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering` |
160160
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
161161
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
162162

@@ -170,7 +170,7 @@ model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
170170

171171
where
172172

173-
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the seven PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification` or `BertForQuestionAnswering`, and
173+
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the eight PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering`, and
174174
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
175175

176176
- the shortcut name of a Google AI's pre-trained model selected in the list:
@@ -353,14 +353,13 @@ The optimizer accepts the following arguments:
353353

354354
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
355355

356-
To help with fine-tuning these models, we have included five techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
356+
To help with fine-tuning these models, we have included several techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
357357

358358
Here is how to use these techniques in our scripts:
359359

360360
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
361361
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs.
362362
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument (see below).
363-
- **Optimize on CPU**: The Adam optimizer stores 2 moving average of the weights of the model. If you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal for large models like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU/RAM to free more room on the GPU(s). As the most computational intensive operation is usually the backward pass, this doesn't have a significant impact on the training time. Activate this option with `--optimize_on_cpu` on the [`run_squad.py`](./examples/run_squad.py) script.
364363
- **16-bits training**: 16-bits training, also called mixed-precision training, can reduce the memory requirement of your model on the GPU by using half-precision training, basically allowing to double the batch size. If you have a recent GPU (starting from NVIDIA Volta architecture) you should see no decrease in speed. A good introduction to Mixed precision training can be found [here](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) and a full documentation is [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). In our scripts, this option can be activated by setting the `--fp16` flag and you can play with loss scaling using the `--loss_scaling` flag (see the previously linked documentation for details on loss scaling). If the loss scaling is too high (`Nan` in the gradients) it will be automatically scaled down until the value is acceptable. The default loss scaling is 128 which behaved nicely in our tests.
365364

366365
Note: To use *Distributed Training*, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see [the above mentioned blog post]((https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255)) for more details):
@@ -371,16 +370,21 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach
371370

372371
### Fine-tuning with BERT: running the examples
373372

374-
We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD.
373+
We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/):
375374

376-
Before running these examples you should download the
375+
- a *sequence-level classifier* on the MRPC classification corpus,
376+
- a *token-level classifier* on the question answering dataset SQuAD, and
377+
- a *sequence-level multiple-choice classifier* on the SWAG classification corpus.
378+
379+
#### MRPC
380+
381+
This example code fine-tunes BERT on the Microsoft Research Paraphrase
382+
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.
383+
384+
Before running this example you should download the
377385
[GLUE data](https://gluebenchmark.com/tasks) by running
378386
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
379-
and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base`
380-
checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section.
381-
382-
This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase
383-
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80.
387+
and unpack it to some directory `$GLUE_DIR`.
384388

385389
```shell
386390
export GLUE_DIR=/path/to/glue
@@ -401,7 +405,29 @@ python run_classifier.py \
401405

402406
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%.
403407

404-
The second example fine-tunes `BERT-Base` on the SQuAD question answering task.
408+
**Fast run with apex and 16 bit precision: fine-tuning on MRPC in 27 seconds!**
409+
First install apex as indicated [here](https://github.com/NVIDIA/apex).
410+
Then run
411+
```shell
412+
export GLUE_DIR=/path/to/glue
413+
414+
python run_classifier.py \
415+
--task_name MRPC \
416+
--do_train \
417+
--do_eval \
418+
--do_lower_case \
419+
--data_dir $GLUE_DIR/MRPC/ \
420+
--bert_model bert-base-uncased \
421+
--max_seq_length 128 \
422+
--train_batch_size 32 \
423+
--learning_rate 2e-5 \
424+
--num_train_epochs 3.0 \
425+
--output_dir /tmp/mrpc_output/
426+
```
427+
428+
#### SQuAD
429+
430+
This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large) on a single tesla V100 16GB.
405431

406432
The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory.
407433

@@ -432,25 +458,28 @@ Training with the previous hyper-parameters gave us the following results:
432458
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
433459
```
434460

435-
The data for Swag can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)
461+
#### SWAG
462+
463+
The data for SWAG can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)
436464

437465
```shell
438466
export SWAG_DIR=/path/to/SWAG
439467

440468
python run_swag.py \
441469
--bert_model bert-base-uncased \
442470
--do_train \
471+
--do_lower_case \
443472
--do_eval \
444-
--data_dir $SWAG_DIR/data
473+
--data_dir $SWAG_DIR/data \
445474
--train_batch_size 16 \
446475
--learning_rate 2e-5 \
447476
--num_train_epochs 3.0 \
448477
--max_seq_length 80 \
449-
--output_dir /tmp/swag_output/
478+
--output_dir /tmp/swag_output/ \
450479
--gradient_accumulation_steps 4
451480
```
452481

453-
Training with the previous hyper-parameters gave us the following results:
482+
Training with the previous hyper-parameters on a single GPU gave us the following results:
454483
```
455484
eval_accuracy = 0.8062081375587323
456485
eval_loss = 0.5966546792367169

docker/Dockerfile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
FROM pytorch/pytorch:latest
2+
3+
RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext
4+
5+
RUN pip install pytorch-pretrained-bert
6+
7+
WORKDIR /workspace

examples/extract_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def read_examples(input_file):
168168
"""Read a list of `InputExample`s from an input file."""
169169
examples = []
170170
unique_id = 0
171-
with open(input_file, "r") as reader:
171+
with open(input_file, "r", encoding='utf-8') as reader:
172172
while True:
173173
line = reader.readline()
174174
if not line:

examples/run_classifier.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,6 @@
3636
from pytorch_pretrained_bert.optimization import BertAdam
3737
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
3838

39-
try:
40-
from apex.optimizers import FP16_Optimizer
41-
from apex.optimizers import FusedAdam
42-
from apex.parallel import DistributedDataParallel as DDP
43-
except ImportError:
44-
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
45-
4639
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
4740
datefmt = '%m/%d/%Y %H:%M:%S',
4841
level = logging.INFO)
@@ -98,7 +91,7 @@ def get_labels(self):
9891
@classmethod
9992
def _read_tsv(cls, input_file, quotechar=None):
10093
"""Reads a tab separated value file."""
101-
with open(input_file, "r") as f:
94+
with open(input_file, "r", encoding='utf-8') as f:
10295
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
10396
lines = []
10497
for line in reader:
@@ -329,7 +322,7 @@ def main():
329322
default=None,
330323
type=str,
331324
required=True,
332-
help="The output directory where the model checkpoints will be written.")
325+
help="The output directory where the model predictions and checkpoints will be written.")
333326

334327
## Other parameters
335328
parser.add_argument("--max_seq_length",
@@ -420,7 +413,8 @@ def main():
420413
n_gpu = 1
421414
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
422415
torch.distributed.init_process_group(backend='nccl')
423-
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
416+
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
417+
device, n_gpu, bool(args.local_rank != -1), args.fp16))
424418

425419
if args.gradient_accumulation_steps < 1:
426420
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
@@ -467,6 +461,11 @@ def main():
467461
model.half()
468462
model.to(device)
469463
if args.local_rank != -1:
464+
try:
465+
from apex.parallel import DistributedDataParallel as DDP
466+
except ImportError:
467+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
468+
470469
model = DDP(model)
471470
elif n_gpu > 1:
472471
model = torch.nn.DataParallel(model)
@@ -482,6 +481,12 @@ def main():
482481
if args.local_rank != -1:
483482
t_total = t_total // torch.distributed.get_world_size()
484483
if args.fp16:
484+
try:
485+
from apex.optimizers import FP16_Optimizer
486+
from apex.optimizers import FusedAdam
487+
except ImportError:
488+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
489+
485490
optimizer = FusedAdam(optimizer_grouped_parameters,
486491
lr=args.learning_rate,
487492
bias_correction=False,
@@ -546,6 +551,16 @@ def main():
546551
optimizer.zero_grad()
547552
global_step += 1
548553

554+
# Save a trained model
555+
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
556+
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
557+
torch.save(model_to_save.state_dict(), output_model_file)
558+
559+
# Load a trained model that you have fine-tuned
560+
model_state_dict = torch.load(output_model_file)
561+
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
562+
model.to(device)
563+
549564
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
550565
eval_examples = processor.get_dev_examples(args.data_dir)
551566
eval_features = convert_examples_to_features(

0 commit comments

Comments
 (0)