File size: 4,579 Bytes
06d6367 151bb5a 06d6367 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import random
import sys
from collections.abc import Sequence
import pytest
import torch
import torch.distributed as dist
from packaging import version
from torch.distributed.tensor.placement_types import (Partial, Placement,
Replicate, Shard)
import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float32]
NUM_TOKENS = [512] # Arbitrary values for testing
SEQUENCE_DIMS = [0, 1] # 0 is for [T, D] (packed), 1 is for [B, S, D]
D = [16] # Arbitrary values for testing
SEEDS = [0]
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import (SequenceParallel,
parallelize_module)
@pytest.fixture(scope="session", autouse=True)
def init_dist(request):
if version.parse(torch.__version__) < version.parse("2.8"):
pytest.skip("torch>=2.8.0 is required for sequence parallel")
return
try:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
except Exception as e:
print(f"Failed to initialize torch.distributed: {e}")
pytest.skip("Failed to initialize torch.distributed")
if dist.get_world_size() < 2:
pytest.skip("Need at least 2 processes in dist group. "
"You can run with `torchrun --nproc-per-node=2 "
"--local-ranks-filter 0 -m pytest "
"test_rms_norm_sequence_parallel.py`")
yield
dist.destroy_process_group()
class Model(torch.nn.Module):
def __init__(self, num_tokens, d) -> None:
super().__init__()
self.rms_norm = activation.layers.RMSNorm(d)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.rms_norm(x)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("sequence_dim", SEQUENCE_DIMS)
def test_rms_norm_sequence_parallel(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
sequence_dim: int,
) -> None:
if num_tokens % dist.get_world_size() != 0:
# It hangs at `y.full_tensor()` if not divisible
pytest.skip("num_tokens must be divisible by world_size for sharding")
random.seed(seed)
torch.manual_seed(seed)
num_ranks = dist.get_world_size()
rank = dist.get_rank()
mesh = init_device_mesh("cuda", (num_ranks, ), mesh_dim_names=("shard", ))
match sequence_dim:
case 0:
x_shape = (num_tokens, d)
case 1:
BATCH_SIZE = 2
x_shape = (BATCH_SIZE, num_tokens, d)
case _:
raise ValueError(f"Invalid sequence_dim: {sequence_dim}")
x = torch.randn(x_shape, dtype=dtype, requires_grad=True).cuda()
weight = torch.ones(d, dtype=dtype, requires_grad=True).cuda()
eps = 1e-05
x.retain_grad()
weight.retain_grad()
# Copy x, weight for reference
x_ref = x.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
model_sharded = Model(num_tokens, d).to(dtype=dtype).cuda()
model_sharded.rms_norm.weight = torch.nn.Parameter(weight)
parallelize_module(
model_sharded, mesh,
{"rms_norm": SequenceParallel(sequence_dim=sequence_dim)})
x_sharded = DTensor.from_local(
x.chunk(num_ranks, dim=sequence_dim)[rank].contiguous(),
placements=(Shard(sequence_dim), ),
device_mesh=mesh,
)
y = model_sharded(x_sharded)
y_from_sharded = y.full_tensor()
model_unsharded = Model(num_tokens, d).to(dtype=dtype).cuda()
model_unsharded.rms_norm.weight = torch.nn.Parameter(weight_ref)
y_from_unsharded = model_unsharded(x_ref)
assert_close(y_from_sharded, y_from_unsharded)
# Backward
y_grad = torch.randn_like(y_from_unsharded)
y_from_sharded.backward(y_grad)
y_from_unsharded.backward(y_grad)
weight_grad_from_sharded = model_sharded.rms_norm.weight.grad._local_tensor
weight_grad_from_unsharded = model_unsharded.rms_norm.weight.grad
torch.distributed.all_reduce(x.grad, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(weight_grad_from_sharded,
op=torch.distributed.ReduceOp.SUM)
assert_close(x.grad, x_ref.grad)
assert_close(weight_grad_from_sharded, weight_grad_from_unsharded)
|