Shortcuts

Planner

The TorchRec Planner is responsible for determining the most performant, balanced sharding plan for distributed training and inference.

The main API for generating a sharding plan is EmbeddingShardingPlanner.plan

class torchrec.distributed.types.ShardingPlan(plan: Dict[str, ModuleShardingPlan])

Representation of sharding plan. This uses the FQN of the larger wrapped model (i.e the model that is wrapped using DistributedModelParallel) EmbeddingModuleShardingPlan should be used when TorchRec composability is desired.

plan

dict keyed by module path of dict of parameter sharding specs keyed by parameter name.

Type:

Dict[str, EmbeddingModuleShardingPlan]

get_plan_for_module(module_path: str) Optional[ModuleShardingPlan]
Parameters:

module_path (str) –

Returns:

dict of parameter sharding specs keyed by parameter name. None if sharding specs do not exist for given module_path.

Return type:

Optional[ModuleShardingPlan]

class torchrec.distributed.planner.planners.EmbeddingShardingPlanner(topology: Optional[Topology] = None, batch_size: Optional[int] = None, enumerator: Optional[Enumerator] = None, storage_reservation: Optional[StorageReservation] = None, proposer: Optional[Union[Proposer, List[Proposer]]] = None, partitioner: Optional[Partitioner] = None, performance_model: Optional[PerfModel] = None, stats: Optional[Union[Stats, List[Stats]]] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, debug: bool = True, callbacks: Optional[List[Callable[[List[ShardingOption]], List[ShardingOption]]]] = None, timeout_seconds: Optional[int] = None)

Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.

Parameters:
  • topology (Optional[Topology]) – the topology of the current process group.

  • batch_size (Optional[int]) – the batch size of the model.

  • enumerator (Optional[Enumerator]) – the enumerator to use

  • storage_reservation (Optional[StorageReservation]) – the storage reservation to use

  • proposer (Optional[Union[Proposer, List[Proposer]]]) – the proposer(s) to use

  • partitioner (Optional[Partitioner]) – the partitioner to use

  • performance_model (Optional[PerfModel]) – the performance model to use

  • stats (Optional[Union[Stats, List[Stats]]]) – the stats to use

  • constraints (Optional[Dict[str, ParameterConstraints]]) – per table constraints for sharding.

  • debug (bool) – whether to print debug information.

Example:

ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
planner = EmbeddingShardingPlanner()
plan = planner.plan(
    module=ebc,
    sharders=[EmbeddingBagCollectionSharder()],
)
collective_plan(module: Module, sharders: Optional[List[ModuleSharder[Module]]] = None, pg: Optional[ProcessGroup] = None) ShardingPlan

Call self.plan(…) on rank 0 and broadcast

Parameters:
  • module (nn.Module) – the module to shard.

  • sharders (Optional[List[ModuleSharder[nn.Module]]]) – the sharders to use for sharding

  • pg (Optional[dist.ProcessGroup]) – the process group to use for collective operations

Returns:

the sharding plan for the module.

Return type:

ShardingPlan

plan(module: Module, sharders: List[ModuleSharder[Module]]) ShardingPlan

Provides an optimized sharding plan for a given module with shardable parameters according to the provided sharders, topology, and constraints.

Parameters:
  • module (nn.Module) – the module to shard.

  • sharders (List[ModuleSharder[nn.Module]]) – the sharders to use for sharding.

Returns:

the sharding plan for the module.

Return type:

ShardingPlan

class torchrec.distributed.planner.enumerators.EmbeddingEnumerator(topology: Topology, batch_size: int, constraints: Optional[Dict[str, ParameterConstraints]] = None, estimator: Optional[Union[ShardEstimator, List[ShardEstimator]]] = None, use_exact_enumerate_order: Optional[bool] = False)

Generates embedding sharding options for given nn.Module, considering user provided constraints.

Parameters:
  • topology (Topology) – device topology.

  • batch_size (int) – batch size.

  • constraints (Optional[Dict[str, ParameterConstraints]]) – dict of parameter names to provided ParameterConstraints.

  • estimator (Optional[Union[ShardEstimator, List[ShardEstimator]]]) – shard performance estimators.

  • use_exact_enumerate_order (bool) – whether to enumerate shardable parameters in the exact name_children enumeration order

