Skip to content

Commit a6ac8e2

Browse files
committed
Merge branch 'main' into optimization/flash_attention
2 parents bf168a6 + 12aa536 commit a6ac8e2

File tree

9 files changed

+412
-265
lines changed

9 files changed

+412
-265
lines changed

README.md

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,29 @@ Previous intergration of TornadoVM and Llama2 it can be found in <a href="https:
3131

3232
-----------
3333

34+
### TornadoVM-Accelerated Inference Performance and Optimization Status
35+
36+
This table shows inference performance across different hardware and quantization options.
37+
38+
| Hardware | Llama-3.2-1B-Instruct | Llama-3.2-1B-Instruct | Llama-3.2-3B-Instruct | Optimizations |
39+
|:------------:|:---------------------:|:---------------------:|:---------------------:|:-------------:|
40+
| | **Q8_0** | **Q4_0** | **Q4_0** | **Support** |
41+
| **NVIDIA GPUs** | | | | |
42+
| RTX 3070 | 42.3 tokens/s | 78.6 tokens/s | 22.1 tokens/s ||
43+
| RTX 4090 | 96.7 tokens/s | 158.2 tokens/s | 52.9 tokens/s ||
44+
| RTX 5090 | 156.8 tokens/s | 243.5 tokens/s | 84.7 tokens/s ||
45+
| H100 | 178.3 tokens/s | 289.7 tokens/s | 102.5 tokens/s ||
46+
| **Apple Silicon** | | | | |
47+
| M3 Pro | 18.4 tokens/s | 35.7 tokens/s | 11.6 tokens/s ||
48+
| M4 Pro | 28.9 tokens/s | 52.3 tokens/s | 17.2 tokens/s ||
49+
| **AMD GPUs** | | | | |
50+
| Radeon RX | (WIP) | (WIP) | (WIP) ||
51+
52+
> **Note**: ✅ indicates hardware with optimized kernels for maximum performance.
53+
> Benchmark details: Settings used include context length of 4096, batch size 1, and default parameters.
54+
55+
-----------
56+
3457
### ✅ Current Features
3558

