|
|
import random |
|
|
import sys |
|
|
from collections.abc import Sequence |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from packaging import version |
|
|
from torch.distributed.tensor.placement_types import (Partial, Placement, |
|
|
Replicate, Shard) |
|
|
|
|
|
import activation |
|
|
|
|
|
from .utils import assert_close, opcheck |
|
|
|
|
|
DTYPES = [torch.float32] |
|
|
NUM_TOKENS = [512] |
|
|
SEQUENCE_DIMS = [0, 1] |
|
|
D = [16] |
|
|
SEEDS = [0] |
|
|
|
|
|
from torch.distributed._tensor import DTensor |
|
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
|
|
from torch.distributed.tensor.parallel import (SequenceParallel, |
|
|
parallelize_module) |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True) |
|
|
def init_dist(request): |
|
|
if version.parse(torch.__version__) < version.parse("2.8"): |
|
|
pytest.skip("torch>=2.8.0 is required for sequence parallel") |
|
|
return |
|
|
|
|
|
try: |
|
|
dist.init_process_group(backend="nccl") |
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize torch.distributed: {e}") |
|
|
pytest.skip("Failed to initialize torch.distributed") |
|
|
|
|
|
if dist.get_world_size() < 2: |
|
|
pytest.skip("Need at least 2 processes in dist group. " |
|
|
"You can run with `torchrun --nproc-per-node=2 " |
|
|
"--local-ranks-filter 0 -m pytest " |
|
|
"test_rms_norm_sequence_parallel.py`") |
|
|
|
|
|
yield |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
class Model(torch.nn.Module): |
|
|
|
|
|
def __init__(self, num_tokens, d) -> None: |
|
|
super().__init__() |
|
|
self.rms_norm = activation.layers.RMSNorm(d) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.rms_norm(x) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
|
|
@pytest.mark.parametrize("d", D) |
|
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
|
@pytest.mark.parametrize("seed", SEEDS) |
|
|
@pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS) |
|
|
def test_rms_norm_sequence_parallel( |
|
|
num_tokens: int, |
|
|
d: int, |
|
|
dtype: torch.dtype, |
|
|
seed: int, |
|
|
sequence_dim: int, |
|
|
) -> None: |
|
|
if num_tokens % dist.get_world_size() != 0: |
|
|
|
|
|
pytest.skip("num_tokens must be divisible by world_size for sharding") |
|
|
|
|
|
random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
num_ranks = dist.get_world_size() |
|
|
rank = dist.get_rank() |
|
|
mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", )) |
|
|
|
|
|
match sequence_dim: |
|
|
case 0: |
|
|
x_shape = (num_tokens, d) |
|
|
case 1: |
|
|
BATCH_SIZE = 2 |
|
|
x_shape = (BATCH_SIZE, num_tokens, d) |
|
|
case _: |
|
|
raise ValueError(f"Invalid sequence_dim: {sequence_dim}") |
|
|
|
|
|
x = torch.randn(x_shape, dtype=dtype, requires_grad=True).cuda() |
|
|
weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda() |
|
|
eps = 1e-05 |
|
|
|
|
|
x.retain_grad() |
|
|
weight.retain_grad() |
|
|
|
|
|
|
|
|
x_ref = x.detach().clone().requires_grad_(True) |
|
|
weight_ref = weight.detach().clone().requires_grad_(True) |
|
|
|
|
|
model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda() |
|
|
model_sharded.rms_norm.weight = torch.nn.Parameter(weight) |
|
|
parallelize_module( |
|
|
model_sharded, mesh, |
|
|
{"rms_norm": SequenceParallel(sequence_dim=sequence_dim)}) |
|
|
x_sharded = DTensor.from_local( |
|
|
x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(), |
|
|
placements=(Shard(sequence_dim), ), |
|
|
device_mesh=mesh, |
|
|
) |
|
|
y = model_sharded(x_sharded) |
|
|
y_from_sharded = y.full_tensor() |
|
|
|
|
|
model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda() |
|
|
model_unsharded.rms_norm.weight = torch.nn.Parameter(weight_ref) |
|
|
|
|
|
y_from_unsharded = model_unsharded(x_ref) |
|
|
|
|
|
assert_close(y_from_sharded, y_from_unsharded) |
|
|
|
|
|
|
|
|
y_grad = torch.randn_like(y_from_unsharded) |
|
|
y_from_sharded.backward(y_grad) |
|
|
y_from_unsharded.backward(y_grad) |
|
|
|
|
|
weight_grad_from_sharded = model_sharded.rms_norm.weight.grad._local_tensor |
|
|
weight_grad_from_unsharded = model_unsharded.rms_norm.weight.grad |
|
|
|
|
|
torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM) |
|
|
torch.distributed.all_reduce(weight_grad_from_sharded, |
|
|
op=torch.distributed.ReduceOp.SUM) |
|
|
|
|
|
assert_close(x.grad, x_ref.grad) |
|
|
assert_close(weight_grad_from_sharded, weight_grad_from_unsharded) |
|
|
|