-
Notifications
You must be signed in to change notification settings - Fork 30k
Description
Python: 3.6.10
PyTorch: 1.5.0
Transformers: 2.8.0 and 2.9.0
In the following code, I wrap the pretrained BERT with a DataParallel wrapper so as to run it on multiple GPUs:
import torch, transformers
model = transformers.AutoModel.from_pretrained("bert-base-multilingual-cased")
model = torch.nn.DataParallel(model)
model = model.cuda()
input = torch.ones([16, 10], dtype=torch.long)
input = input.cuda()
model(input)
But I got the following error:
Traceback (most recent call last):
File "", line 1, in
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
output.reraise()
File "/home/anaconda/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
raise self.exc_type(msg)
StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
output = module(*input, **kwargs)
File "/home/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/home/anaconda/lib/python3.6/site-packages/transformers/modeling_bert.py", line 734, in forward
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
StopIteration
But it will work if I remove the DataParallel wrapper.