Planner¶
The TorchRec Planner is responsible for determining the most performant, balanced sharding plan for distributed training and inference.
The main API for generating a sharding plan is EmbeddingShardingPlanner.plan
- class torchrec.distributed.types.ShardingPlan(plan: Dict[str, ModuleShardingPlan])¶
Representation of sharding plan. This uses the FQN of the larger wrapped model (i.e the model that is wrapped using DistributedModelParallel) EmbeddingModuleShardingPlan should be used when TorchRec composability is desired.
- plan¶
dict keyed by module path of dict of parameter sharding specs keyed by parameter name.
- Type:
Dict[str, EmbeddingModuleShardingPlan]
- get_plan_for_module(module_path: str) Optional[ModuleShardingPlan] ¶
- Parameters:
module_path (str) –
- Returns:
dict of parameter sharding specs keyed by parameter name. None if sharding specs do not exist for given module_path.
- Return type:
Optional[ModuleShardingPlan]
- class torchrec.distributed.planner.planners.EmbeddingShardingPlanner(topology: Optional[Topology] = None, batch_size: Optional[int] = None, enumerator: Optional[Enumerator] = None, storage_reservation: Optional[StorageReservation] = None, proposer: Optional[Union[Proposer, List[Proposer]]] = None, partitioner: Optional[Partitioner] = None, performance_model: Optional[PerfModel] = None, stats: Optional[Union[Stats, List[Stats]]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, debug: bool = True, callbacks: Optional[List[Callable[[List[ShardingOption]], List[ShardingOption]]]] = None, timeout_seconds: Optional[int] = None)¶
Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.
- Parameters:
topology (Optional[Topology]) – the topology of the current process group.
batch_size (Optional[int]) – the batch size of the model.
enumerator (Optional[Enumerator]) – the enumerator to use
storage_reservation (Optional[StorageReservation]) – the storage reservation to use
proposer (Optional[Union[Proposer, List[Proposer]]]) – the proposer(s) to use
partitioner (Optional[Partitioner]) – the partitioner to use
performance_model (Optional[PerfModel]) – the performance model to use
stats (Optional[Union[Stats, List[Stats]]]) – the stats to use
constraints (Optional[Dict[str, ParameterConstraints]]) – per table constraints for sharding.
debug (bool) – whether to print debug information.
Example:
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta")) planner = EmbeddingShardingPlanner() plan = planner.plan( module=ebc, sharders=[EmbeddingBagCollectionSharder()], )
- collective_plan(module: Module, sharders: Optional[List[ModuleSharder[Module]]] = None, pg: Optional[ProcessGroup] = None) ShardingPlan ¶
Call self.plan(…) on rank 0 and broadcast
- Parameters:
module (nn.Module) – the module to shard.
sharders (Optional[List[ModuleSharder[nn.Module]]]) – the sharders to use for sharding
pg (Optional[dist.ProcessGroup]) – the process group to use for collective operations
- Returns:
the sharding plan for the module.
- Return type:
- plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan ¶
Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.
- Parameters:
module (nn.Module) – the module to shard.
sharders (List[ModuleSharder[nn.Module]]) – the sharders to use for sharding.
- Returns:
the sharding plan for the module.
- Return type:
- class torchrec.distributed.planner.enumerators.EmbeddingEnumerator(topology: Topology, batch_size: int, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None, use_exact_enumerate_order: Optional[bool] = False)¶
Generates embedding sharding options for given nn.Module, considering user provided constraints.
- Parameters:
topology (Topology) – device topology.
batch_size (int) – batch size.
constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.
estimator (Optional[Union[ShardEstimator, List[ShardEstimator]]]) – shard performance estimators.
use_exact_enumerate_order (bool) – whether to enumerate shardable parameters in the exact name_children enumeration order
- enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption] ¶
Generates relevant sharding options given module and sharders.
- Parameters:
module (nn.Module) – module to be sharded.
sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.
- Returns:
valid sharding options with values populated.
- Return type:
List[ShardingOption]
- populate_estimates(sharding_options: List[ShardingOption]) None ¶
See class description.
- class torchrec.distributed.planner.partitioners.GreedyPerfPartitioner(sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False)¶
Greedy Partitioner.
- Parameters:
sort_by (SortBy) – Sort sharding options by storage or perf in descending order (i.e., large tables will be placed first).
balance_modules (bool) – Whether to sort by modules first, where smaller modules will be sorted first. In effect, this will place tables in each module in a balanced way.
- partition(proposal: List[ShardingOption], storage_constraint: Topology, hbm_per_device: Optional[int] = None) List[ShardingOption] ¶
Places sharding options on topology based on each sharding option’s partition_by attribute. The topology, storage, and perfs are updated at the end of the placement.
- Parameters:
proposal (List[ShardingOption]) – list of populated sharding options.
storage_constraint (Topology) – device topology.
- Returns:
list of sharding options for selected plan.
- Return type:
List[ShardingOption]
Example:
sharding_options = [ ShardingOption(partition_by="uniform", shards=[ Shards(storage=1, perf=1), Shards(storage=1, perf=1), ]), ShardingOption(partition_by="uniform", shards=[ Shards(storage=2, perf=2), Shards(storage=2, perf=2), ]), ShardingOption(partition_by="device", shards=[ Shards(storage=3, perf=3), Shards(storage=3, perf=3), ]) ShardingOption(partition_by="device", shards=[ Shards(storage=4, perf=4), Shards(storage=4, perf=4), ]), ] topology = Topology(world_size=2) # First [sharding_options[0] and sharding_options[1]] will be placed on the # topology with the uniform strategy, resulting in topology.devices[0].perf.total = (1,2) topology.devices[1].perf.total = (1,2) # Finally sharding_options[2] and sharding_options[3]] will be placed on the # topology with the device strategy (see docstring of `partition_by_device` for # more details). topology.devices[0].perf.total = (1,2) + (3,4) topology.devices[1].perf.total = (1,2) + (3,4) # The topology updates are done after the end of all the placements (the other # in the example is just for clarity).
- class torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation(percentage: float, parameter_multiplier: float = 6.0, dense_tensor_estimate: Optional[int] = None)¶
Reserves storage for model to be sharded with heuristical calculation. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage.
- Parameters:
percentage (float) – extra storage percent to reserve that acts as a margin of error beyond heuristic calculation of storage.
parameter_multiplier (float) – heuristic multiplier for total parameter storage.
dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, uses default heuristic estimate if not provided.
- property last_reserved_topology: Optional[Topology]¶
Cached value of the most recent output from the reserve() method.
- class torchrec.distributed.planner.proposers.GreedyProposer(use_depth: bool = True, threshold: Optional[int] = None)¶
Proposes sharding plans in greedy fashion.
Sorts sharding options for each shardable parameter by perf. On each iteration, finds parameter with largest current storage usage and tries its next sharding option.
- Parameters:
use_depth (bool) – When enabled, sharding_options of a fqn are sorted based on max(shard.perf.total), otherwise sharding_options are sorted by sum(shard.perf.total).
threshold (Optional[int]) – Threshold for early stopping. When specified, the proposer stops proposing when the proposals have consecutive worse perf_rating than best_perf_rating.
- feedback(partitionable: bool, plan: Optional[