Skip to content

Commit 432c653

Browse files
authored
Faster Helion JSD (#733)
1 parent 6c49259 commit 432c653

File tree

3 files changed

+105
-115
lines changed

3 files changed

+105
-115
lines changed

examples/jsd.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,14 @@ def jsd_forward(
6363
assert target.shape == _input.shape, (
6464
f"Shape mismatch: {target.shape} != {_input.shape}"
6565
)
66-
n_rows = BT
66+
block_size_n = hl.register_block_size(V)
67+
block_size_m = hl.register_block_size(BT)
6768

6869
# Create output tensor for accumulating loss
69-
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
70-
dX = torch.empty_like(_input)
70+
loss = torch.zeros([BT], dtype=torch.float32, device=_input.device)
71+
dX = torch.empty_like(loss)
72+
73+
one_minus_beta = 1 - beta
7174

7275
# Count non-ignored elements
7376
n_non_ignore = float(BT)
@@ -79,60 +82,54 @@ def jsd_forward(
7982
), torch.zeros_like(_input)
8083

8184
# Process each sequence position
82-
BT_SIZE = helion.cdiv(BT, n_rows) # The liger kernel uses 1
83-
for tile_bt in hl.tile(BT, block_size=BT_SIZE):
85+
for tile_bt in hl.tile(BT, block_size=block_size_m):
8486
# Check for label masking
8587
if shift_labels is not None:
8688
if shift_labels[tile_bt] == ignore_index:
8789
for tile_X in hl.tile(V):
8890
dX[tile_bt, tile_X] = 0.0
8991
continue
90-
91-
for tile_v in hl.tile(V):
92+
intermediate_loss = hl.zeros([tile_bt, block_size_n], dtype=torch.float32)
93+
intermediate_dX = hl.zeros([tile_bt, block_size_n], dtype=_input.dtype)
94+
for tile_v in hl.tile(V, block_size=block_size_n):
9295
# Load log probabilities and convert to float32
9396
X = _input[tile_bt, tile_v]
9497
Y = target[tile_bt, tile_v]
95-
X_max = torch.amax(X, dim=0)
96-
Y_max = torch.amax(Y, dim=0)
9798

9899
if beta == 0.0: # Forward KL: KL(P || Q)
100+
Y_max = torch.amax(Y, dim=0)
99101
Y_shift = Y - Y_max
100102
Y_prob = torch.exp(Y_shift) * torch.exp(
101103
Y_max
102104
) # Compensate for the shift
103-
loss[tile_bt, tile_v] = Y_prob * (Y - X)
104-
dX[tile_bt, tile_v] = -Y_prob
105+
intermediate_loss += Y_prob * (Y - X)
106+
intermediate_dX += -Y_prob
105107
elif beta == 1.0: # Reverse KL: KL(Q || P)
108+
X_max = torch.amax(X, dim=0)
106109
X_shift = X - X_max
107110
X_prob = torch.exp(X_shift) * torch.exp(
108111
X_max
109112
) # Compensate for the shift
110-
loss[tile_bt, tile_v] = X_prob * (X - Y)
111-
dX[tile_bt, tile_v] = loss[tile_bt, tile_v] + X_prob
113+
intermediate_loss += X_prob * (X - Y)
114+
intermediate_dX += intermediate_loss + X_prob
112115
else: # General JSD: beta*KL(P||M) + (1-beta)*KL(Q||M)
113-
max_val = torch.maximum(X_max, Y_max)
114-
X_shifted = X - max_val
115-
Y_shifted = Y - max_val
116-
117-
exp_max = torch.exp(max_val)
118-
119-
Q = torch.exp(X_shifted) * exp_max # = exp(X)
120-
P = torch.exp(Y_shifted) * exp_max # = exp(Y)
116+
Q = torch.exp(X) # = exp(X)
117+
P = torch.exp(Y) # = exp(Y)
121118

122119
beta_P = beta * P
123-
one_minus_beta_Q = (1 - beta) * Q
120+
one_minus_beta_Q = one_minus_beta * Q
124121
M = beta_P + one_minus_beta_Q
125-
log_M = torch.log(
126-
M
127-
) # No need to compensate as M is already in original scale
122+
log_M = torch.log(M)
123+
x_minus_log_m = X - log_M
124+
kl_q_m = one_minus_beta_Q * x_minus_log_m
128125

129-
loss[tile_bt, tile_v] = beta_P * Y + one_minus_beta_Q * X - M * log_M
130-
dX[tile_bt, tile_v] = one_minus_beta_Q * (X - log_M)
126+
intermediate_loss += beta_P * (Y - log_M) + kl_q_m
127+
intermediate_dX += kl_q_m
131128

132-
# Accumulate over vocabulary dimension
133-
scale = 1.0 / n_non_ignore
134-
loss[tile_bt, tile_v] = loss[tile_bt, tile_v] * scale
135-
dX[tile_bt, tile_v] = dX[tile_bt, tile_v] * scale
129+
# Accumulate over vocabulary dimension
130+
scale = 1.0 / n_non_ignore
131+
loss[tile_bt] = torch.sum(intermediate_loss * scale, dim=1)
132+
dX[tile_bt] = torch.sum(intermediate_dX * scale, dim=1)
136133

137134
# Normalize by number of non-ignored elements, run it on host to match liger_kernel
138135
final_loss = torch.sum(

test/test_examples.expected

Lines changed: 76 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,121 +2121,116 @@ def jagged_sum_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launche
21212121
from __future__ import annotations
21222122

21232123
import torch
2124-
import helion
21252124
import triton
21262125
import triton.language as tl
2127-
from torch._inductor.runtime import triton_helpers
21282126
from torch._inductor.runtime.triton_helpers import math as tl_math
21292127
from torch._inductor.runtime.triton_compat import libdevice
21302128
from helion.runtime import default_launcher as _default_launcher
21312129

21322130
@triton.jit
2133-
def _helion_jsd_forward(_input, target, loss, dX, _input_stride_0, _input_stride_1, dX_stride_0, dX_stride_1, loss_stride_0, loss_stride_1, target_stride_0, target_stride_1, BT, V, beta, n_non_ignore, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
2131+
def _helion_jsd_forward(_input, target, loss, dX, _input_stride_0, _input_stride_1, dX_stride_0, loss_stride_0, target_stride_0, target_stride_1, BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
21342132
pid_0 = tl.program_id(0)
2135-
offset_0 = pid_0 * _BLOCK_SIZE_0
2136-
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
2137-
mask_0 = indices_0 < BT
2138-
for offset_1 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_1):
2139-
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
2140-
mask_1 = indices_1 < V
2141-
X = tl.load(_input + (indices_0[:, None] * _input_stride_0 + indices_1[None, :] * _input_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2142-
Y = tl.load(target + (indices_0[:, None] * target_stride_0 + indices_1[None, :] * target_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2143-
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], X, tl.full([], float('-inf'), tl.float32))
2144-
X_max = tl.cast(tl.max(_mask_to, 0), tl.float32)
2145-
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], Y, tl.full([], float('-inf'), tl.float32))
2146-
Y_max = tl.cast(tl.max(_mask_to_1, 0), tl.float32)
2133+
offset_1 = pid_0 * _BLOCK_SIZE_1
2134+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
2135+
mask_1 = indices_1 < BT
2136+
intermediate_loss = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2137+
intermediate_dX = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2138+
for offset_0 in tl.range(0, V.to(tl.int32)):
2139+
indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32)
2140+
intermediate_loss_copy = intermediate_loss
2141+
intermediate_dX_copy = intermediate_dX
2142+
intermediate_loss = intermediate_loss_copy
2143+
intermediate_dX = intermediate_dX_copy
2144+
X = tl.load(_input + (indices_1[:, None] * _input_stride_0 + indices_0[None, :] * _input_stride_1), mask_1[:, None], other=0)
2145+
Y = tl.load(target + (indices_1[:, None] * target_stride_0 + indices_0[None, :] * target_stride_1), mask_1[:, None], other=0)
21472146
eq = beta == 0.0
21482147
if eq:
21492148
Y_copy = Y
2150-
Y_max_copy = Y_max
21512149
X_copy = X
2150+
intermediate_loss_copy_0_copy = intermediate_loss
2151+
intermediate_dX_copy_0_copy = intermediate_dX
21522152
Y_copy_0 = Y_copy
2153-
Y_max_copy_0 = Y_max_copy
21542153
X_copy_0 = X_copy
2155-
v_0 = Y_max_copy_0[None, :]
2154+
intermediate_loss_copy_0_copy_0 = intermediate_loss_copy_0_copy
2155+
intermediate_dX_copy_0_copy_0 = intermediate_dX_copy_0_copy
2156+
_mask_to = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), Y_copy_0, tl.full([], float('-inf'), tl.float32))
2157+
Y_max = tl.cast(tl.max(_mask_to, 0), tl.float32)
2158+
v_0 = Y_max[None, :]
21562159
v_1 = Y_copy_0 - v_0
21572160
v_2 = libdevice.exp(v_1)
2158-
v_3 = libdevice.exp(Y_max_copy_0)
2161+
v_3 = libdevice.exp(Y_max)
21592162
v_4 = v_3[None, :]
21602163
v_5 = v_2 * v_4
21612164
v_6 = Y_copy_0 - X_copy_0
21622165
v_7 = v_5 * v_6
2163-
tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_7, mask_0[:, None] & mask_1[None, :])
2164-
v_8 = -v_5
2165-
tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_8, mask_0[:, None] & mask_1[None, :])
2166+
intermediate_loss = intermediate_loss_copy_0_copy_0 + v_7
2167+
v_9 = -v_5
2168+
intermediate_dX = intermediate_dX_copy_0_copy_0 + v_9
21662169
_not = not eq
21672170
if _not:
21682171
X_copy_1 = X
2169-
X_max_copy = X_max
21702172
Y_copy_1 = Y
2171-
Y_max_copy_1 = Y_max
2173+
intermediate_loss_copy_0_copy_1 = intermediate_loss
2174+
intermediate_dX_copy_0_copy_1 = intermediate_dX
21722175
X_copy_1_0 = X_copy_1
2173-
X_max_copy_0 = X_max_copy
21742176
Y_copy_1_0 = Y_copy_1
2175-
Y_max_copy_1_0 = Y_max_copy_1
2177+
intermediate_loss = intermediate_loss_copy_0_copy_1
2178+
intermediate_dX = intermediate_dX_copy_0_copy_1
21762179
eq_1 = beta == 1.0
21772180
if eq_1:
21782181
X_copy_1_0_copy = X_copy_1_0
2179-
X_max_copy_0_copy = X_max_copy_0
21802182
Y_copy_1_0_copy = Y_copy_1_0
2183+
intermediate_loss_copy_0_copy_1_0_copy = intermediate_loss
2184+
intermediate_dX_copy_0_copy_1_0_copy = intermediate_dX
21812185
X_copy_1_0_copy_0 = X_copy_1_0_copy
2182-
X_max_copy_0_copy_0 = X_max_copy_0_copy
21832186
Y_copy_1_0_copy_0 = Y_copy_1_0_copy
2184-
v_9 = X_max_copy_0_copy_0[None, :]
2185-
v_10 = X_copy_1_0_copy_0 - v_9
2186-
v_11 = libdevice.exp(v_10)
2187-
v_12 = libdevice.exp(X_max_copy_0_copy_0)
2188-
v_13 = v_12[None, :]
2189-
v_14 = v_11 * v_13
2190-
v_15 = X_copy_1_0_copy_0 - Y_copy_1_0_copy_0
2191-
v_16 = v_14 * v_15
2192-
tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_16, mask_0[:, None] & mask_1[None, :])
2193-
load = tl.load(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2194-
v_17 = load + v_14
2195-
tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_17, mask_0[:, None] & mask_1[None, :])
2187+
intermediate_loss_copy_0_copy_1_0_copy_0 = intermediate_loss_copy_0_copy_1_0_copy
2188+
intermediate_dX_copy_0_copy_1_0_copy_0 = intermediate_dX_copy_0_copy_1_0_copy
2189+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), X_copy_1_0_copy_0, tl.full([], float('-inf'), tl.float32))
2190+
X_max = tl.cast(tl.max(_mask_to_1, 0), tl.float32)
2191+
v_11 = X_max[None, :]
2192+
v_12 = X_copy_1_0_copy_0 - v_11
2193+
v_13 = libdevice.exp(v_12)
2194+
v_14 = libdevice.exp(X_max)
2195+
v_15 = v_14[None, :]
2196+
v_16 = v_13 * v_15
2197+
v_17 = X_copy_1_0_copy_0 - Y_copy_1_0_copy_0
2198+
v_18 = v_16 * v_17
2199+
intermediate_loss = intermediate_loss_copy_0_copy_1_0_copy_0 + v_18
2200+
v_20 = intermediate_loss + v_16
2201+
intermediate_dX = intermediate_dX_copy_0_copy_1_0_copy_0 + v_20
21962202
_not_1 = not eq_1
21972203
if _not_1:
2198-
X_max_copy_0_copy_1 = X_max_copy_0
2199-
Y_max_copy_1_0_copy = Y_max_copy_1_0
22002204
X_copy_1_0_copy_1 = X_copy_1_0
22012205
Y_copy_1_0_copy_1 = Y_copy_1_0
2202-
X_max_copy_0_copy_1_0 = X_max_copy_0_copy_1
2203-
Y_max_copy_1_0_copy_0 = Y_max_copy_1_0_copy
2206+
intermediate_loss_copy_0_copy_1_0_copy_1 = intermediate_loss
2207+
intermediate_dX_copy_0_copy_1_0_copy_1 = intermediate_dX
22042208
X_copy_1_0_copy_1_0 = X_copy_1_0_copy_1
22052209
Y_copy_1_0_copy_1_0 = Y_copy_1_0_copy_1
2206-
v_18 = triton_helpers.maximum(X_max_copy_0_copy_1_0, Y_max_copy_1_0_copy_0)
2207-
v_19 = v_18[None, :]
2208-
v_20 = X_copy_1_0_copy_1_0 - v_19
2209-
v_21 = v_18[None, :]
2210-
v_22 = Y_copy_1_0_copy_1_0 - v_21
2211-
v_23 = libdevice.exp(v_18)
2212-
v_24 = libdevice.exp(v_20)
2213-
v_25 = v_23[None, :]
2214-
v_26 = v_24 * v_25
2215-
v_27 = libdevice.exp(v_22)
2216-
v_28 = v_23[None, :]
2217-
v_29 = v_27 * v_28
2218-
v_30 = v_29 * beta
2219-
sub_2 = 1.0 + -1 * beta
2220-
v_31 = v_26 * sub_2
2221-
v_32 = v_30 + v_31
2222-
v_33 = tl_math.log(v_32)
2223-
v_34 = v_30 * Y_copy_1_0_copy_1_0
2224-
v_35 = v_31 * X_copy_1_0_copy_1_0
2225-
v_36 = v_34 + v_35
2226-
v_37 = v_32 * v_33
2227-
v_38 = v_36 - v_37
2228-
tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_38, mask_0[:, None] & mask_1[None, :])
2229-
v_39 = X_copy_1_0_copy_1_0 - v_33
2230-
v_40 = v_31 * v_39
2231-
tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_40, mask_0[:, None] & mask_1[None, :])
2232-
truediv = 1.0 / n_non_ignore
2233-
load_2 = tl.load(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2234-
v_41 = load_2 * truediv
2235-
tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_41, mask_0[:, None] & mask_1[None, :])
2236-
load_3 = tl.load(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2237-
v_42 = load_3 * truediv
2238-
tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_42, mask_0[:, None] & mask_1[None, :])
2210+
intermediate_loss_copy_0_copy_1_0_copy_1_0 = intermediate_loss_copy_0_copy_1_0_copy_1
2211+
intermediate_dX_copy_0_copy_1_0_copy_1_0 = intermediate_dX_copy_0_copy_1_0_copy_1
2212+
v_22 = libdevice.exp(X_copy_1_0_copy_1_0)
2213+
v_23 = libdevice.exp(Y_copy_1_0_copy_1_0)
2214+
v_24 = v_23 * beta
2215+
v_25 = v_22 * one_minus_beta
2216+
v_26 = v_24 + v_25
2217+
v_27 = tl_math.log(v_26)
2218+
v_28 = X_copy_1_0_copy_1_0 - v_27
2219+
v_29 = v_25 * v_28
2220+
v_30 = Y_copy_1_0_copy_1_0 - v_27
2221+
v_31 = v_24 * v_30
2222+
v_32 = v_31 + v_29
2223+
intermediate_loss = intermediate_loss_copy_0_copy_1_0_copy_1_0 + v_32
2224+
intermediate_dX = intermediate_dX_copy_0_copy_1_0_copy_1_0 + v_29
2225+
truediv = 1.0 / n_non_ignore
2226+
v_35 = intermediate_loss * truediv
2227+
_mask_to_2 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), v_35, tl.full([], 0, tl.float32))
2228+
sum_1 = tl.cast(tl.sum(_mask_to_2, 1), tl.float32)
2229+
tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1)
2230+
v_36 = intermediate_dX * truediv
2231+
_mask_to_3 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), v_36, tl.full([], 0, tl.float32))
2232+
sum_2 = tl.cast(tl.sum(_mask_to_3, 1), tl.float32)
2233+
tl.store(dX + indices_1 * dX_stride_0, sum_2, mask_1)
22392234

