@@ -2121,121 +2121,116 @@ def jagged_sum_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launche
21212121from __future__ import annotations
21222122
21232123import torch
2124- import helion
21252124import triton
21262125import triton.language as tl
2127- from torch._inductor.runtime import triton_helpers
21282126from torch._inductor.runtime.triton_helpers import math as tl_math
21292127from torch._inductor.runtime.triton_compat import libdevice
21302128from helion.runtime import default_launcher as _default_launcher
21312129
21322130@triton.jit
2133- def _helion_jsd_forward(_input, target, loss, dX, _input_stride_0, _input_stride_1, dX_stride_0, dX_stride_1, loss_stride_0, loss_stride_1, target_stride_0, target_stride_1, BT, V, beta, n_non_ignore, _BLOCK_SIZE_0 : tl.constexpr, _BLOCK_SIZE_1 : tl.constexpr):
2131+ def _helion_jsd_forward(_input, target, loss, dX, _input_stride_0, _input_stride_1, dX_stride_0, loss_stride_0, target_stride_0, target_stride_1, BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1 : tl.constexpr, _BLOCK_SIZE_0 : tl.constexpr):
21342132 pid_0 = tl.program_id(0)
2135- offset_0 = pid_0 * _BLOCK_SIZE_0
2136- indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
2137- mask_0 = indices_0 < BT
2138- for offset_1 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_1):
2139- indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
2140- mask_1 = indices_1 < V
2141- X = tl.load(_input + (indices_0[:, None] * _input_stride_0 + indices_1[None, :] * _input_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2142- Y = tl.load(target + (indices_0[:, None] * target_stride_0 + indices_1[None, :] * target_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2143- _mask_to = tl.where(mask_0[:, None] & mask_1[None, :], X, tl.full([], float('-inf'), tl.float32))
2144- X_max = tl.cast(tl.max(_mask_to, 0), tl.float32)
2145- _mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], Y, tl.full([], float('-inf'), tl.float32))
2146- Y_max = tl.cast(tl.max(_mask_to_1, 0), tl.float32)
2133+ offset_1 = pid_0 * _BLOCK_SIZE_1
2134+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
2135+ mask_1 = indices_1 < BT
2136+ intermediate_loss = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2137+ intermediate_dX = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2138+ for offset_0 in tl.range(0, V.to(tl.int32)):
2139+ indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32)
2140+ intermediate_loss_copy = intermediate_loss
2141+ intermediate_dX_copy = intermediate_dX
2142+ intermediate_loss = intermediate_loss_copy
2143+ intermediate_dX = intermediate_dX_copy
2144+ X = tl.load(_input + (indices_1[:, None] * _input_stride_0 + indices_0[None, :] * _input_stride_1), mask_1[:, None], other=0)
2145+ Y = tl.load(target + (indices_1[:, None] * target_stride_0 + indices_0[None, :] * target_stride_1), mask_1[:, None], other=0)
21472146 eq = beta == 0.0
21482147 if eq:
21492148 Y_copy = Y
2150- Y_max_copy = Y_max
21512149 X_copy = X
2150+ intermediate_loss_copy_0_copy = intermediate_loss
2151+ intermediate_dX_copy_0_copy = intermediate_dX
21522152 Y_copy_0 = Y_copy
2153- Y_max_copy_0 = Y_max_copy
21542153 X_copy_0 = X_copy
2155- v_0 = Y_max_copy_0[None, :]
2154+ intermediate_loss_copy_0_copy_0 = intermediate_loss_copy_0_copy
2155+ intermediate_dX_copy_0_copy_0 = intermediate_dX_copy_0_copy
2156+ _mask_to = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), Y_copy_0, tl.full([], float('-inf'), tl.float32))
2157+ Y_max = tl.cast(tl.max(_mask_to, 0), tl.float32)
2158+ v_0 = Y_max[None, :]
21562159 v_1 = Y_copy_0 - v_0
21572160 v_2 = libdevice.exp(v_1)
2158- v_3 = libdevice.exp(Y_max_copy_0 )
2161+ v_3 = libdevice.exp(Y_max )
21592162 v_4 = v_3[None, :]
21602163 v_5 = v_2 * v_4
21612164 v_6 = Y_copy_0 - X_copy_0
21622165 v_7 = v_5 * v_6
2163- tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_7, mask_0[:, None] & mask_1[None, :])
2164- v_8 = -v_5
2165- tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_8, mask_0[:, None] & mask_1[None, :])
2166+ intermediate_loss = intermediate_loss_copy_0_copy_0 + v_7
2167+ v_9 = -v_5
2168+ intermediate_dX = intermediate_dX_copy_0_copy_0 + v_9
21662169 _not = not eq
21672170 if _not:
21682171 X_copy_1 = X
2169- X_max_copy = X_max
21702172 Y_copy_1 = Y
2171- Y_max_copy_1 = Y_max
2173+ intermediate_loss_copy_0_copy_1 = intermediate_loss
2174+ intermediate_dX_copy_0_copy_1 = intermediate_dX
21722175 X_copy_1_0 = X_copy_1
2173- X_max_copy_0 = X_max_copy
21742176 Y_copy_1_0 = Y_copy_1
2175- Y_max_copy_1_0 = Y_max_copy_1
2177+ intermediate_loss = intermediate_loss_copy_0_copy_1
2178+ intermediate_dX = intermediate_dX_copy_0_copy_1
21762179 eq_1 = beta == 1.0
21772180 if eq_1:
21782181 X_copy_1_0_copy = X_copy_1_0
2179- X_max_copy_0_copy = X_max_copy_0
21802182 Y_copy_1_0_copy = Y_copy_1_0
2183+ intermediate_loss_copy_0_copy_1_0_copy = intermediate_loss
2184+ intermediate_dX_copy_0_copy_1_0_copy = intermediate_dX
21812185 X_copy_1_0_copy_0 = X_copy_1_0_copy
2182- X_max_copy_0_copy_0 = X_max_copy_0_copy
21832186 Y_copy_1_0_copy_0 = Y_copy_1_0_copy
2184- v_9 = X_max_copy_0_copy_0[None, :]
2185- v_10 = X_copy_1_0_copy_0 - v_9
2186- v_11 = libdevice.exp(v_10)
2187- v_12 = libdevice.exp(X_max_copy_0_copy_0)
2188- v_13 = v_12[None, :]
2189- v_14 = v_11 * v_13
2190- v_15 = X_copy_1_0_copy_0 - Y_copy_1_0_copy_0
2191- v_16 = v_14 * v_15
2192- tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_16, mask_0[:, None] & mask_1[None, :])
2193- load = tl.load(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2194- v_17 = load + v_14
2195- tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_17, mask_0[:, None] & mask_1[None, :])
2187+ intermediate_loss_copy_0_copy_1_0_copy_0 = intermediate_loss_copy_0_copy_1_0_copy
2188+ intermediate_dX_copy_0_copy_1_0_copy_0 = intermediate_dX_copy_0_copy_1_0_copy
2189+ _mask_to_1 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), X_copy_1_0_copy_0, tl.full([], float('-inf'), tl.float32))
2190+ X_max = tl.cast(tl.max(_mask_to_1, 0), tl.float32)
2191+ v_11 = X_max[None, :]
2192+ v_12 = X_copy_1_0_copy_0 - v_11
2193+ v_13 = libdevice.exp(v_12)
2194+ v_14 = libdevice.exp(X_max)
2195+ v_15 = v_14[None, :]
2196+ v_16 = v_13 * v_15
2197+ v_17 = X_copy_1_0_copy_0 - Y_copy_1_0_copy_0
2198+ v_18 = v_16 * v_17
2199+ intermediate_loss = intermediate_loss_copy_0_copy_1_0_copy_0 + v_18
2200+ v_20 = intermediate_loss + v_16
2201+ intermediate_dX = intermediate_dX_copy_0_copy_1_0_copy_0 + v_20
21962202 _not_1 = not eq_1
21972203 if _not_1:
2198- X_max_copy_0_copy_1 = X_max_copy_0
2199- Y_max_copy_1_0_copy = Y_max_copy_1_0
22002204 X_copy_1_0_copy_1 = X_copy_1_0
22012205 Y_copy_1_0_copy_1 = Y_copy_1_0
2202- X_max_copy_0_copy_1_0 = X_max_copy_0_copy_1
2203- Y_max_copy_1_0_copy_0 = Y_max_copy_1_0_copy
2206+ intermediate_loss_copy_0_copy_1_0_copy_1 = intermediate_loss
2207+ intermediate_dX_copy_0_copy_1_0_copy_1 = intermediate_dX
22042208 X_copy_1_0_copy_1_0 = X_copy_1_0_copy_1
22052209 Y_copy_1_0_copy_1_0 = Y_copy_1_0_copy_1
2206- v_18 = triton_helpers.maximum(X_max_copy_0_copy_1_0, Y_max_copy_1_0_copy_0)
2207- v_19 = v_18[None, :]
2208- v_20 = X_copy_1_0_copy_1_0 - v_19
2209- v_21 = v_18[None, :]
2210- v_22 = Y_copy_1_0_copy_1_0 - v_21
2211- v_23 = libdevice.exp(v_18)
2212- v_24 = libdevice.exp(v_20)
2213- v_25 = v_23[None, :]
2214- v_26 = v_24 * v_25
2215- v_27 = libdevice.exp(v_22)
2216- v_28 = v_23[None, :]
2217- v_29 = v_27 * v_28
2218- v_30 = v_29 * beta
2219- sub_2 = 1.0 + -1 * beta
2220- v_31 = v_26 * sub_2
2221- v_32 = v_30 + v_31
2222- v_33 = tl_math.log(v_32)
2223- v_34 = v_30 * Y_copy_1_0_copy_1_0
2224- v_35 = v_31 * X_copy_1_0_copy_1_0
2225- v_36 = v_34 + v_35
2226- v_37 = v_32 * v_33
2227- v_38 = v_36 - v_37
2228- tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_38, mask_0[:, None] & mask_1[None, :])
2229- v_39 = X_copy_1_0_copy_1_0 - v_33
2230- v_40 = v_31 * v_39
2231- tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_40, mask_0[:, None] & mask_1[None, :])
2232- truediv = 1.0 / n_non_ignore
2233- load_2 = tl.load(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2234- v_41 = load_2 * truediv
2235- tl.store(loss + (indices_0[:, None] * loss_stride_0 + indices_1[None, :] * loss_stride_1), v_41, mask_0[:, None] & mask_1[None, :])
2236- load_3 = tl.load(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
2237- v_42 = load_3 * truediv
2238- tl.store(dX + (indices_0[:, None] * dX_stride_0 + indices_1[None, :] * dX_stride_1), v_42, mask_0[:, None] & mask_1[None, :])
2210+ intermediate_loss_copy_0_copy_1_0_copy_1_0 = intermediate_loss_copy_0_copy_1_0_copy_1
2211+ intermediate_dX_copy_0_copy_1_0_copy_1_0 = intermediate_dX_copy_0_copy_1_0_copy_1
2212+ v_22 = libdevice.exp(X_copy_1_0_copy_1_0)
2213+ v_23 = libdevice.exp(Y_copy_1_0_copy_1_0)
2214+ v_24 = v_23 * beta
2215+ v_25 = v_22 * one_minus_beta
2216+ v_26 = v_24 + v_25
2217+ v_27 = tl_math.log(v_26)
2218+ v_28 = X_copy_1_0_copy_1_0 - v_27
2219+ v_29 = v_25 * v_28
2220+ v_30 = Y_copy_1_0_copy_1_0 - v_27
2221+ v_31 = v_24 * v_30
2222+ v_32 = v_31 + v_29
2223+ intermediate_loss = intermediate_loss_copy_0_copy_1_0_copy_1_0 + v_32
2224+ intermediate_dX = intermediate_dX_copy_0_copy_1_0_copy_1_0 + v_29
2225+ truediv = 1.0 / n_non_ignore
2226+ v_35 = intermediate_loss * truediv
2227+ _mask_to_2 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), v_35, tl.full([], 0, tl.float32))
2228+ sum_1 = tl.cast(tl.sum(_mask_to_2, 1), tl.float32)
2229+ tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1)
2230+ v_36 = intermediate_dX * truediv
2231+ _mask_to_3 = tl.where(tl.broadcast_to(mask_1[:, None], [_BLOCK_SIZE_1, _BLOCK_SIZE_0]), v_36, tl.full([], 0, tl.float32))
2232+ sum_2 = tl.cast(tl.sum(_mask_to_3, 1), tl.float32)
2233+ tl.store(dX + indices_1 * dX_stride_0, sum_2, mask_1)
22392234
22402235def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None, beta: float=0.5, ignore_index: int=-100, *, _launcher=_default_launcher):
22412236 """
@@ -2254,18 +2249,16 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None
22542249 """
22552250 BT, V = _input.shape
22562251 assert target.shape == _input.shape, f'Shape mismatch: {target.shape} != {_input.shape}'
2257- n_rows = BT
2258- loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device )
2259- dX = torch.empty_like(_input)
2252+ loss = torch.zeros([BT], dtype=torch.float32, device=_input.device)
2253+ dX = torch.empty_like(loss )
2254+ one_minus_beta = 1 - beta
22602255 n_non_ignore = float(BT)
22612256 if shift_labels is not None:
22622257 n_non_ignore = float((shift_labels != ignore_index).sum().item())
22632258 if n_non_ignore == 0:
22642259 return (torch.zeros([], dtype=_input.dtype, device=_input.device), torch.zeros_like(_input))
2265- BT_SIZE = helion.cdiv(BT, n_rows)
2266- _BLOCK_SIZE_0 = BT_SIZE
22672260 _BLOCK_SIZE_1 = 4096
2268- _launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_0 ),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), dX.stride(1), loss.stride(0), loss.stride(1), target.stride(0), target.stride(1), BT, V, beta, n_non_ignore, _BLOCK_SIZE_0 , _BLOCK_SIZE_1, num_warps=4, num_stages=3)
2261+ _launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_1 ),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), loss.stride(0), target.stride(0), target.stride(1), BT, V, beta, one_minus_beta, n_non_ignore , _BLOCK_SIZE_1, 1 , num_warps=4, num_stages=3)
22692262 final_loss = torch.sum(loss)
22702263 return (final_loss, dX)
22712264
0 commit comments