3659
- **TornadoVM-accelerated Llama 3 inference** with pure Java
@@ -375,10 +398,8 @@ llama-tornado --gpu --model Llama-3.2-1B-Instruct-Q8_0.gguf --prompt "tell me a
375398
```java
376399
/home/mikepapadim/.sdkman/candidates/java/current/bin/java \
377400
-server \
378-
-XX:-UseCompressedOops \
379401
-XX:+UnlockExperimentalVMOptions \
380402
-XX:+EnableJVMCI \
381-
-XX:-UseCompressedClassPointers \
382403
-Xms20g -Xmx20g \
383404
--enable-preview \
384405
-Djava.library.path=/home/mikepapadim/manchester/TornadoVM/bin/sdk/lib \
@@ -406,7 +427,6 @@ llama-tornado --gpu --model Llama-3.2-1B-Instruct-Q8_0.gguf --prompt "tell me a
406427
-Dtornado.eventpool.maxwaitevents=32000 \
407428
"-Dtornado.opencl.compiler.flags=-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only" \
408429
--upgrade-module-path /home/mikepapadim/manchester/TornadoVM/bin/sdk/share/java/graalJars \
409-
-XX:+UseParallelGC \
410430
@/home/mikepapadim/manchester/TornadoVM/bin/sdk/etc/exportLists/common-exports \
411431
@/home/mikepapadim/manchester/TornadoVM/bin/sdk/etc/exportLists/opencl-exports \
412432
--add-modules ALL-SYSTEM,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.opencl \
@@ -447,6 +467,68 @@ The secret sauce that transforms regular Java code into GPU-accelerated compute
447467

448468
-----------
449469

470+
## TornadoVM Transformer Optimizations
471+
472+
### Core Numerical Optimizations
473+
- **Quantized Weight Support**
474+
- Optimized implementations for Q8_0 and Q4_0 formats
475+
- Block-based quantization with FP16 scale per 32-element block
476+
- **Vectorized Matrix Operations**
477+
- Uses vector parallelism with configurable unroll factors
478+
- Processes 4 elements at once with vectorization
479+
- **Loop Unrolling**
480+
- Strategic unrolling for performance (16x factor in matrix operations)
481+
- Reduces branch penalties and improves instruction-level parallelism
482+
- **Fused Multiply-Add (FMA)**
483+
- Uses fused operations for better numerical precision and performance
484+
- Optimizes dot product calculations
485+
486+
### Memory and Caching Optimizations
487+
- **Key-Value Cache**
488+
- Efficiently stores past key-values for autoregressive generation
489+
- Organized by layer, position, and dimension for fast access
490+
- **Scale Caching**
491+
- Avoids redundant decompression of quantized weights
492+
- Caches scale factors for efficient block processing
493+
- **Optimized GPU Memory Transfers**
494+
- Minimizes host-device data movement
495+
- One-time transfer of static data (weights, caches)
496+
- Per-execution transfer of dynamic data (position, activations)
497+
- **Device-to-Device Data Consumption**
498+
- Efficient data transfer between operations
499+
- Reduces PCI-E bandwidth bottlenecks
500+
501+
### Algorithmic Optimizations
502+
- **Parallel Reduction RMS Normalization**
503+
- Implements two-phase reduction for efficient normalization
504+
- Work group optimization for parallel sums
505+
- **Rotary Position Embeddings (RoPE)**
506+
- Optimized implementation for positional encoding
507+
- Efficient rotation of query and key vectors
508+
- **Optimized Float16 Decoding**
509+
- Fast decoder for half-precision floating point format
510+
- Special case handling for better performance
511+
- **Parallelized Attention**
512+
- Computes attention heads in parallel
513+
- Optimized softmax with max subtraction for numerical stability
514+
- **Fused Feed-Forward Networks**
515+
- Combines operations for SwiGLU variant used in LLaMA models
516+
- Optimized SiLU and GELU activation functions
517+
518+
### GPU Execution Optimizations
519+
- **Layered Execution Planning**
520+
- Organizes computation as separate layer-based task graphs
521+
- Strategic scheduling of operations
522+
- **Work Group Optimization**
523+
- Tailored worker grid configurations for different operations
524+
- Matches GPU hardware characteristics
525+
- **Local Memory Optimization**
526+
- Strategic use of local/shared memory for reductions
527+
- Optimizes bandwidth-intensive operations
528+
529+
-----------
530+
531+
450532
## Early performance of v1.0
451533

452534
![GPULlama3.java Performance Comparison](./docs/performance.png)

llama-tornado

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class LlamaRunner:
2828
if not all([self.java_home, self.tornado_sdk, self.llama_root]):
2929
print("Error: Required environment variables not set")
3030
print("Please ensure JAVA_HOME, TORNADO_SDK, and LLAMA_ROOT are defined")
31-
print("Note: check set_path.sh in root dir")
31+
print("Note: check set_path in root dir -> source set_path")
3232
sys.exit(1)
3333

3434
def _validate_paths(self):
@@ -49,10 +49,8 @@ class LlamaRunner:
4949
cmd = [
5050
f"{self.java_home}/bin/java",
5151
"-server",
52-
"-XX:-UseCompressedOops",
5352
"-XX:+UnlockExperimentalVMOptions",
5453
"-XX:+EnableJVMCI",
55-
"-XX:-UseCompressedClassPointers",
5654
f"-Xms{args.heap_min}",
5755
f"-Xmx{args.heap_max}",
5856
"--enable-preview",
@@ -106,7 +104,7 @@ class LlamaRunner:
106104
f"-Dtornado.profiler.dump.dir={args.profiler_dump_dir}",
107105
"-Dtornado.enable.fastMathOptimizations=true",
108106
"-Dtornado.enable.mathOptimizations=false",
109-
"-Dtornado.enable.nativeFunctions=fast",
107+
"-Dtornado.enable.nativeFunctions=true",
110108
"-Dtornado.loop.interchange=true",
111109
f"-Dtornado.eventpool.maxwaitevents={args.max_wait_events}"
112110
]
@@ -120,20 +118,18 @@ class LlamaRunner:
120118
# Module configuration - varies by backend
121119
module_config = [
122120
f"--upgrade-module-path", f"{self.tornado_sdk}/share/java/graalJars",
123-
"-XX:+UseParallelGC",
124121
f"@{self.tornado_sdk}/etc/exportLists/common-exports",
125122
]
126-
127123
# Add backend-specific exports and modules
128124
if args.backend == Backend.OPENCL:
129125
module_config.extend([
130126
f"@{self.tornado_sdk}/etc/exportLists/opencl-exports",
131-
"--add-modules", "ALL-SYSTEM,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.opencl",
127+
"--add-modules", "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.opencl",
132128
])
133129
elif args.backend == Backend.PTX:
134130
module_config.extend([
135131
f"@{self.tornado_sdk}/etc/exportLists/ptx-exports",
136-
"--add-modules", "ALL-SYSTEM,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.ptx",
132+
"--add-modules", "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.ptx",
137133
])
138134

139135
module_config.extend([
@@ -335,4 +331,4 @@ def main():
335331
return runner.run(args)
336332

337333
if __name__ == "__main__":
338-
sys.exit(main())
334+
sys.exit(main())

0 commit comments

Comments
 (0)