22402235
def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None, beta: float=0.5, ignore_index: int=-100, *, _launcher=_default_launcher):
22412236
"""
@@ -2254,18 +2249,16 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None
22542249
"""
22552250
BT, V = _input.shape
22562251
assert target.shape == _input.shape, f'Shape mismatch: {target.shape} != {_input.shape}'
2257-
n_rows = BT
2258-
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
2259-
dX = torch.empty_like(_input)
2252+
loss = torch.zeros([BT], dtype=torch.float32, device=_input.device)
2253+
dX = torch.empty_like(loss)
2254+
one_minus_beta = 1 - beta
22602255
n_non_ignore = float(BT)
22612256
if shift_labels is not None:
22622257
n_non_ignore = float((shift_labels != ignore_index).sum().item())
22632258
if n_non_ignore == 0:
22642259
return (torch.zeros([], dtype=_input.dtype, device=_input.device), torch.zeros_like(_input))
2265-
BT_SIZE = helion.cdiv(BT, n_rows)
2266-
_BLOCK_SIZE_0 = BT_SIZE
22672260
_BLOCK_SIZE_1 = 4096
2268-
_launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_0),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), dX.stride(1), loss.stride(0), loss.stride(1), target.stride(0), target.stride(1), BT, V, beta, n_non_ignore, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
2261+
_launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), loss.stride(0), target.stride(0), target.stride(1), BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=3)
22692262
final_loss = torch.sum(loss)
22702263
return (final_loss, dX)
22712264

test/test_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ def test_jsd(self):
11161116
args,
11171117
(expected(*args), None),
11181118
fn_name="jsd_forward",
1119-
block_sizes=[4096],
1119+
block_sizes=[1, 4096],
11201120
num_warps=4,
11211121
num_stages=3,
11221122
)

0 commit comments

Comments
 (0)