diff --git a/README.md b/README.md index 219f3b3c19af202414aa8dbc0b6a885a05ffa7c7..fc54cd190280e4aee1712d5a74c5d28ca5c686a3 100644 --- a/README.md +++ b/README.md @@ -18,13 +18,15 @@ Activation is a python package that contains custom CUDA-based activation kernel ```python y = x + residual - out = rms_norm(y, weight, eps) + hidden_state = rms_norm(y, weight, eps) + out = y + some_op(hidden_state) ``` - Fused as: ```python - out = fused_add_rms_norm(x, residual, weight, eps) + hidden_state, y = fused_add_rms_norm(x, residual, weight, eps) + out = y + some_op(hidden_state) ``` - **FusedMulPolyNorm** diff --git a/activation/fused_add_rms_norm.cu b/activation/fused_add_rms_norm.cu index 9e73bd8f61fbd1ad59824398cf6b43e121071ea8..7d27947d377d2880e1e9c8cf2eb5c49a6fc366ff 100644 --- a/activation/fused_add_rms_norm.cu +++ b/activation/fused_add_rms_norm.cu @@ -117,9 +117,175 @@ fused_add_rms_norm_kernel(scalar_t *__restrict__ out, // [..., d] } } +template +__global__ std::enable_if_t<(width > 0)> fused_add_rms_norm_backward_kernel( + scalar_t *__restrict__ input_grad, // [..., d] + acc_t *__restrict__ temp_weight_grad, // [..., d] + const scalar_t *__restrict__ output_grad, // [..., d] + const scalar_t *__restrict__ add_output_grad, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { + using vec_t = type_vec_t; + using dw_vec_t = type_vec_t; + + const int64_t token_idx = blockIdx.x; + const int64_t vec_idx = threadIdx.x; + + const int vec_d = d / width; + const int64_t vec_offset = token_idx * vec_d; + + const vec_t *__restrict__ input_vec = reinterpret_cast(input); + const vec_t *__restrict__ output_grad_vec = + reinterpret_cast(output_grad); + const vec_t *__restrict__ weight_vec = + reinterpret_cast(weight); + + acc_t d_sum = 0.0f; + acc_t sum_square = 0.0f; + + for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + vidx]; + vec_t dy_vec = output_grad_vec[vec_offset + vidx]; + vec_t w_vec = weight_vec[vidx]; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x = x_vec.data[i]; + acc_t dy = dy_vec.data[i]; + acc_t w = w_vec.data[i]; + d_sum += dy * x * w; + sum_square += x * x; + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + struct SumOp { + __device__ float2 operator()(const float2 &a, const float2 &b) const { + return make_float2(a.x + b.x, a.y + b.y); + } + }; + float2 thread_sums = make_float2(d_sum, sum_square); + float2 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + d_sum = block_sums.x; + sum_square = block_sums.y; + + __shared__ acc_t s_scale; + __shared__ acc_t s_dxx; + + if (threadIdx.x == 0) { + acc_t scale = rsqrtf(sum_square / d + eps); + s_dxx = d_sum * scale * scale * scale / d; + s_scale = scale; + } + __syncthreads(); + acc_t scale = s_scale; + acc_t dxx = s_dxx; + vec_t *__restrict__ input_grad_vec = reinterpret_cast(input_grad); + dw_vec_t *__restrict__ temp_weight_grad_vec = + reinterpret_cast(temp_weight_grad); + const vec_t *__restrict__ add_output_grad_vec = + reinterpret_cast(add_output_grad); + + for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + vidx]; + vec_t dy_vec = output_grad_vec[vec_offset + vidx]; + vec_t da_vec = add_output_grad_vec[vec_offset + vidx]; + vec_t w_vec = weight_vec[vidx]; + + vec_t in_grad_vec; + dw_vec_t tw_grad_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x = x_vec.data[i]; + acc_t dy = dy_vec.data[i]; + acc_t w = w_vec.data[i]; + + if (input_grad) { + scalar_t da = da_vec.data[i]; + scalar_t in_grad = scale * dy * w - dxx * x; + in_grad_vec.data[i] = in_grad + da; + } + tw_grad_vec.data[i] = dy * x * scale; + } + + if (input_grad) { + input_grad_vec[vec_offset + vidx] = in_grad_vec; + } + temp_weight_grad_vec[vec_offset + vidx] = tw_grad_vec; + } +} + +template +__global__ std::enable_if_t<(width == 0)> fused_add_rms_norm_backward_kernel( + scalar_t *__restrict__ input_grad, // [..., d] + acc_t *__restrict__ temp_weight_grad, // [..., d] + const scalar_t *__restrict__ output_grad, // [..., d] + const scalar_t *__restrict__ add_output_grad, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { + const int64_t token_idx = blockIdx.x; + const int64_t vec_idx = threadIdx.x; + acc_t d_sum = 0.0f; + acc_t sum_square = 0.0f; + + for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) { + acc_t x = input[token_idx * d + idx]; + acc_t dy = output_grad[token_idx * d + idx]; + acc_t w = weight[idx]; + d_sum += dy * x * w; + sum_square += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + struct SumOp { + __device__ float2 operator()(const float2 &a, const float2 &b) const { + return make_float2(a.x + b.x, a.y + b.y); + } + }; + float2 thread_sums = make_float2(d_sum, sum_square); + float2 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + d_sum = block_sums.x; + sum_square = block_sums.y; + + __shared__ acc_t s_scale; + __shared__ acc_t s_dxx; + + if (threadIdx.x == 0) { + acc_t scale = rsqrtf(sum_square / d + eps); + s_dxx = d_sum * scale * scale * scale / d; + s_scale = scale; + } + __syncthreads(); + + acc_t scale = s_scale; + acc_t dxx = s_dxx; + + for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) { + acc_t x = input[token_idx * d + idx]; + acc_t dy = output_grad[token_idx * d + idx]; + acc_t w = weight[idx]; + + if (input_grad) { + scalar_t da = add_output_grad[token_idx * d + idx]; + scalar_t in_grad = scale * dy * w - dxx * x; + input_grad[token_idx * d + idx] = in_grad + da; + } + temp_weight_grad[token_idx * d + idx] = dy * x * scale; + } +} + } // namespace motif -#define LAUNCH_RMS_NORM(width) \ +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ MOTIF_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ motif::fused_add_rms_norm_kernel \ @@ -150,8 +316,60 @@ void fused_add_rms_norm(torch::Tensor &out, // [..., d] const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (d % 8 == 0) { - LAUNCH_RMS_NORM(8); + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_FUSED_ADD_RMS_NORM(0); + } +} + +#define LAUNCH_FUSED_ADD_RMS_NORM_BWD(width) \ + MOTIF_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_backward_kernel", [&] { \ + motif::fused_add_rms_norm_backward_kernel \ + <<>>(input_grad.data_ptr(), \ + temp_weight_grad.data_ptr(), \ + output_grad.data_ptr(), \ + add_output_grad.data_ptr(), \ + input.data_ptr(), \ + weight.data_ptr(), eps, d); \ + }); + +void fused_add_rms_norm_backward( + torch::Tensor &input_grad, // [..., d] + torch::Tensor &weight_grad, // [d] + const torch::Tensor &output_grad, // [..., d] + const torch::Tensor &add_output_grad, // [..., d] + const torch::Tensor &input, // [..., d] + const torch::Tensor &weight, // [d] + double eps) { + AssertTensorShapeEqual(input, input_grad, "input", "input_grad"); + AssertTensorShapeEqual(input, output_grad, "input", "output_grad"); + AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad"); + AssertTensorNotNull(weight, "weight"); + // TODO shape check + // weight_grad, input_grad can be nullable + + int d = input.size(-1); + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(d, max_block_size)); + + torch::Tensor temp_weight_grad = + torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat)); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (d % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM_BWD(8); } else { - LAUNCH_RMS_NORM(0); + LAUNCH_FUSED_ADD_RMS_NORM_BWD(0); + } + + if (weight_grad.defined()) { + torch::Tensor acc = + torch::empty_like(weight_grad, temp_weight_grad.options()); + at::sum_out(acc, temp_weight_grad, {0}); + weight_grad.copy_(acc); } } diff --git a/activation/fused_mul_poly_norm.cu b/activation/fused_mul_poly_norm.cu index ce35c1dd67be0fd4c41658b55fc433bd87d1038a..42ef350d0be4fe42d34e4e435372e15debb42087 100644 --- a/activation/fused_mul_poly_norm.cu +++ b/activation/fused_mul_poly_norm.cu @@ -573,7 +573,7 @@ void fused_mul_poly_norm(torch::Tensor &out, // [..., d] } } -#define LAUNCH_POLY_NORM_BACKWARD(width) \ +#define LAUNCH_FUSED_MUL_POLY_NORM_BACKWARD(width) \ MOTIF_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "fused_mul_poly_norm_backward_kernel", [&] { \ motif::fused_mul_poly_norm_backward_kernel \ @@ -620,11 +620,11 @@ void fused_mul_poly_norm_backward(torch::Tensor &input_grad, // [..., d] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (d % 8 == 0 && input.element_size() == 2) { - LAUNCH_POLY_NORM_BACKWARD(8); + LAUNCH_FUSED_MUL_POLY_NORM_BACKWARD(8); } else if (d % 4 == 0 && input.element_size() == 4) { - LAUNCH_POLY_NORM_BACKWARD(4); + LAUNCH_FUSED_MUL_POLY_NORM_BACKWARD(4); } else { - LAUNCH_POLY_NORM_BACKWARD(0); + LAUNCH_FUSED_MUL_POLY_NORM_BACKWARD(0); } if (bias_grad.defined()) { diff --git a/benchmarks/cases/add_rms.py b/benchmarks/cases/add_rms.py index 5e055e197c2a9e8540c94b579e88db63824ce424..d0585ecf3ca2d32ca636ed0ab6524602d3c73452 100644 --- a/benchmarks/cases/add_rms.py +++ b/benchmarks/cases/add_rms.py @@ -12,7 +12,8 @@ class FusedAddRMSNorm(torch.nn.Module): self.eps = eps def forward(self, x, residual): - return activation.rms_norm((x + residual), self.weight, self.eps) + h = x + residual + return activation.rms_norm(h, self.weight, self.eps), h class AddRMS(DiffCase): diff --git a/benchmarks/common/bench_framework.py b/benchmarks/common/bench_framework.py index f24b8d8d6163a85ce347eeb3d3ad06f46f1c67cf..49dfe3c7deb1cd6595a2d411a0e17615d8f99e3b 100644 --- a/benchmarks/common/bench_framework.py +++ b/benchmarks/common/bench_framework.py @@ -149,7 +149,10 @@ def make_bwd_benchmark_for_case( obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) y = case.forward(obj, I) gin = list(case.grad_inputs(I)) + list(obj.parameters()) - g = torch.randn_like(y) + if isinstance(y, torch.Tensor): + g = [torch.randn_like(y)] + else: + g = [torch.randn_like(r) for r in y] run = lambda: torch.autograd.grad(y, gin, g, @@ -201,7 +204,10 @@ def make_bwd_benchmark_plot_for_case( obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) y = case.forward(obj, I) gin = list(case.grad_inputs(I)) + list(obj.parameters()) - g = torch.randn_like(y) + if isinstance(y, torch.Tensor): + g = [torch.randn_like(y)] + else: + g = [torch.randn_like(r) for r in y] run = lambda: torch.autograd.grad(y, gin, g, diff --git a/benchmarks/common/diff_engine.py b/benchmarks/common/diff_engine.py index 276cd3900c34a69740cb35ec3183491055b32751..3c6edb75e20c4d84138e58b1238a3d6470ef05c7 100644 --- a/benchmarks/common/diff_engine.py +++ b/benchmarks/common/diff_engine.py @@ -68,7 +68,10 @@ def calculate_diff( torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol) gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters()) gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters()) - g = _unit_grad_like(y_n).to(device) + if isinstance(y_n, torch.Tensor): + g = [_unit_grad_like(y_n).to(device)] + else: + g = [_unit_grad_like(r).to(device) for r in y_n] ng = torch.autograd.grad(y_n, gin_n, g, diff --git a/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png b/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png index b596caf613cf58456d23286293dca72d747f377f..9e73dcf9779ad8fd5bfbc0b996d879112f5a3d91 100644 Binary files a/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png and b/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png differ diff --git a/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png b/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png index 132e33a45d291c152b44398be9f5d75d2abaa232..c381d41e691e8eca3336a79c1510025976d632ec 100644 Binary files a/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png and b/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png differ diff --git a/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png b/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png index d8919a93e3d7be46a9d6cd57d832eff48377a637..a1f3f56eff89118ef5757fca676172417a0b547c 100644 Binary files a/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png and b/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png differ diff --git a/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png b/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png index 5191efe759849638f7ac9f0f619c8eb56a04975d..cbe8cbbc35f46f5e026b15e5a2dd6dc31cde57aa 100644 Binary files a/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png and b/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png differ diff --git a/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png b/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png index 06d36211ff47c6c1a6feac15d5ea1c708c289ece..0ece866ec3e6eec7852ebf83410a12f9974b6e14 100644 Binary files a/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png and b/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png differ diff --git a/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png b/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png index 59f9cbd32aa6532bfcf5284f51a3d56be9d3da9d..510394f391a2cc624212dbae42ca0927131ea916 100644 Binary files a/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png and b/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png differ diff --git a/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png b/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png index 95815d90490f6c5916fa4835271473ac2565b079..695ad06b5eef4bd3b0e1f378378a4b114704e10a 100644 Binary files a/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png and b/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png differ diff --git a/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png b/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png index 5b86645cfecfbb0e9645527530a9fbd092f68b5f..69c35ce8bd941d0124e27856955d490aeea36e2e 100644 Binary files a/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png and b/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png differ diff --git a/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png b/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png index f36820d99706b319eeda8f8a3ed5246383d2524f..19296d853f59426be7eed63a6d74561e3b1ddc29 100644 Binary files a/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png and b/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png differ diff --git a/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png b/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png index dacca9cf397419f9af992675388f7299b35e30c0..94df77c414763adc493d39a3a04bfd4dd956bd7b 100644 Binary files a/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png and b/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..5a1e5a3587679a157ba7b067d28d762c6577fb8f --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec9ea7edc8b27f7983e20d615ab470cef6b82975afc214becfddfd05a867a839 +size 8600336 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..30ab86df7c79038bc40bcd1292a2fa606b44ebc1 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d3511410cdc288d2fafc500223ed2e625e360f50fa341809cf892fb2c822924 +size 8779000 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..b57174622d44e91556d4646cc225ce02ae186236 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25efc9c32e4bd6609a8326025aad861cbf79b544893755fe44519c9df7224c40 +size 13818872 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..c0069ea9e4f962208b869f671b23aa15f728cb92 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c80d05690547f2842d416ebb85c9f830370373bc7e6c54ba08eec61b3690280f +size 4386744 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..a50764fa05ea1e21294f84d922050f5d70f7db93 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:440f5c17a7ddaf73c506bbc84fd1405e2e188b8ceaf4977910608be6b91e89bf +size 8730200 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d3e4416a52e04ff527f48c721c6c4f1fa16059ed --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dfb6d468f9cef0239d4ea47f0a247fa721befc5b8db86e1cddfc25f1814b67a +size 13770064 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ebdc9108aad1a1dfd16dc0d8baebf827bc0476f4 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0815a50e61497b357b2b90fc28602b3f53a25da1161edd2cb0b0fbebc7c62bf6 +size 13757248 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..a7e8ec3a1957ec7fa888600e141e2d6acdb1d4be --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d404c88b72f1b6da551a64b3373395e80403a52ccff14fc401be3e8ee184d83 +size 4387536 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..dafb119147ed94f04203dd8c8a366ef9a6ed7680 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8d52dee20ba3c4619f7c614984f656f34f32dd74ba6cf866cf80f32245117cf +size 4393240 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py index a5ff861cc76e68ae4de5758b7acafa38f915e62a..fa68616c13166de47619ed052ed1eba664998b82 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_20250907180255 -ops = torch.ops._activation_20250907180255 +from . import _activation_e5e2eeb_dirty +ops = torch.ops._activation_e5e2eeb_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_20250907180255::{op_name}" \ No newline at end of file + return f"_activation_e5e2eeb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..0e2c29e955b87025e63f4795d58a14104318f736 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output, weight, eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/tests/test_fused_add_rms_norm.py b/tests/test_fused_add_rms_norm.py index 5486471c131e4543917cf7f399553e6dc71a7a72..249f5119da0085dca4d7eeb911a9c8f4ee18d248 100644 --- a/tests/test_fused_add_rms_norm.py +++ b/tests/test_fused_add_rms_norm.py @@ -18,15 +18,22 @@ CUDA_DEVICES = [ def add_rms_norm_all_naive(x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: - return torch.nn.functional.rms_norm((x + residual), weight.shape, weight, - eps) + h = x + residual + return torch.nn.functional.rms_norm(h, weight.shape, weight, eps) + h #use rms_norm kernel def add_rms_norm_partial_naive(x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: - return activation.rms_norm((x + residual), weight, eps) + h = x + residual + return activation.rms_norm(h, weight, eps) + h + + +def fused_add_rms_norm(x: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, eps: float) -> torch.Tensor: + out, h = activation.fused_add_rms_norm(x, residual, weight, eps) + return out + h @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -67,7 +74,7 @@ def test_fused_add_rms_norm( torch_fn2 = add_rms_norm_partial_naive op = activation.ops.fused_add_rms_norm - fn = activation.fused_add_rms_norm + fn = fused_add_rms_norm layer = activation.layers.FusedAddRMSNorm(d, eps) layer.weight = torch.nn.Parameter(weight) @@ -77,11 +84,12 @@ def test_fused_add_rms_norm( opcheck(op, (out, add_out, x, residual, weight, eps)) out = fn(x, residual, weight, eps) - mod_out = layer(x, residual) + mod_out, mod_a_out = layer(x, residual) + mod_out = mod_out + mod_a_out ref_out = torch_fn(x_ref, residual_ref, weight_ref, eps) ref_out2 = torch_fn2(x_ref2, residual_ref2, weight_ref2, eps) - assert_close(out, ref_out) + assert_close(out, ref_out, atol=0.05, rtol=0.05) assert_close(out, ref_out2) assert_close(mod_out, out, atol=0.0, rtol=0.0) diff --git a/torch-ext/activation/__init__.py b/torch-ext/activation/__init__.py index 254ab917bab4ccd4a19327c1fe0bf96060f11217..938feeff791794d011fec65cf86df957e2c4da2f 100644 --- a/torch-ext/activation/__init__.py +++ b/torch-ext/activation/__init__.py @@ -39,7 +39,7 @@ def fused_add_rms_norm( weight: torch.Tensor, eps: float = 1e-6, ) -> None: - return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + return FusedAddRMSNormFunction.apply(x, residual, weight, eps) __all__ = [ diff --git a/torch-ext/activation/layers.py b/torch-ext/activation/layers.py index 156ea42df607e920731ad932d3a5b5d3a472c157..b1880bdbe8dd73ac76d7d4561cf60f9765097ca9 100644 --- a/torch-ext/activation/layers.py +++ b/torch-ext/activation/layers.py @@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module): residual: torch.Tensor, ): return FusedAddRMSNormFunction.apply(x, residual, self.weight, - self.eps)[0] + self.eps) def reset_parameters(self) -> None: """ diff --git a/torch-ext/activation/rms_norm.py b/torch-ext/activation/rms_norm.py index 7b4274f3a59c423a8662edf3bb8728a1daacb71f..7f9a470d9bb3833083cfa711e9d16c336b73238d 100644 --- a/torch-ext/activation/rms_norm.py +++ b/torch-ext/activation/rms_norm.py @@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function): def setup_context(ctx, inputs, outputs): _, _, weight, eps = inputs _, add_output = outputs - ctx.mark_non_differentiable(add_output) - ctx.set_materialize_grads(False) ctx.save_for_backward(weight, add_output) ctx.eps = eps - # This function only needs one gradient @staticmethod - def backward(ctx, output_grad, _): + def backward(ctx, output_grad, add_output_grad): weight, add_output = ctx.saved_tensors eps = ctx.eps - if output_grad is None: - output_grad = torch.zeros_like(add_output) - need_in = ctx.needs_input_grad[0] need_res = ctx.needs_input_grad[1] @@ -76,8 +70,9 @@ class FusedAddRMSNormFunction(torch.autograd.Function): weight_grad = torch.empty_like( weight) if ctx.needs_input_grad[2] else None - ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, - weight, eps) + ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, + add_output_grad, add_output, weight, + eps) input_grad = grad if need_in else None residual_grad = grad if need_res else None diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp index 5c71a5ba1e025e276b32d6878defb0ab1f28ffd1..a316859ac2ac4f7692717b7f3e0a175647835b4a 100644 --- a/torch-ext/torch_binding.cpp +++ b/torch-ext/torch_binding.cpp @@ -39,12 +39,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { &fused_mul_poly_norm_backward); // fused_add_rms_norm - // fused_add_rms_norm_backward uses rms_norm_backward_kernel ops.def( "fused_add_rms_norm(Tensor! out, Tensor! add_out, Tensor input, Tensor " "residual, Tensor " "weight, float eps) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + + ops.def( + "fused_add_rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, " + "Tensor " + "output_grad, Tensor add_output_grad, Tensor input, Tensor weight, float " + "eps) -> ()"); + ops.impl("fused_add_rms_norm_backward", torch::kCUDA, + &fused_add_rms_norm_backward); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h index b3629b1afbe07d99f6a60bd8c1f846fdc46ff8b0..57464c4f2d45238dbc5e3b11e40095a45d77db31 100644 --- a/torch-ext/torch_binding.h +++ b/torch-ext/torch_binding.h @@ -28,8 +28,13 @@ void fused_mul_poly_norm_backward( const torch::Tensor &mul, const torch::Tensor &weight, const torch::Tensor &bias, double eps); -// fused_add_rms_norm_backward uses rms_norm_backward_kernel void fused_add_rms_norm(torch::Tensor &out, torch::Tensor &add_out, const torch::Tensor &input, const torch::Tensor &residual, const torch::Tensor &weight, double eps); +void fused_add_rms_norm_backward(torch::Tensor &input_grad, + torch::Tensor &weight_grad, + const torch::Tensor &output_grad, + const torch::Tensor &add_output_grad, + const torch::Tensor &input, + const torch::Tensor &weight, double eps);