Spaces:
Runtime error
Runtime error
| # mamba_ssm package init β minimal override to avoid broken selective_scan_cuda.so | |
| # ABI mismatch with the base image's libtorch. | |
| # | |
| # The upstream __init__.py eagerly imports selective_scan_cuda which fails on | |
| # pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor | |
| # symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip | |
| # all compiled-CUDA imports here and let Mamba3 load directly. | |
| __version__ = "2.3.1+feather-graft" | |
| # selective_scan_fn / mamba_inner_fn are shimmed to None β they are NOT used | |
| # by the Feather training path (which is Mamba3-only). If any import path | |
| # hits this, it will get a clear AttributeError instead of an obscure ImportError. | |
| selective_scan_fn = None | |
| mamba_inner_fn = None | |
| # --- triton API compatibility shims ----------------------------------------- | |
| # Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor | |
| # imports AttrsDescriptor from triton.compiler.compiler β removed in triton 3.4+. | |
| # Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and | |
| # tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version | |
| # satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3 | |
| # APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The | |
| # stub is never actually invoked at runtime because the codebase does not use | |
| # torch.compile β but importing torch._inductor.* still requires the symbol to | |
| # exist at module load time. | |
| import triton as _triton # noqa: E402 | |
| if not hasattr(_triton, "set_allocator"): | |
| def _noop_set_allocator(_fn): # pragma: no cover | |
| return None | |
| _triton.set_allocator = _noop_set_allocator | |
| import triton.compiler.compiler as _tcc # noqa: E402 | |
| # NOTE: on triton >= 3.4, AttrsDescriptor is intentionally ABSENT. An earlier | |
| # version of this file installed a `_AttrsDescriptorShim` here so that | |
| # `torch._inductor.runtime.hints` could `import AttrsDescriptor` without | |
| # failing, but that shim turned out to be actively harmful once | |
| # `HYDRA_MUON_COMPILE=1` started exercising the torch.compile codegen path: | |
| # | |
| # 1. torch/_inductor/runtime/hints.py branches on | |
| # `hasattr(triton.compiler.compiler, "AttrsDescriptor")`. With the shim | |
| # present it took the triton-3.0.0 path and constructed shim instances | |
| # into `triton_meta["configs"]`. | |
| # 2. torch/_inductor/codegen/triton.py then wrote `triton_meta={...!r}` | |
| # into the generated Triton-Python source. The shim had no __repr__, | |
| # so the emitted text was `<mamba_ssm._AttrsDescriptorShim object at 0xβ¦>`, | |
| # which is a SyntaxError at line 15 of the generated file. | |
| # 3. torch/_dynamo/config.suppress_errors = True turned that SyntaxError | |
| # into a silent fall-back-to-eager, costing ~30% MFU on the RTX 3060 | |
| # (43k tps eager vs 63k tps compiled per hydra/optimizer.py:66). | |
| # | |
| # Without the shim, torch._inductor takes the intended "triton >= 3.4" | |
| # branch (hints.py:82-92) which emits `{(x,): [["tt.divisibility", 16]] | |
| # for x in divisible_by_16}` β a plain dict whose repr is valid Python β | |
| # and `gen_attr_descriptor_import()` returns "" so no broken | |
| # `from triton.compiler.compiler import AttrsDescriptor` line is baked | |
| # into generated kernel files. | |
| # | |
| # Leaving the guard as a no-op `pass` keeps this block visible for future | |
| # triton-version audits but performs no monkey-patch. | |
| if not hasattr(_tcc, "AttrsDescriptor"): | |
| pass | |
| # triton_key: removed in triton 3.5, used by torch._inductor.codecache for | |
| # FxGraphCache key derivation. Return a stable string so caching still works. | |
| if not hasattr(_tcc, "triton_key"): | |
| def _triton_key_shim(): | |
| import triton as _t | |
| return f"triton-{_t.__version__}-shim" | |
| _tcc.triton_key = _triton_key_shim | |
| # Suppress torch.compile/_dynamo errors globally β we don't rely on torch.compile | |
| # for performance in this codebase (Muon + mamba3 CUDA kernels already fused), | |
| # so fall back to eager on any dynamo failure rather than crashing. This is | |
| # defense-in-depth against further triton API drift. | |
| try: | |
| import torch._dynamo # noqa: F401 β triggers dynamo module init | |
| torch._dynamo.config.suppress_errors = True | |
| except Exception: # pragma: no cover | |
| pass | |
| # Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`. | |
| from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402 | |