Skip to content

Bug: can not use pretrained BERT on multiple GPUs with DataParallel (PyTorch 1.5.0) #4189

@erikchwang

Description

@erikchwang

Python: 3.6.10
PyTorch: 1.5.0
Transformers: 2.8.0 and 2.9.0

In the following code, I wrap the pretrained BERT with a DataParallel wrapper so as to run it on multiple GPUs:

import torch, transformers
model = transformers.AutoModel.from_pretrained("bert-base-multilingual-cased")
model = torch.nn.DataParallel(model)
model = model.cuda()
input = torch.ones([16, 10], dtype=torch.long)
input = input.cuda()
model(input)

But I got the following error:

Traceback (most recent call last):
File "", line 1, in
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
output.reraise()
File "/home/anaconda/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
raise self.exc_type(msg)
StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
output = module(*input, **kwargs)
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/home/anaconda/lib/python3.6/site-packages/transformers/modeling_bert.py", line 734, in forward
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
StopIteration

But it will work if I remove the DataParallel wrapper.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions