feather-h200-runtime / mamba_ssm_init.py
icarus112's picture
Update Feather H200 training runtime image
7be430d verified
# mamba_ssm package init β€” minimal override to avoid broken selective_scan_cuda.so
# ABI mismatch with the base image's libtorch.
#
# The upstream __init__.py eagerly imports selective_scan_cuda which fails on
# pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
# symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
# all compiled-CUDA imports here and let Mamba3 load directly.
__version__ = "2.3.1+feather-graft"
# selective_scan_fn / mamba_inner_fn are shimmed to None β€” they are NOT used
# by the Feather training path (which is Mamba3-only). If any import path
# hits this, it will get a clear AttributeError instead of an obscure ImportError.
selective_scan_fn = None
mamba_inner_fn = None
# --- triton API compatibility shims -----------------------------------------
# Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
# imports AttrsDescriptor from triton.compiler.compiler β€” removed in triton 3.4+.
# Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
# tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
# satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
# APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
# stub is never actually invoked at runtime because the codebase does not use
# torch.compile β€” but importing torch._inductor.* still requires the symbol to
# exist at module load time.
import triton as _triton # noqa: E402
if not hasattr(_triton, "set_allocator"):
def _noop_set_allocator(_fn): # pragma: no cover
return None
_triton.set_allocator = _noop_set_allocator
import triton.compiler.compiler as _tcc # noqa: E402
# NOTE: on triton >= 3.4, AttrsDescriptor is intentionally ABSENT. An earlier
# version of this file installed a `_AttrsDescriptorShim` here so that
# `torch._inductor.runtime.hints` could `import AttrsDescriptor` without
# failing, but that shim turned out to be actively harmful once
# `HYDRA_MUON_COMPILE=1` started exercising the torch.compile codegen path:
#
# 1. torch/_inductor/runtime/hints.py branches on
# `hasattr(triton.compiler.compiler, "AttrsDescriptor")`. With the shim
# present it took the triton-3.0.0 path and constructed shim instances
# into `triton_meta["configs"]`.
# 2. torch/_inductor/codegen/triton.py then wrote `triton_meta={...!r}`
# into the generated Triton-Python source. The shim had no __repr__,
# so the emitted text was `<mamba_ssm._AttrsDescriptorShim object at 0x…>`,
# which is a SyntaxError at line 15 of the generated file.
# 3. torch/_dynamo/config.suppress_errors = True turned that SyntaxError
# into a silent fall-back-to-eager, costing ~30% MFU on the RTX 3060
# (43k tps eager vs 63k tps compiled per hydra/optimizer.py:66).
#
# Without the shim, torch._inductor takes the intended "triton >= 3.4"
# branch (hints.py:82-92) which emits `{(x,): [["tt.divisibility", 16]]
# for x in divisible_by_16}` β€” a plain dict whose repr is valid Python β€”
# and `gen_attr_descriptor_import()` returns "" so no broken
# `from triton.compiler.compiler import AttrsDescriptor` line is baked
# into generated kernel files.
#
# Leaving the guard as a no-op `pass` keeps this block visible for future
# triton-version audits but performs no monkey-patch.
if not hasattr(_tcc, "AttrsDescriptor"):
pass
# triton_key: removed in triton 3.5, used by torch._inductor.codecache for
# FxGraphCache key derivation. Return a stable string so caching still works.
if not hasattr(_tcc, "triton_key"):
def _triton_key_shim():
import triton as _t
return f"triton-{_t.__version__}-shim"
_tcc.triton_key = _triton_key_shim
# Suppress torch.compile/_dynamo errors globally β€” we don't rely on torch.compile
# for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
# so fall back to eager on any dynamo failure rather than crashing. This is
# defense-in-depth against further triton API drift.
try:
import torch._dynamo # noqa: F401 β€” triggers dynamo module init
torch._dynamo.config.suppress_errors = True
except Exception: # pragma: no cover
pass
# Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402