From 9708f36d669c0a065a848664aedd96352df4010a Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Mon, 5 May 2025 19:45:18 -0700 Subject: [PATCH 1/2] FSDP2 example Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f3eeff8c67a43405f18d4271809c8385da5e9b70 Pull Request resolved: https://github.com/pytorch/examples/pull/1339 --- distributed/FSDP2/README.md | 17 ++++++++++++++ distributed/FSDP2/train.py | 45 +++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 distributed/FSDP2/README.md create mode 100644 distributed/FSDP2/train.py diff --git a/distributed/FSDP2/README.md b/distributed/FSDP2/README.md new file mode 100644 index 0000000000..1f85acb469 --- /dev/null +++ b/distributed/FSDP2/README.md @@ -0,0 +1,17 @@ +## FSDP2 + +To run FSDP2 on transformer model: + +## Install the requirements: +~~~ +pip install -r requirements.txt +~~~ + +## Ensure you are running a recent version of PyTorch: +see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build. + +Start the training with `torchrun` Torchrun (adjust nproc_per_node to your GPU count): + +``` +torchrun --nnodes 1 --nproc_per_node 2 train.py +``` diff --git a/distributed/FSDP2/train.py b/distributed/FSDP2/train.py new file mode 100644 index 0000000000..3b38f214d6 --- /dev/null +++ b/distributed/FSDP2/train.py @@ -0,0 +1,45 @@ +import os +import argparse +import torch +from torch.distributed.fsdp import fully_shard +from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer + + +def main(args): + torch.distributed.init_process_group(backend="nccl") + rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.manual_seed(rank) + vocab_size = 1024 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + vocab_size=vocab_size, + max_seq_len=64, + dropout_p=0, + ) + model = Transformer(model_args) + for layer in model.layers: + fully_shard(layer) + fully_shard(model) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + for _ in range(10): + x = torch.randint(0, vocab_size, (32, 32), device=device) + loss = model(x).sum() + loss.backward() + optim.step() + optim.zero_grad() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PyTorch FSDP2 example') + parser.add_argument('--meta-init', type=int, default=4, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--epochs', type=int, default=2, metavar='N', + help='number of epochs to train (default: 3)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + args = parser.parse_args() + main(args) From 48eb64b99a0838616f699e78f55ebc7676d71efe Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Mon, 5 May 2025 19:51:11 -0700 Subject: [PATCH 2/2] args ghstack-source-id: 73725dcf539537e2a3e341370c42290dd0fc8792 Pull Request resolved: https://github.com/pytorch/examples/pull/1340 --- distributed/FSDP2/train.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/distributed/FSDP2/train.py b/distributed/FSDP2/train.py index 3b38f214d6..f9c123baad 100644 --- a/distributed/FSDP2/train.py +++ b/distributed/FSDP2/train.py @@ -10,7 +10,7 @@ def main(args): rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - torch.manual_seed(rank) + torch.manual_seed(args.seed) vocab_size = 1024 model_args = ModelArgs( n_layers=3, @@ -35,10 +35,6 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description='PyTorch FSDP2 example') - parser.add_argument('--meta-init', type=int, default=4, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--epochs', type=int, default=2, metavar='N', - help='number of epochs to train (default: 3)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') args = parser.parse_args()