Skip to content

Conversation

@tscholak
Copy link
Collaborator

@tscholak tscholak commented Dec 7, 2025

This PR updates the base image to nvcr.io/nvidia/pytorch:25.11-py3 and bumps dependencies to enable Kimi Delta Attention (KDA) support.

Base image changes:

  • PyTorch: 2.10
  • CUDA: 13.0
  • flash-attn: 2.7.4.post1 (pre-installed, no compilation needed)

Dependency updates:

Package Old New
causal-conv1d commit 2a288a1 v1.5.4
mamba-ssm commit 4a8a2a2 2.2.6.post3
flash-linear-attention @main commit 67eee20
flash-attn 2.7.3 2.7.4.post1
triton (Dockerfile) 3.1.0 3.5.1

This is an alternative approach to #395 for enabling KDA support. Instead of using PyTorch/triton nightly builds, we use the official NVIDIA PyTorch image which:

  • Avoids recompiling flash-attn - the 25.11 image ships with flash-attn 2.7.4.post1, saving significant build time
  • Provides broader architecture support out of the box
  • Uses pinned versions rather than nightly builds

Thanks to @oleksost for the groundwork in #395 exploring the nightly path.

All KDA tests from #395 pass:

  • tests/layers/test_kda_equivalence.py::test_fast_llm_kda_matches_apriel_forward
  • tests/test_varlen.py::test_mixer_varlen_stacking_equivalence[config2-False] (KDA, sequence_first=False)
  • tests/test_varlen.py::test_mixer_varlen_stacking_equivalence[config3-True] (KDA, sequence_first=True)

Known issues:

Dropless MoE kernel remains broken with triton >= 3.2.0 and needs a complete rewrite (also limited to 32 experts). This is tracked separately and doesn't block KDA work.

Update to nvcr.io/nvidia/pytorch:25.11-py3 which includes:
- PyTorch 2.10
- CUDA 13.0
- flash-attn 2.7.4.post1 (pre-installed, no compilation needed)

Dependency updates:
- causal-conv1d: v1.5.4 (was pinned to commit 2a288a1)
- mamba-ssm: 2.2.6.post3 (was pinned to commit 4a8a2a2)
- flash-linear-attention: pin to commit 67eee20 (was @main)
- flash-attn: 2.7.4.post1 to match base image (was 2.7.3)
- triton: 3.5.1 in Dockerfile (was 3.1.0)

These updates enable Kimi Delta Attention (KDA) support via the
flash-linear-attention library. The pinned versions are tested and
working, unlike the nightly/unpinned approach in #395.

Note: Dropless MoE kernel remains broken with triton >= 3.2.0 and
needs a complete rewrite (also limited to 32 experts). This is
tracked separately and doesn't block KDA work.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the base Docker image to NVIDIA PyTorch 25.11 and bumps several dependencies to enable Kimi Delta Attention (KDA) support. The changes leverage the pre-installed flash-attn 2.7.4.post1 in the new base image to avoid lengthy compilation, while updating mamba-ssm, causal-conv1d, flash-linear-attention, and triton to their latest compatible versions.

Key Changes:

  • Base image upgrade from nvcr.io/nvidia/pytorch:25.05-py3 to 25.11-py3 (PyTorch 2.10, CUDA 13.0)
  • Dependency version bumps: flash-attn (2.7.3 → 2.7.4.post1), mamba-ssm (2.2.4 → 2.2.6.post3), triton (3.1.0 → 3.5.1)
  • Switch from git commits to versioned releases for causal-conv1d and mamba-ssm

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
Dockerfile Updates base image to 25.11-py3, switches mamba-ssm to PyPI release, updates causal-conv1d to v1.5.4, adds flash-linear-attention installation, and bumps triton to 3.5.1
setup.cfg Bumps flash-attn to 2.7.4.post1, mamba_ssm to 2.2.6.post3, pins flash-linear-attention to specific commit, and clarifies comment about Dropless MoE kernel limitations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# The base image enforces versions for things like pytest for no good reason.
ENV PIP_CONSTRAINT=""
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially outdated comment: This comment states "We need to compile from the repo because of state-spaces/mamba#720 (same for causal-conv1d)". However, line 33 now installs mamba-ssm from PyPI (not from git repo), so the comment is only accurate for causal-conv1d (line 32). Consider updating the comment to clarify which packages still require compilation from git and which are now installed from PyPI.

Suggested change
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
# We need to compile causal-conv1d from the repo because of https://github.com/state-spaces/mamba/issues/720.
# mamba-ssm is now installed from PyPI.

Copilot uses AI. Check for mistakes.
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d @ git+https://github.com/Dao-AILab/causal-conv1d@v1.5.4"
RUN MAX_JOBS=2 pip install --no-build-isolation mamba-ssm==2.2.6.post3
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Package name inconsistency: This line uses mamba-ssm (with hyphen) while setup.cfg line 55 uses mamba_ssm (with underscore). For consistency across the codebase, both should use the same format. Recommend using mamba_ssm to match setup.cfg and the import statements used throughout the codebase.

Suggested change
RUN MAX_JOBS=2 pip install --no-build-isolation mamba-ssm==2.2.6.post3
RUN MAX_JOBS=2 pip install --no-build-isolation mamba_ssm==2.2.6.post3

Copilot uses AI. Check for mistakes.

# The base image enforces versions for things like pytest for no good reason.
ENV PIP_CONSTRAINT=""
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
Copy link

Copilot AI Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated comment: This comment references "pytorch 2.8" but the PR description indicates the new base image (25.11) includes PyTorch 2.10. The comment should be updated to reflect the current PyTorch version to avoid confusion.

Suggested change
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
# There is no pre-build mamba image for pytorch 2.10, we build it before the rest to avoid rebuilds.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@oleksost oleksost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its great that this works! Having to build flash attention from source it a would have been a huge pain.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants