-
Notifications
You must be signed in to change notification settings - Fork 39
Kda mixer #395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
jlamypoirier
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments, most also apply to GDA
| # The image is still compatible with any user id. | ||
| RUN useradd user | ||
| USER user | ||
| USER user |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary diff
| super()._validate() | ||
|
|
||
|
|
||
| @config_class(dynamic_type={MixerConfig: "kda"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"kimi_delta_attention"
| desc="Configuration for the gated normalization applied to the KDA output.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| q_projection_layer: AffineLinearConfig = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
projection seems unnecessary in these fields.
| ) | ||
|
|
||
| @property | ||
| def layer_class(self) -> "type": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type["KimiDeltaAttention"]
| return KimiDeltaAttention | ||
|
|
||
| def _validate(self) -> None: | ||
| with self._set_implicit_default(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure that's a good idea, it makes configs hard to understand. Better assume the user to specify these explicitly. (and most of the time we're creating from HF so that's not a problem)
|
|
||
|
|
||
| @pytest.mark.slow | ||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytest.mark.requires_cuda
| AprielHybridSSMConfig, KimiDeltaAttention = None, None | ||
|
|
||
|
|
||
| def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use get_stage, it already does this. See example here https://github.com/ServiceNow/Fast-LLM/blob/main/tests/layers/test_lm_head.py#L264
Also please don't copy utils to every file, they can go in utils
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA equivalence test needs CUDA") | ||
| @pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") | ||
| @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") | ||
| def test_fast_llm_kda_matches_apriel_forward(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we need this test at all. test_huggingface_model already tests the equivalence
| ModelTestingGroup.convert: ModelTestingGroupAction.normal, | ||
| ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.distributed: ModelTestingGroupAction.normal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to test once and leave as unimportant, this has a huge impact on testing time.
Update to nvcr.io/nvidia/pytorch:25.11-py3 which includes: - PyTorch 2.10 - CUDA 13.0 - flash-attn 2.7.4.post1 (pre-installed, no compilation needed) Dependency updates: - causal-conv1d: v1.5.4 (was pinned to commit 2a288a1) - mamba-ssm: 2.2.6.post3 (was pinned to commit 4a8a2a2) - flash-linear-attention: pin to commit 67eee20 (was @main) - flash-attn: 2.7.4.post1 to match base image (was 2.7.3) - triton: 3.5.1 in Dockerfile (was 3.1.0) These updates enable Kimi Delta Attention (KDA) support via the flash-linear-attention library. The pinned versions are tested and working, unlike the nightly/unpinned approach in #395. Note: Dropless MoE kernel remains broken with triton >= 3.2.0 and needs a complete rewrite (also limited to 32 experts). This is tracked separately and doesn't block KDA work. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
✨ Description
Should be merged after GDN #392 .
Adding KDA mixer from Kimi Lienar.
Note, for now this requires nightly triton and pytorch, see: https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md.
Added a new docker file for KDA image. I uploaded it to
registry.toolkit-sp.yul201.service-now.com/snow.research.afm/kda_image:kda_torch_nightly🔍 Type of change
Select all that apply:
📝 Changes
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.