Kernels
activation / tests /test_fused_add_rms_norm_sequence_parallel.py
wyldecat's picture
feat: support sequence parallel with fused_add_rms_norm
151bb5a
import random
import sys
from collections.abc import Sequence
import pytest
import torch
import torch.distributed as dist
from packaging import version
from torch.distributed.tensor.placement_types import (Partial, Placement,
Replicate, Shard)
import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float32]
NUM_TOKENS = [512] # Arbitrary values for testing
SEQUENCE_DIMS = [0, 1] # 0 is for [T, D] (packed), 1 is for [B, S, D]
D = [16] # Arbitrary values for testing
SEEDS = [0]
from activation.parallel_style import ResidualSequenceParallel
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
@pytest.fixture(scope="session", autouse=True)
def init_dist(request):
if version.parse(torch.__version__) < version.parse("2.8"):
pytest.skip("torch>=2.8.0 is required for sequence parallel")
return
try:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
except Exception as e:
print(f"Failed to initialize torch.distributed: {e}")
pytest.skip("Failed to initialize torch.distributed")
if dist.get_world_size() < 2:
pytest.skip("Need at least 2 processes in dist group. "
"You can run with `torchrun --nproc-per-node=2 "
"--local-ranks-filter 0 -m pytest "
"test_rms_norm_sequence_parallel.py`")
yield
dist.destroy_process_group()
class Model(torch.nn.Module):
def __init__(self, num_tokens, d) -> None:
super().__init__()
self.fused_add_rms_norm = activation.layers.FusedAddRMSNorm(d)
def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
return self.fused_add_rms_norm(x, residual=residual)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS)
@pytest.mark.parametrize("x_requires_grad", [True, False])
@pytest.mark.parametrize("residual_requires_grad", [True, False])
def test_fused_add_rms_norm_sequence_parallel(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
sequence_dim: int,
x_requires_grad: bool,
residual_requires_grad: bool,
) -> None:
if num_tokens % dist.get_world_size() != 0:
# It hangs at `y.full_tensor()` if not divisible
pytest.skip("num_tokens must be divisible by world_size for sharding")
if not x_requires_grad and not residual_requires_grad:
pytest.skip("For now, at least one of x or residual must require grad")
random.seed(seed)
torch.manual_seed(seed)
num_ranks = dist.get_world_size()
rank = dist.get_rank()
mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", ))
match sequence_dim:
case 0:
x_shape = (num_tokens, d)
case 1:
BATCH_SIZE = 2
x_shape = (BATCH_SIZE, num_tokens, d)
case _:
raise ValueError(f"Invalid sequence_dim: {sequence_dim}")
x = torch.randn(x_shape, dtype=dtype, requires_grad=x_requires_grad).cuda()
residual = torch.randn(x_shape,
dtype=dtype,
requires_grad=residual_requires_grad).cuda()
weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda()
eps = 1e-05
if x_requires_grad:
x.retain_grad()
if residual_requires_grad:
residual.retain_grad()
weight.retain_grad()
# Copy x, weight for reference
x_ref = x.detach().clone().requires_grad_(True)
residual_ref = residual.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda()
model_sharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight)
parallelize_module(model_sharded, mesh, {
"fused_add_rms_norm":
ResidualSequenceParallel(sequence_dim=sequence_dim)
})
x_sharded = DTensor.from_local(
x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(),
placements=(Shard(sequence_dim), ),
device_mesh=mesh,
)
residual_sharded = DTensor.from_local(
residual.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(),
placements=(Shard(sequence_dim), ),
device_mesh=mesh,
)
y, add_output = model_sharded(x_sharded, residual_sharded)
y_from_sharded = y.full_tensor()
add_output_from_sharded = add_output.full_tensor()
model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda()
model_unsharded.fused_add_rms_norm.weight = torch.nn.Parameter(weight_ref)
y_from_unsharded, add_output_from_unsharded = model_unsharded(
x_ref, residual_ref)
assert_close(y_from_sharded, y_from_unsharded)
assert_close(add_output_from_sharded, add_output_from_unsharded)
# Backward
y_grad = torch.randn_like(y_from_unsharded)
add_output_grad = torch.randn_like(add_output_from_unsharded)
(y_grad * y_from_sharded +
add_output_grad * add_output_from_sharded).sum().backward()
(y_grad * y_from_unsharded +
add_output_grad * add_output_from_unsharded).sum().backward()
weight_grad_from_sharded = model_sharded.fused_add_rms_norm.weight.grad._local_tensor
weight_grad_from_unsharded = model_unsharded.fused_add_rms_norm.weight.grad
assert (x.grad is None) ^ x_requires_grad
assert (residual.grad is None) ^ residual_requires_grad
torch.distributed.all_reduce(weight_grad_from_sharded,
op=torch.distributed.ReduceOp.SUM)
if x.grad is not None:
torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM)
assert_close(x.grad, x_ref.grad)
if residual.grad is not None:
torch.distributed.all_reduce(residual.grad,
op=torch.distributed.ReduceOp.SUM)
assert_close(residual.grad, residual_ref.grad)
assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)