enumerate(module: Module, sharders: List[ModuleSharder[Module]]) List[ShardingOption]

Generates relevant sharding options given module and sharders.

Parameters:
  • module (nn.Module) – module to be sharded.

  • sharders (List[ModuleSharder[nn.Module]]) – provided sharders for module.

Returns:

valid sharding options with values populated.

Return type:

List[ShardingOption]

populate_estimates(sharding_options: List[ShardingOption]) None

See class description.

class torchrec.distributed.planner.partitioners.GreedyPerfPartitioner(sort_by: SortBy = SortBy.STORAGE, balance_modules: bool = False)

Greedy Partitioner.

Parameters:
  • sort_by (SortBy) – Sort sharding options by storage or perf in descending order (i.e., large tables will be placed first).

  • balance_modules (bool) – Whether to sort by modules first, where smaller modules will be sorted first. In effect, this will place tables in each module in a balanced way.

partition(proposal: List[ShardingOption], storage_constraint: Topology, hbm_per_device: Optional[int] = None) List[ShardingOption]

Places sharding options on topology based on each sharding option’s partition_by attribute. The topology, storage, and perfs are updated at the end of the placement.

Parameters:
  • proposal (List[ShardingOption]) – list of populated sharding options.

  • storage_constraint (Topology) – device topology.

Returns:

list of sharding options for selected plan.

Return type:

List[ShardingOption]

Example:

sharding_options = [
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=1, perf=1),
                    Shards(storage=1, perf=1),
                ]),
        ShardingOption(partition_by="uniform",
                shards=[
                    Shards(storage=2, perf=2),
                    Shards(storage=2, perf=2),
                ]),
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=3, perf=3),
                    Shards(storage=3, perf=3),
                ])
        ShardingOption(partition_by="device",
                shards=[
                    Shards(storage=4, perf=4),
                    Shards(storage=4, perf=4),
                ]),
    ]
topology = Topology(world_size=2)

# First [sharding_options[0] and sharding_options[1]] will be placed on the
# topology with the uniform strategy, resulting in

topology.devices[0].perf.total = (1,2)
topology.devices[1].perf.total = (1,2)

# Finally sharding_options[2] and sharding_options[3]] will be placed on the
# topology with the device strategy (see docstring of `partition_by_device` for
# more details).

topology.devices[0].perf.total = (1,2) + (3,4)
topology.devices[1].perf.total = (1,2) + (3,4)

# The topology updates are done after the end of all the placements (the other
# in the example is just for clarity).
class torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation(percentage: float, parameter_multiplier: float = 6.0, dense_tensor_estimate: Optional[int] = None)

Reserves storage for model to be sharded with heuristical calculation. The storage reservation is comprised of dense tensor storage, KJT storage, and an extra percentage of total storage.

Parameters:
  • percentage (float) – extra storage percent to reserve that acts as a margin of error beyond heuristic calculation of storage.

  • parameter_multiplier (float) – heuristic multiplier for total parameter storage.

  • dense_tensor_estimate (Optional[int]) – storage estimate for dense tensors, uses default heuristic estimate if not provided.

property last_reserved_topology: Optional[Topology]

Cached value of the most recent output from the reserve() method.

class torchrec.distributed.planner.proposers.GreedyProposer(use_depth: bool = True, threshold: Optional[int] = None)

Proposes sharding plans in greedy fashion.

Sorts sharding options for each shardable parameter by perf. On each iteration, finds parameter with largest current storage usage and tries its next sharding option.

Parameters:
  • use_depth (bool) – When enabled, sharding_options of a fqn are sorted based on max(shard.perf.total), otherwise sharding_options are sorted by sum(shard.perf.total).

  • threshold (Optional[int]) – Threshold for early stopping. When specified, the proposer stops proposing when the proposals have consecutive worse perf_rating than best_perf_rating.

feedback(partitionable: bool, plan: Optional[