import random import sys from collections.abc import Sequence import pytest import torch import torch.distributed as dist from packaging import version from torch.distributed.tensor.placement_types import (Partial, Placement, Replicate, Shard) import activation from .utils import assert_close, opcheck DTYPES = [torch.float32] NUM_TOKENS = [512] # Arbitrary values for testing SEQUENCE_DIMS = [0, 1] # 0 is for [T, D] (packed), 1 is for [B, S, D] D = [16] # Arbitrary values for testing SEEDS = [0] from activation.parallel_style import ResidualSequenceParallel from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor.parallel import parallelize_module @pytest.fixture(scope="session", autouse=True) def init_dist(request): if version.parse(torch.__version__) < version.parse("2.8"): pytest.skip("torch>=2.8.0 is required for sequence parallel") return try: dist.init_process_group(backend="nccl") torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) except Exception as e: print(f"Failed to initialize torch.distributed: {e}") pytest.skip("Failed to initialize torch.distributed") if dist.get_world_size() < 2: pytest.skip("Need at least 2 processes in dist group. " "You can run with `torchrun --nproc-per-node=2 " "--local-ranks-filter 0 -m pytest " "test_rms_norm_sequence_parallel.py`") yield dist.destroy_process_group() class Model(torch.nn.Module): def __init__(self, num_tokens, d) -> None: super().__init__() self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d) def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: return self.fused_add_rms_norm(x, residual=residual) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS) @pytest.mark.parametrize("x_requires_grad", [True, False]) @pytest.mark.parametrize("residual_requires_grad", [True, False]) def test_fused_add_rms_norm_sequence_parallel( num_tokens: int, d: int, dtype: torch.dtype, seed: int, sequence_dim: int, x_requires_grad: bool, residual_requires_grad: bool, ) -> None: if num_tokens % dist.get_world_size() != 0: # It hangs at `y.full_tensor()` if not divisible pytest.skip("num_tokens must be divisible by world_size for sharding") if not x_requires_grad and not residual_requires_grad: pytest.skip("For now, at least one of x or residual must require grad") random.seed(seed) torch.manual_seed(seed) num_ranks = dist.get_world_size() rank = dist.get_rank() mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", )) match sequence_dim: case 0: x_shape = (num_tokens, d) case 1: BATCH_SIZE = 2 x_shape = (BATCH_SIZE, num_tokens, d) case _: raise ValueError(f"Invalid sequence_dim: {sequence_dim}") x = torch.randn(x_shape, dtype=dtype, requires_grad=x_requires_grad).cuda() residual = torch.randn(x_shape, dtype=dtype, requires_grad=residual_requires_grad).cuda() weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda() eps = 1e-05 if x_requires_grad: x.retain_grad() if residual_requires_grad: residual.retain_grad() weight.retain_grad() # Copy x, weight for reference x_ref = x.detach().clone().requires_grad_(True) residual_ref = residual.detach().clone().requires_grad_(True) weight_ref = weight.detach().clone().requires_grad_(True) model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda() model_sharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight) parallelize_module(model_sharded, mesh, { "fused_add_rms_norm": ResidualSequenceParallel(sequence_dim=sequence_dim) }) x_sharded = DTensor.from_local( x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(), placements=(Shard(sequence_dim), ), device_mesh=mesh, ) residual_sharded = DTensor.from_local( residual.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(), placements=(Shard(sequence_dim), ), device_mesh=mesh, ) y, add_output = model_sharded(x_sharded, residual_sharded) y_from_sharded = y.full_tensor() add_output_from_sharded = add_output.full_tensor() model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda() model_unsharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight_ref) y_from_unsharded, add_output_from_unsharded = model_unsharded( x_ref, residual_ref) assert_close(y_from_sharded, y_from_unsharded) assert_close(add_output_from_sharded, add_output_from_unsharded) # Backward y_grad = torch.randn_like(y_from_unsharded) add_output_grad = torch.randn_like(add_output_from_unsharded) (y_grad * y_from_sharded + add_output_grad * add_output_from_sharded).sum().backward() (y_grad * y_from_unsharded + add_output_grad * add_output_from_unsharded).sum().backward() weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad._local_tensor weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad assert (x.grad is None) ^ x_requires_grad assert (residual.grad is None) ^ residual_requires_grad torch.distributed.all_reduce(weight_grad_from_sharded, op=torch.distributed.ReduceOp.SUM) if x.grad is not None: torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM) assert_close(x.grad, x_ref.grad) if residual.grad is not None: torch.distributed.all_reduce(residual.grad, op=torch.distributed.ReduceOp.SUM) assert_close(residual.grad, residual_ref.grad) assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)