-
Notifications
You must be signed in to change notification settings - Fork 33.5k
Enable PyTorch/XLA Fully Sharded Data Parallel (FSDP) for a Specific Class of Transformer Models #20774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable PyTorch/XLA Fully Sharded Data Parallel (FSDP) for a Specific Class of Transformer Models #20774
Changes from all commits
362ba54
f232e2b
46fa159
89aaf58
68e338c
04da398
54eab8b
0a0b8da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -171,7 +171,7 @@ | |
| import torch_xla.core.xla_model as xm | ||
| import torch_xla.debug.metrics as met | ||
| import torch_xla.distributed.parallel_loader as pl | ||
|
|
||
| if is_fairscale_available(): | ||
| dep_version_check("fairscale") | ||
| import fairscale | ||
|
|
@@ -417,6 +417,42 @@ def __init__( | |
| elif FSDPOption.NO_SHARD in args.fsdp: | ||
| self.fsdp = ShardingStrategy.NO_SHARD | ||
|
|
||
| if args.xla_fsdp: | ||
| try: | ||
| from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module | ||
| except ImportError: | ||
| assert False, "Missing module XLAFullyShardedDataParallel; this module is available in torch-xla >= 1.12.0." | ||
| fsdp_kwargs = args.xla_fsdp_config | ||
| fsdp_wrap = lambda m: FSDP(m, **fsdp_kwargs) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not refactor in a lambda this for 2 uses, especially since it's so simple. Let's call FSDP with the kwargs in the 2 instances below instead. |
||
| # A wrapper for gradient checkpointing | ||
| grad_ckpt_wrap = checkpoint_module if args.xla_fsdp_grad_ckpt else (lambda m: m) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, we prefer explicit tests even if it takes 2 more lines of code. |
||
| if args.xla_fsdp_nested: | ||
| if not(hasattr(model, 'transformer') and hasattr(model.transformer, 'h')): | ||
| raise ValueError( | ||
| "Nested XLA FSDP is currently only supported for models which expose their" | ||
| " transformer blocks through `transformer.h`." | ||
| ) | ||
|
Comment on lines
+431
to
+434
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's just GPT/GPT-2. This is too restrictive to be merged like this in the Trainer. |
||
| # Wrap each transformer block | ||
| for i in range(len(model.transformer.h)): | ||
| model.transformer.h[i] = fsdp_wrap(grad_ckpt_wrap(model.transformer.h[i])) | ||
| # Wrap the base model with an outer FSDP wrapper | ||
| # Also, copy the signature of the original model's forward method -- otherwise | ||
| # columns not appearing in the forward method's argument will be dropped by | ||
| # the `_remove_unused_columns` method | ||
| forward_signature = inspect.signature(model.forward.__func__) | ||
| model = fsdp_wrap(model) | ||
| model.forward.__func__.__signature__ = forward_signature | ||
|
|
||
| # Patch `xm.optimizer_step` should not reduce gradients in this case, | ||
| # as FSDP does not need gradient reduction over sharded parameters. | ||
| def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): | ||
| loss = optimizer.step(**optimizer_args) | ||
| if barrier: | ||
| xm.mark_step() | ||
| return loss | ||
|
|
||
| xm.optimizer_step = patched_optimizer_step | ||
|
|
||
| # one place to sort out whether to place the model on device or not | ||
| # postpone switching model to cuda when: | ||
| # 1. MP - since we are trying to fit a much bigger than 1 gpu model | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't use assert in code of the library, we prefer explicit exceptions.