Skip to content

utils

Classes

fastvideo.distributed.utils.StatelessProcessGroup dataclass

StatelessProcessGroup(rank: int, world_size: int, store: Store, data_expiration_seconds: int = 3600, send_dst_counter: dict[int, int] = dict(), recv_src_counter: dict[int, int] = dict(), broadcast_send_counter: int = 0, broadcast_recv_src_counter: dict[int, int] = dict(), entries: deque[tuple[str, float]] = deque())

A dataclass to hold a metadata store, and the rank, world_size of the group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects.

Functions

fastvideo.distributed.utils.StatelessProcessGroup.all_gather_obj
all_gather_obj(obj: Any) -> list[Any]

All gather an object from all ranks.

Source code in fastvideo/distributed/utils.py
def all_gather_obj(self, obj: Any) -> list[Any]:
    """All gather an object from all ranks."""
    gathered_objs = []
    for i in range(self.world_size):
        if i == self.rank:
            gathered_objs.append(obj)
            self.broadcast_obj(obj, src=self.rank)
        else:
            recv_obj = self.broadcast_obj(None, src=i)
            gathered_objs.append(recv_obj)
    return gathered_objs
fastvideo.distributed.utils.StatelessProcessGroup.barrier
barrier()

A barrier to synchronize all ranks.

Source code in fastvideo/distributed/utils.py
def barrier(self):
    """A barrier to synchronize all ranks."""
    for i in range(self.world_size):
        if i == self.rank:
            self.broadcast_obj(None, src=self.rank)
        else:
            self.broadcast_obj(None, src=i)
fastvideo.distributed.utils.StatelessProcessGroup.broadcast_obj
broadcast_obj(obj: Any | None, src: int) -> Any

Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. Use it for limited times, e.g., for initialization.

Source code in fastvideo/distributed/utils.py
def broadcast_obj(self, obj: Any | None, src: int) -> Any:
    """Broadcast an object from a source rank to all other ranks.
    It does not clean up after all ranks have received the object.
    Use it for limited times, e.g., for initialization.
    """
    if self.rank == src:
        self.expire_data()
        key = (f"broadcast_from/{src}/"
               f"{self.broadcast_send_counter}")
        self.store.set(key, pickle.dumps(obj))
        self.broadcast_send_counter += 1
        self.entries.append((key, time.perf_counter()))
        return obj
    else:
        key = (f"broadcast_from/{src}/"
               f"{self.broadcast_recv_src_counter[src]}")
        recv_obj = pickle.loads(self.store.get(key))
        self.broadcast_recv_src_counter[src] += 1
        return recv_obj
fastvideo.distributed.utils.StatelessProcessGroup.create staticmethod
create(host: str, port: int, rank: int, world_size: int, data_expiration_seconds: int = 3600) -> StatelessProcessGroup

A replacement for torch.distributed.init_process_group that does not pollute the global state.

If we have process A and process B called torch.distributed.init_process_group to form a group, and then we want to form another group with process A, B, C, D, it is not possible in PyTorch, because process A and process B have already formed a group, and process C and process D cannot join that group. This function is a workaround for this issue.

torch.distributed.init_process_group is a global call, while this function is a stateless call. It will return a StatelessProcessGroup object that can be used for exchanging metadata. With this function, process A and process B can call StatelessProcessGroup.create to form a group, and then process A, B, C, and D can call StatelessProcessGroup.create to form another group.

