Skip to content
38 changes: 37 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl

if is_fairscale_available():
dep_version_check("fairscale")
import fairscale
Expand Down Expand Up @@ -417,6 +417,42 @@ def __init__(
elif FSDPOption.NO_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.NO_SHARD

if args.xla_fsdp:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module
except ImportError:
assert False, "Missing module XLAFullyShardedDataParallel; this module is available in torch-xla >= 1.12.0."

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't use assert in code of the library, we prefer explicit exceptions.

fsdp_kwargs = args.xla_fsdp_config
fsdp_wrap = lambda m: FSDP(m, **fsdp_kwargs)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not refactor in a lambda this for 2 uses, especially since it's so simple. Let's call FSDP with the kwargs in the 2 instances below instead.

# A wrapper for gradient checkpointing
grad_ckpt_wrap = checkpoint_module if args.xla_fsdp_grad_ckpt else (lambda m: m)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, we prefer explicit tests even if it takes 2 more lines of code.

if args.xla_fsdp_nested:
if not(hasattr(model, 'transformer') and hasattr(model.transformer, 'h')):
raise ValueError(
"Nested XLA FSDP is currently only supported for models which expose their"
" transformer blocks through `transformer.h`."
)
Comment on lines +431 to +434

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just GPT/GPT-2. This is too restrictive to be merged like this in the Trainer.

# Wrap each transformer block
for i in range(len(model.transformer.h)):
model.transformer.h[i] = fsdp_wrap(grad_ckpt_wrap(model.transformer.h[i]))
# Wrap the base model with an outer FSDP wrapper
# Also, copy the signature of the original model's forward method -- otherwise
# columns not appearing in the forward method's argument will be dropped by
# the `_remove_unused_columns` method
forward_signature = inspect.signature(model.forward.__func__)
model = fsdp_wrap(model)
model.forward.__func__.__signature__ = forward_signature

# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss

xm.optimizer_step = patched_optimizer_step

# one place to sort out whether to place the model on device or not
# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
Expand Down
78 changes: 78 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,25 @@ class TrainingArguments:
The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.

Possible choices are `"default"`, `"reduce-overhead"` and `"max-autotune"`.

xla_fsdp (`str`, `dict`, *optional*):
Use PyTorch/XLA Fully Sharded Data Parallel Training

For a complete list of options, please see [here](
https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).

This is an experimental feature and its API may evolve in the future. The value is the location of config
file (e.g., `fsdp_config.json`).
xla_fsdp_nested (`bool`, *optional*, defaults to `False`):
Will use nested XLA FSDP to shard each transformer block layer. This setting can only be used with
xla_fsdp. Currently, only models which expose their their transformers block through the class attribute
`transformer.h` may use this feature.
xla_fsdp_grad_ckpt (`bool`, *optional*, defaults to `False`):
Will use gradient checkpointing over each XLA FSDP wrapped layer. This setting can only be used with
xla_fsdp and xla_fsdp_nested.



"""

framework = "pt"
Expand Down Expand Up @@ -953,6 +972,35 @@ class TrainingArguments:
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
)
xla_fsdp: Optional[str] = field(
default=None,
metadata={
"help": (
"Whether or not to use PyTorch/XLA Fully Sharded Data Parallel (FSDP) training. For a complete list"
" of configuration options, please see the PyTorch/XLA FSDP definitions."
),
},
)
xla_fsdp_nested: Optional[bool] = field(
default=False,
metadata={
"help": (
"Will use nested XLA FSDP to shard each transformer block layer. This setting can only be used with xla_fsdp."
" Currently, only models which expose their their transformers block through the class attribute `transformer.h`"
" may use this feature."
),
},
)
xla_fsdp_grad_ckpt: Optional[bool] = field(
default=False,
metadata={
"help": (
"Will use gradient checkpointing over each XLA FSDP wrapped layer. This setting can only be used with xla_fsdp"
" and xla_fsdp_nested."

),
},
)
# Deprecated arguments
fp16_backend: str = field(
default="auto",
Expand Down Expand Up @@ -1299,6 +1347,36 @@ def __post_init__(self):
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
self.hf_deepspeed_config.trainer_config_process(self)

if self.xla_fsdp:
# gather fsdp configuration parameters into a dictionary from specified json file
with io.open(self.xla_fsdp, "r", encoding="utf-8") as f:
self.xla_fsdp = json.load(f)
# apply appropriate string to torch.dtype conversions for parameters
dtype_dict = {
"torch.float32": torch.float32,
"torch.float16" : torch.float16,
"torch.bfloat16" : torch.bfloat16,
}
if "compute_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["compute_dtype"] = dtype_dict[self.xla_fsdp_config["compute_dtype"]]
if "buffer_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["buffer_dtype"] = dtype_dict[self.xla_fsdp_config["buffer_dtype"]]
if self.xla_fsdp_grad_ckpt and not self.xla_fsdp_nested:
raise ValueError(
"`--xla_fsdp_grad_ckpt` may only be used when --xla_fsdp_nested is enabled."
)
else:
if self.xla_fsdp_nested:
raise ValueError(
"`--xla_fsdp_nested` may only be used when --xla_fsdp is enabled."
)
elif self.xla_fsdp_grad_ckpt:
raise ValueError(
"`--xla_fsdp_grad_ckpt` may only be used when --xla_fsdp is enabled."
)



if self.push_to_hub_token is not None:
warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
Expand Down