|
| 1 | +import triton |
| 2 | +import triton.language as tl |
| 3 | + |
| 4 | + |
| 5 | +@triton.jit |
| 6 | +def softmax_kernel( |
| 7 | + output_ptr, |
| 8 | + input_ptr, |
| 9 | + input_row_stride, |
| 10 | + output_row_stride, |
| 11 | + n_rows, |
| 12 | + n_cols, |
| 13 | + BLOCK_SIZE: tl.constexpr, |
| 14 | + num_stages: tl.constexpr, |
| 15 | +): |
| 16 | + # starting row of the program |
| 17 | + row_start = tl.program_id(0) |
| 18 | + row_step = tl.num_programs(0) |
| 19 | + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): |
| 20 | + # The stride represents how much we need to increase the pointer to advance 1 row |
| 21 | + row_start_ptr = input_ptr + row_idx * input_row_stride |
| 22 | + # The block size is the next power of two greater than n_cols, so we can fit each |
| 23 | + # row in a single block |
| 24 | + col_offsets = tl.arange(0, BLOCK_SIZE) |
| 25 | + input_ptrs = row_start_ptr + col_offsets |
| 26 | + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols |
| 27 | + mask = col_offsets < n_cols |
| 28 | + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) |
| 29 | + # Subtract maximum for numerical stability |
| 30 | + row_minus_max = row - tl.max(row, axis=0) |
| 31 | + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) |
| 32 | + numerator = tl.exp(row_minus_max) |
| 33 | + denominator = tl.sum(numerator, axis=0) |
| 34 | + softmax_output = numerator / denominator |
| 35 | + # Write back output to DRAM |
| 36 | + output_row_start_ptr = output_ptr + row_idx * output_row_stride |
| 37 | + output_ptrs = output_row_start_ptr + col_offsets |
| 38 | + tl.store(output_ptrs, softmax_output, mask=mask) |
0 commit comments