@@ -2266,7 +2266,6 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None
22662266from __future__ import annotations
22672267
22682268import torch
2269- import helion
22702269import triton
22712270import triton.language as tl
22722271from torch._inductor.runtime import triton_helpers
@@ -2275,41 +2274,44 @@ from torch._inductor.runtime.triton_compat import libdevice
22752274from helion.runtime import default_launcher as _default_launcher
22762275
22772276@triton.jit
2278- def _helion_kl_div_forward(y_pred, y_true, kl_loss, loss, kl_loss_stride_0, kl_loss_stride_1 , loss_stride_0, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, log_target, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
2277+ def _helion_kl_div_forward(y_pred, y_true, loss, loss_stride_0, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, log_target, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
22792278 pid_0 = tl.program_id(0)
22802279 offset_1 = pid_0 * _BLOCK_SIZE_1
22812280 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
22822281 mask_1 = indices_1 < BT
22832282 loss_sum = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2284- for offset_0 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_0):
2285- indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
2286- mask_0 = indices_0 < V
2283+ for offset_0 in tl.range(0, V.to(tl.int32)):
2284+ indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32)
22872285 loss_sum_copy = loss_sum
22882286 loss_sum_copy_0 = loss_sum_copy
2289- y_pred_val = tl.load(y_pred + (indices_1[:, None] * y_pred_stride_0 + indices_0[None, :] * y_pred_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
2290- y_true_val = tl.load(y_true + (indices_1[:, None] * y_true_stride_0 + indices_0[None, :] * y_true_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
2287+ kl_loss = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2288+ y_pred_val = tl.load(y_pred + (indices_1[:, None] * y_pred_stride_0 + indices_0[None, :] * y_pred_stride_1), mask_1[:, None], other=0)
2289+ y_true_val = tl.load(y_true + (indices_1[:, None] * y_true_stride_0 + indices_0[None, :] * y_true_stride_1), mask_1[:, None], other=0)
22912290 if log_target:
22922291 y_true_val_copy = y_true_val
22932292 y_pred_val_copy = y_pred_val
2293+ kl_loss_copy = kl_loss
22942294 y_true_val_copy_0 = y_true_val_copy
22952295 y_pred_val_copy_0 = y_pred_val_copy
2296+ kl_loss_copy_0 = kl_loss_copy
22962297 v_0 = libdevice.exp(y_true_val_copy_0)
22972298 v_1 = y_true_val_copy_0 - y_pred_val_copy_0
22982299 v_2 = v_0 * v_1
2299- tl.store( kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), v_2, mask_1[:, None] & mask_0[None, :])
2300+ kl_loss = kl_loss_copy_0 + v_2
23002301 _not = not log_target
23012302 if _not:
23022303 y_true_val_copy_1 = y_true_val
23032304 y_pred_val_copy_1 = y_pred_val
2305+ kl_loss_copy_1 = kl_loss
23042306 y_true_val_copy_1_0 = y_true_val_copy_1
23052307 y_pred_val_copy_1_0 = y_pred_val_copy_1
2306- v_3 = triton_helpers.maximum(y_true_val_copy_1_0, eps)
2307- v_4 = tl_math.log(v_3 )
2308- v_5 = v_4 - y_pred_val_copy_1_0
2309- v_6 = y_true_val_copy_1_0 * v_5
2310- tl.store(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), v_6, mask_1[:, None] & mask_0[None, :])
2311- load_2 = tl.load(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
2312- loss_sum = loss_sum_copy_0 + load_2
2308+ kl_loss_copy_1_0 = kl_loss_copy_1
2309+ v_4 = triton_helpers.maximum(y_true_val_copy_1_0, eps )
2310+ v_5 = tl_math.log( v_4)
2311+ v_6 = v_5 - y_pred_val_copy_1_0
2312+ v_7 = y_true_val_copy_1_0 * v_6
2313+ kl_loss = kl_loss_copy_1_0 + v_7
2314+ loss_sum = loss_sum_copy_0 + kl_loss
23132315 sum_1 = tl.cast(tl.sum(loss_sum, 1), tl.float32)
23142316 tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1)
23152317
@@ -2333,11 +2335,8 @@ def kl_div_forward(y_pred: Tensor, y_true: Tensor, log_target: bool=False, reduc
23332335 loss = torch.zeros_like(y_pred)
23342336 else:
23352337 loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
2336- kl_loss = torch.zeros_like(y_pred)
2337- BT_SIZE = helion.cdiv(BT, BT)
2338- _BLOCK_SIZE_1 = BT_SIZE
2339- _BLOCK_SIZE_0 = 4096
2340- _launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, kl_loss, loss, kl_loss.stride(0), kl_loss.stride(1), loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
2338+ _BLOCK_SIZE_1 = 4096
2339+ _launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, loss, loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=3)
23412340 if reduction == 'batchmean':
23422341 final_loss = torch.sum(loss) / BT
23432342 elif reduction == 'sum':
0 commit comments