Skip to content

Commit 59aeb6b

Browse files
authored
[Gluon] Require warp_specialize default_args and worker_args be tuples (#8368)
I'm not a fan of using `isinstance` here as it leads to an inconsistency. Say we want to pass a single argument `x` that happens to be a tuple of tensors, then that has to be passed as `(x,)` or otherwise the tuple will be interpreted as a list of arguments to pass to the worker function. The original PR (#8269) justified this as improving the error message, so I also add a validation check with a better error message if you forget to pass a tuple.
1 parent 1888f81 commit 59aeb6b

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/triton/experimental/gluon/language/_core.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,6 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti
509509
"""
510510
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
511511
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
512-
if not isinstance(default_args, tuple):
513-
default_args = (default_args, )
514-
if not isinstance(worker_args, tuple):
515-
worker_args = (worker_args, )
516512
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
517513
worker_num_regs, _generator)
518514

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
420420
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
421421
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
422422
num_partitions = len(worker_partitions)
423+
_check(isinstance(default_args, (tuple, ttgl.tuple)),
424+
lambda: f"default_args must be a tuple of arguments, but got {type(default_args)}")
425+
_check(isinstance(worker_args, (tuple, ttgl.tuple)),
426+
lambda: f"worker_args must be a tuple of arguments, but got {type(worker_args)}")
423427
assert num_partitions == len(
424428
worker_num_warps
425429
), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts"

0 commit comments

Comments
 (0)