Skip to content

Commit 9c8e371

Browse files
committed
d
1 parent 356ece2 commit 9c8e371

File tree

8 files changed

+2891
-0
lines changed

8 files changed

+2891
-0
lines changed
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# MLX Metal Kernel Optimization Integration
2+
3+
This package provides seamless integration of optimized Metal kernels with MLX-LM, delivering significant performance improvements for transformer attention computations on Apple Silicon.
4+
5+
## 🚀 Key Features
6+
7+
- **Intelligent Dispatch**: Automatically detects model architecture and applies appropriate optimizations
8+
- **Graceful Fallback**: Falls back to standard MLX operations when optimizations aren't beneficial
9+
- **Multiple Attention Patterns**: Supports GQA, MQA, and MHA with pattern-specific optimizations
10+
- **Easy Integration**: Simple monkey-patching for existing mlx-lm code
11+
- **Comprehensive Benchmarking**: Built-in tools for performance measurement and comparison
12+
- **Apple Silicon Optimized**: Leverages Metal Performance Shaders and unified memory architecture
13+
14+
## 📊 Performance Improvements
15+
16+
| Model Type | Architecture | Expected Speedup | Memory Reduction |
17+
|------------|--------------|------------------|------------------|
18+
| Qwen3 | 40:8 GQA | 1.5-2.0x | 10-15% |
19+
| Llama-3 | 32:8 GQA | 1.3-1.8x | 8-12% |
20+
| Gemma | 24:24 MHA | 1.2-1.5x | 5-10% |
21+
| Mistral | 32:8 GQA | 1.4-1.9x | 8-12% |
22+
23+
## 🛠 Installation
24+
25+
1. **Prerequisites**:
26+
```bash
27+
pip install mlx mlx-lm
28+
```
29+
30+
2. **Integration Setup**:
31+
```bash
32+
# Copy the integration folder to your project
33+
cp -r integration/ /path/to/your/project/
34+
```
35+
36+
## 🔧 Quick Start
37+
38+
### Basic Usage
39+
40+
```python
41+
from integration import patch_mlx_lm, unpatch_mlx_lm
42+
from mlx_lm import load, generate
43+
44+
# Apply optimizations
45+
patch_mlx_lm(enable_debug=True)
46+
47+
# Use mlx-lm normally - optimizations applied automatically
48+
model, tokenizer = load("mlx-community/Qwen2.5-0.5B-Instruct-4bit")
49+
response = generate(model, tokenizer, prompt="Hello!", max_tokens=100)
50+
51+
# Remove optimizations when done
52+
unpatch_mlx_lm()
53+
```
54+
55+
### Context Manager Pattern
56+
57+
```python
58+
from integration.mlx_lm_integration import MLXLMIntegration
59+
60+
class OptimizedMLX:
61+
def __enter__(self):
62+
self.patched_count = patch_mlx_lm(enable_debug=False)
63+
return self
64+
65+
def __exit__(self, exc_type, exc_val, exc_tb):
66+
unpatch_mlx_lm(enable_debug=False)
67+
68+
# Optimizations applied only within this block
69+
with OptimizedMLX():
70+
model, tokenizer = load("mlx-community/Qwen2.5-0.5B-Instruct-4bit")
71+
response = generate(model, tokenizer, prompt="Hello!", max_tokens=100)
72+
# Optimizations automatically removed
73+
```
74+
75+
### Custom Configuration
76+
77+
```python
78+
from integration import configure_optimizer, patch_mlx_lm
79+
80+
# Configure optimization thresholds
81+
configure_optimizer(
82+
enable_debug=True,
83+
min_seq_len=128, # Lower threshold for short sequences
84+
max_seq_len=4096, # Higher limit for long sequences
85+
gqa_ratio_min=3, # Require at least 3:1 GQA ratio
86+
min_heads=16 # Require at least 16 heads
87+
)
88+
89+
# Apply with custom configuration
90+
patch_mlx_lm()
91+
```
92+
93+
## 🧪 Testing and Benchmarking
94+
95+
### Quick Demo
96+
97+
```bash
98+
python integration/demo_integration.py --quick-test
99+
```
100+
101+
### Interactive Demo
102+
103+
```bash
104+
python integration/demo_integration.py --interactive --model qwen2.5-0.5b
105+
```
106+
107+
### Comprehensive Benchmark
108+
109+
```bash
110+
python integration/demo_integration.py --comprehensive
111+
```
112+
113+
### Usage Examples
114+
115+
```bash
116+
python integration/usage_examples.py
117+
```
118+
119+
## 📈 Monitoring Performance
120+
121+
### Check Optimization Status
122+
123+
```python
124+
from integration import get_integration_status
125+
126+
status = get_integration_status()
127+
print(f"Patched: {status['is_patched']}")
128+
print(f"Optimization rate: {status['optimizer_stats']['optimization_rate']:.1%}")
129+
```
130+
131+
### Benchmark Specific Models
132+
133+
```python
134+
from integration import benchmark_optimization
135+
136+
results = benchmark_optimization(
137+
model_name="qwen3",
138+
seq_lengths=[256, 512, 1024, 2048],
139+
warmup_runs=3,
140+
benchmark_runs=10,
141+
save_results=True
142+
)
143+
144+
for result in results:
145+
print(f"Seq {result.seq_length}: {result.speedup:.2f}x speedup")
146+
```
147+
148+
## 🎯 Supported Models
149+
150+
| Model Family | Pattern | Priority | Status |
151+
|--------------|---------|----------|--------|
152+
| Qwen3 | GQA 5:1 | High | ✅ Optimized |
153+
| Qwen2 | GQA 4:1 | High | ✅ Optimized |
154+
| Llama-3 | GQA 4:1 | High | ✅ Optimized |
155+
| Mistral | GQA 4:1 | High | ✅ Optimized |
156+
| Gemma | MHA 1:1 | Medium | ✅ Optimized |
157+
| Phi-3 | GQA 4:1 | Medium | ✅ Optimized |
158+
| DeepSeek-V3 | GQA | High | ✅ Optimized |
159+
160+
## ⚙️ How It Works
161+
162+
### 1. Attention Pattern Detection
163+
164+
The optimizer automatically detects attention patterns:
165+
166+
```python
167+
config = AttentionConfig(
168+
num_heads=40,
169+
num_kv_heads=8,
170+
head_dim=128,
171+
seq_len=1024,
172+
batch_size=1
173+
)
174+
175+
# Automatically detects: GQA-5:1 pattern
176+
print(config.attention_pattern) # "GQA-5:1"
177+
```
178+
179+
### 2. Intelligent Dispatch
180+
181+
Based on the detected pattern and thresholds:
182+
183+
```python
184+
should_optimize, reason = optimizer.should_optimize(config)
185+
if should_optimize:
186+
# Apply optimized Metal kernel
187+
result = optimized_attention(queries, keys, values, scale, mask)
188+
else:
189+
# Fall back to standard MLX implementation
190+
result = mx.fast.scaled_dot_product_attention(queries, keys, values, scale, mask)
191+
```
192+
193+
### 3. Metal Kernel Optimization
194+
195+
The Metal kernels include:
196+
197+
- **Memory Coalescing**: Optimized memory access patterns for Apple Silicon
198+
- **SIMD Vectorization**: 4-way and 8-way vectorized operations
199+
- **Online Softmax**: Memory-efficient attention computation
200+
- **Pattern-Specific Logic**: GQA head mapping, MQA single-head optimization
201+
202+
## 🔍 Technical Details
203+
204+
### Optimization Thresholds
205+
206+
| Parameter | Default | Description |
207+
|-----------|---------|-------------|
208+
| `min_seq_len` | 64 | Minimum sequence length for optimization |
209+
| `max_seq_len` | 4096 | Maximum supported sequence length |
210+
| `min_head_dim` | 64 | Minimum head dimension for vectorization |
211+
| `max_head_dim` | 256 | Maximum supported head dimension |
212+
| `min_heads` | 8 | Minimum number of heads for optimization |
213+
| `gqa_ratio_min` | 2 | Minimum GQA ratio to trigger optimization |
214+
215+
### Metal Kernel Features
216+
217+
1. **GQA Optimization**:
218+
- Efficient head mapping for grouped queries
219+
- Optimized memory layout for KV head sharing
220+
- Vectorized computation with loop unrolling
221+
222+
2. **MQA Optimization**:
223+
- Single KV head specialized kernel
224+
- Reduced memory bandwidth requirements
225+
- Optimized for single-head broadcasting
226+
227+
3. **MHA Optimization**:
228+
- Standard multi-head attention with vectorization
229+
- Memory-efficient implementation
230+
- SIMD optimizations for large head counts
231+
232+
## 🐛 Troubleshooting
233+
234+
### Common Issues
235+
236+
1. **No Optimization Applied**:
237+
```python
238+
# Check if model meets thresholds
239+
status = get_integration_status()
240+
print(status['optimizer_stats'])
241+
```
242+
243+
2. **Fallback to Standard Implementation**:
244+
```python
245+
# Enable debug to see fallback reasons
246+
patch_mlx_lm(enable_debug=True)
247+
```
248+
249+
3. **Memory Issues**:
250+
```python
251+
# Lower sequence length threshold
252+
configure_optimizer(max_seq_len=2048)
253+
```
254+
255+
### Debug Mode
256+
257+
Enable debug output to see optimization decisions:
258+
259+
```python
260+
patch_mlx_lm(enable_debug=True)
261+
# Output will show:
262+
# ✅ Patched qwen3 attention
263+
# ⚡ Applying GQA-5:1 optimization: GQA pattern with 5:1 ratio benefits from custom kernel
264+
# 🔄 Falling back to MLX SDPA: Sequence length 32 below threshold 64
265+
```
266+
267+
## 📋 API Reference
268+
269+
### Main Functions
270+
271+
- `patch_mlx_lm(enable_debug=False, **kwargs)` - Apply optimizations
272+
- `unpatch_mlx_lm(enable_debug=False)` - Remove optimizations
273+
- `get_integration_status()` - Get current status and stats
274+
- `configure_optimizer(**kwargs)` - Configure optimization parameters
275+
- `benchmark_optimization(...)` - Run performance benchmarks
276+
277+
### Classes
278+
279+
- `MetalKernelOptimizer` - Core optimization engine
280+
- `AttentionConfig` - Attention pattern configuration
281+
- `MLXLMIntegration` - Integration management
282+
- `BenchmarkResult` - Benchmark result container
283+
284+
## 🤝 Contributing
285+
286+
1. Test on different model architectures
287+
2. Optimize for specific sequence length ranges
288+
3. Add support for new attention patterns
289+
4. Improve Metal kernel performance
290+
5. Add more comprehensive benchmarks
291+
292+
## 📜 License
293+
294+
This project is part of the OpenEvolve framework and follows the same licensing terms.
295+
296+
## 🙏 Acknowledgments
297+
298+
- Built on the AlphaEvolve framework for automated optimization discovery
299+
- Inspired by the Metal kernel optimizations described in the AlphaEvolve paper
300+
- Uses MLX and MLX-LM as the foundation for Apple Silicon machine learning
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""
2+
MLX Metal Kernel Optimization Integration
3+
4+
This package provides seamless integration of optimized Metal kernels with mlx-lm,
5+
offering significant performance improvements for transformer attention computations
6+
on Apple Silicon.
7+
8+
Key Features:
9+
- Automatic dispatch based on model architecture and configuration
10+
- Graceful fallback to standard MLX operations when optimizations aren't beneficial
11+
- Support for GQA, MQA, and MHA attention patterns
12+
- Easy monkey-patching for existing mlx-lm code
13+
- Comprehensive benchmarking and profiling tools
14+
15+
Quick Start:
16+
from integration import patch_mlx_lm, unpatch_mlx_lm
17+
18+
# Apply optimizations
19+
patch_mlx_lm(enable_debug=True)
20+
21+
# Use mlx-lm normally
22+
from mlx_lm import load, generate
23+
model, tokenizer = load("mlx-community/Qwen2.5-0.5B-Instruct-4bit")
24+
response = generate(model, tokenizer, prompt="Hello", max_tokens=100)
25+
26+
# Remove optimizations when done
27+
unpatch_mlx_lm()
28+
29+
Supported Models:
30+
- Qwen3 (40:8 GQA) - High priority optimization
31+
- Qwen2 (32:8 GQA) - High priority optimization
32+
- Llama (32:8 GQA) - High priority optimization
33+
- Mistral3 (32:8 GQA) - High priority optimization
34+
- Gemma (24:24 MHA) - Medium priority optimization
35+
- Phi3 (32:8 GQA) - Medium priority optimization
36+
- DeepSeek-V3 (GQA) - High priority optimization
37+
"""
38+
39+
from .metal_kernel_optimizer import (
40+
MetalKernelOptimizer,
41+
AttentionConfig,
42+
optimized_scaled_dot_product_attention,
43+
configure_optimizer,
44+
get_optimizer_stats,
45+
reset_optimizer_stats
46+
)
47+
48+
from .mlx_lm_integration import (
49+
MLXLMIntegration,
50+
patch_mlx_lm,
51+
unpatch_mlx_lm,
52+
get_integration_status,
53+
is_mlx_lm_patched,
54+
benchmark_optimization,
55+
quick_benchmark,
56+
BenchmarkResult
57+
)
58+
59+
__version__ = "1.0.0"
60+
__author__ = "OpenEvolve Team"
61+
__description__ = "Metal kernel optimizations for MLX-LM attention computations"
62+
63+
__all__ = [
64+
# Core optimizer
65+
'MetalKernelOptimizer',
66+
'AttentionConfig',
67+
'optimized_scaled_dot_product_attention',
68+
'configure_optimizer',
69+
'get_optimizer_stats',
70+
'reset_optimizer_stats',
71+
72+
# Integration
73+
'MLXLMIntegration',
74+
'patch_mlx_lm',
75+
'unpatch_mlx_lm',
76+
'get_integration_status',
77+
'is_mlx_lm_patched',
78+
79+
# Benchmarking
80+
'benchmark_optimization',
81+
'quick_benchmark',
82+
'BenchmarkResult'
83+
]

0 commit comments

Comments
 (0)