Source code in fastvideo/distributed/utils.py
@staticmethod
def create(
    host: str,
    port: int,
    rank: int,
    world_size: int,
    data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
    """A replacement for `torch.distributed.init_process_group` that does not
    pollute the global state.

    If we have process A and process B called `torch.distributed.init_process_group`
    to form a group, and then we want to form another group with process A, B, C,
    D, it is not possible in PyTorch, because process A and process B have already
    formed a group, and process C and process D cannot join that group. This
    function is a workaround for this issue.

    `torch.distributed.init_process_group` is a global call, while this function
    is a stateless call. It will return a `StatelessProcessGroup` object that can be
    used for exchanging metadata. With this function, process A and process B
    can call `StatelessProcessGroup.create` to form a group, and then process A, B,
    C, and D can call `StatelessProcessGroup.create` to form another group.
    """ # noqa
    store = TCPStore(
        host_name=host,
        port=port,
        world_size=world_size,
        is_master=(rank == 0),
    )

    return StatelessProcessGroup(
        rank=rank,
        world_size=world_size,
        store=store,
        data_expiration_seconds=data_expiration_seconds)
fastvideo.distributed.utils.StatelessProcessGroup.expire_data
expire_data() -> None

Expire data that is older than data_expiration_seconds seconds.

Source code in fastvideo/distributed/utils.py
def expire_data(self) -> None:
    """Expire data that is older than `data_expiration_seconds` seconds."""
    while self.entries:
        # check the oldest entry
        key, timestamp = self.entries[0]
        if time.perf_counter() - timestamp > self.data_expiration_seconds:
            self.store.delete_key(key)
            self.entries.popleft()
        else:
            break
fastvideo.distributed.utils.StatelessProcessGroup.recv_obj
recv_obj(src: int) -> Any

Receive an object from a source rank.

Source code in fastvideo/distributed/utils.py
def recv_obj(self, src: int) -> Any:
    """Receive an object from a source rank."""
    obj = pickle.loads(
        self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
    self.recv_src_counter[src] += 1
    return obj
fastvideo.distributed.utils.StatelessProcessGroup.send_obj
send_obj(obj: Any, dst: int)

Send an object to a destination rank.

Source code in fastvideo/distributed/utils.py
def send_obj(self, obj: Any, dst: int):
    """Send an object to a destination rank."""
    self.expire_data()
    key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
    self.store.set(key, pickle.dumps(obj))
    self.send_dst_counter[dst] += 1
    self.entries.append((key, time.perf_counter()))

Functions

fastvideo.distributed.utils.compute_padding_for_sp

compute_padding_for_sp(seq_len: int, sp_world_size: int) -> tuple[int, int]

Compute padding needed for sequence parallel.

Parameters:

Name Type Description Default
seq_len int

Original sequence length

required
sp_world_size int

Sequence parallel world size

required

Returns:

Name Type Description
tuple tuple[int, int]

(padded_seq_len, padding_amount)

Source code in fastvideo/distributed/utils.py
def compute_padding_for_sp(seq_len: int, sp_world_size: int) -> tuple[int, int]:
    """
    Compute padding needed for sequence parallel.

    Args:
        seq_len: Original sequence length
        sp_world_size: Sequence parallel world size

    Returns:
        tuple: (padded_seq_len, padding_amount)
    """
    if seq_len % sp_world_size == 0:
        return seq_len, 0

    padding_amount = sp_world_size - (seq_len % sp_world_size)
    padded_seq_len = seq_len + padding_amount

    return padded_seq_len, padding_amount

fastvideo.distributed.utils.create_attention_mask_for_padding

create_attention_mask_for_padding(seq_len: int, padded_seq_len: int, batch_size: int, device: device, dtype: dtype = bool) -> Tensor | None

Create attention mask to ignore padded tokens.

Parameters:

Name Type Description Default
seq_len int

Original sequence length (before padding)

required
padded_seq_len int

Padded sequence length

required
batch_size int

Batch size

required
device device

Device to create mask on

required
dtype dtype

Data type for the mask (default: bool)

bool

Returns:

Name Type Description
Tensor Tensor | None

Boolean mask [B, padded_seq_len] where True = valid token, or None if no padding is needed

Source code in fastvideo/distributed/utils.py
def create_attention_mask_for_padding(
    seq_len: int,
    padded_seq_len: int,
    batch_size: int,
    device: torch.device,
    dtype: torch.dtype = torch.bool,
) -> torch.Tensor | None:
    """
    Create attention mask to ignore padded tokens.

    Args:
        seq_len: Original sequence length (before padding)
        padded_seq_len: Padded sequence length
        batch_size: Batch size
        device: Device to create mask on
        dtype: Data type for the mask (default: bool)

    Returns:
        Tensor: Boolean mask [B, padded_seq_len] where True = valid token,
                or None if no padding is needed
    """
    if seq_len == padded_seq_len:
        return None

    # Create mask: True for valid tokens, False for padding
    attention_mask = torch.ones(
        (batch_size, padded_seq_len),
        dtype=dtype,
        device=device,
    )

    # Mask out padding tokens
    attention_mask[:, seq_len:] = 0

    return attention_mask

fastvideo.distributed.utils.divide

divide(numerator: int, denominator: int) -> int

Ensure that numerator is divisible by the denominator and return the division value.

Source code in fastvideo/distributed/utils.py
def divide(numerator: int, denominator: int) -> int:
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

fastvideo.distributed.utils.ensure_divisibility

ensure_divisibility(numerator, denominator) -> None

Ensure that numerator is divisible by the denominator.

Source code in fastvideo/distributed/utils.py
def ensure_divisibility(numerator, denominator) -> None:
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(
        numerator, denominator)

fastvideo.distributed.utils.pad_sequence_tensor

pad_sequence_tensor(tensor: Tensor, target_seq_len: int, seq_dim: int = 1, pad_value: float = 0.0) -> Tensor

Pad a tensor along the sequence dimension.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor to pad

required
target_seq_len int

Target sequence length after padding

required
seq_dim int

Dimension to pad along (default: 1)

1
pad_value float

Value to use for padding (default: 0.0)

0.0

Returns:

Name Type Description
Tensor Tensor

Padded tensor

Source code in fastvideo/distributed/utils.py
def pad_sequence_tensor(
    tensor: torch.Tensor,
    target_seq_len: int,
    seq_dim: int = 1,
    pad_value: float = 0.0,
) -> torch.Tensor:
    """
    Pad a tensor along the sequence dimension.

    Args:
        tensor: Input tensor to pad
        target_seq_len: Target sequence length after padding
        seq_dim: Dimension to pad along (default: 1)
        pad_value: Value to use for padding (default: 0.0)

    Returns:
        Tensor: Padded tensor
    """
    current_seq_len = tensor.shape[seq_dim]

    if current_seq_len >= target_seq_len:
        return tensor

    padding_amount = target_seq_len - current_seq_len

    # Create padding shape
    pad_shape = list(tensor.shape)
    pad_shape[seq_dim] = padding_amount

    # Create padding tensor
    padding = torch.full(
        pad_shape,
        pad_value,
        dtype=tensor.dtype,
        device=tensor.device,
    )

    # Concatenate along sequence dimension
    padded_tensor = torch.cat([tensor, padding], dim=seq_dim)

    return padded_tensor

fastvideo.distributed.utils.split_tensor_along_last_dim

split_tensor_along_last_dim(tensor: Tensor, num_partitions: int, contiguous_split_chunks: bool = False) -> Sequence[Tensor]

Split a tensor along its last dimension.

Parameters:

Name Type Description Default
tensor Tensor

input tensor.

required
num_partitions int

number of partitions to split the tensor

required
contiguous_split_chunks bool

If True, make each chunk contiguous in memory.

False

Returns:

Type Description
Sequence[Tensor]

A list of Tensors

Source code in fastvideo/distributed/utils.py
def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
    """ Split a tensor along its last dimension.

        Arguments:
            tensor: input tensor.
            num_partitions: number of partitions to split the tensor
            contiguous_split_chunks: If True, make each chunk contiguous
                                     in memory.

        Returns:
            A list of Tensors
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # NOTE: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tuple(tensor_list)

fastvideo.distributed.utils.unpad_sequence_tensor

unpad_sequence_tensor(tensor: Tensor, original_seq_len: int, seq_dim: int = 1) -> Tensor

Remove padding from a tensor along the sequence dimension.

Parameters:

Name Type Description Default
tensor Tensor

Padded tensor

required
original_seq_len int

Original sequence length (before padding)

required
seq_dim int

Dimension to unpad along (default: 1)

1

Returns:

Name Type Description
Tensor Tensor

Unpadded tensor

Source code in fastvideo/distributed/utils.py
def unpad_sequence_tensor(
    tensor: torch.Tensor,
    original_seq_len: int,
    seq_dim: int = 1,
) -> torch.Tensor:
    """
    Remove padding from a tensor along the sequence dimension.

    Args:
        tensor: Padded tensor
        original_seq_len: Original sequence length (before padding)
        seq_dim: Dimension to unpad along (default: 1)

    Returns:
        Tensor: Unpadded tensor
    """
    # Use slice to remove padding
    indices = [slice(None)] * tensor.dim()
    indices[seq_dim] = slice(0, original_seq_len)

    return tensor[tuple(indices)]