Skip to content

Conversation

@oleksost
Copy link
Contributor

@oleksost oleksost commented Nov 26, 2025

✨ Description

Should be merged after GDN #392 .

Adding KDA mixer from Kimi Lienar.

Note, for now this requires nightly triton and pytorch, see: https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md.

Added a new docker file for KDA image. I uploaded it to registry.toolkit-sp.yul201.service-now.com/snow.research.afm/kda_image:kda_torch_nightly

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

  • added kda.py in ssm layers
  • added kda to varlen test
  • added hybrid_kda to model configs for testing
  • added a Dockerfile for kda

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@oleksost oleksost marked this pull request as ready for review November 26, 2025 20:40
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, most also apply to GDA

# The image is still compatible with any user id.
RUN useradd user
USER user
USER user
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary diff

super()._validate()


@config_class(dynamic_type={MixerConfig: "kda"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"kimi_delta_attention"

desc="Configuration for the gated normalization applied to the KDA output.",
hint=FieldHint.architecture,
)
q_projection_layer: AffineLinearConfig = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

projection seems unnecessary in these fields.

)

@property
def layer_class(self) -> "type":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type["KimiDeltaAttention"]

return KimiDeltaAttention

def _validate(self) -> None:
with self._set_implicit_default():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that's a good idea, it makes configs hard to understand. Better assume the user to specify these explicitly. (and most of the time we're creating from HF so that's not a problem)



@pytest.mark.slow
@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest.mark.requires_cuda

AprielHybridSSMConfig, KimiDeltaAttention = None, None


def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use get_stage, it already does this. See example here https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py#L264

Also please don't copy utils to every file, they can go in utils

@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA")
@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing")
@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available")
def test_fast_llm_kda_matches_apriel_forward():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need this test at all. test_huggingface_model already tests the equivalence

ModelTestingGroup.convert: ModelTestingGroupAction.normal,
ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to test once and leave as unimportant, this has a huge impact on testing time.

tscholak added a commit that referenced this pull request Dec 7, 2025
Update to nvcr.io/nvidia/pytorch:25.11-py3 which includes:
- PyTorch 2.10
- CUDA 13.0
- flash-attn 2.7.4.post1 (pre-installed, no compilation needed)

Dependency updates:
- causal-conv1d: v1.5.4 (was pinned to commit 2a288a1)
- mamba-ssm: 2.2.6.post3 (was pinned to commit 4a8a2a2)
- flash-linear-attention: pin to commit 67eee20 (was @main)
- flash-attn: 2.7.4.post1 to match base image (was 2.7.3)
- triton: 3.5.1 in Dockerfile (was 3.1.0)

These updates enable Kimi Delta Attention (KDA) support via the
flash-linear-attention library. The pinned versions are tested and
working, unlike the nightly/unpinned approach in #395.

Note: Dropless MoE kernel remains broken with triton >= 3.2.0 and
needs a complete rewrite (also limited to 32 experts). This is
tracked separately and doesn't block KDA work.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants