Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/10_gemm_all_scatter_wg_specialization/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8)
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)

bias = None

Expand All @@ -157,6 +157,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Allocate Timestamps
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
shmem.barrier()
locks.zero_()
shmem.barrier()

def run_experiment():
nonlocal local_C
nonlocal global_C
Expand Down Expand Up @@ -244,7 +249,7 @@ def run_experiment():
matmul.set_debug(False)
shmem.info("Benchmarking...")
perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3)
triton_ms = iris.do_bench(run_experiment, shmem.barrier)
triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble)
triton_tflops = perf(triton_ms)
algo_string = "all_scatter"
shmem.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ def persistent_gemm_all_scatter_wg_specialization(
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)

tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt")
tl.debug_barrier()
tl.store(locks + tile_id, 1, cache_modifier=".wt")
tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu")

else: # pid >= GEMM_SMS
COMM_SMS = NUM_SMS - GEMM_SMS
Expand All @@ -165,7 +164,7 @@ def persistent_gemm_all_scatter_wg_specialization(
global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global
# End: masks/offset calculations.

while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1:
while tl.atomic_xchg(locks + tile_id, 0, sem="acquire", scope="gpu") != 1:
pass

for remote_rank in range(world_size):
Expand Down
9 changes: 7 additions & 2 deletions examples/11_gemm_all_scatter_producer_consumer/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8)
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)

bias = None

Expand Down Expand Up @@ -166,6 +166,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Allocate Timestamps
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
shmem.barrier()
locks.zero_()
shmem.barrier()

def run_experiment():
nonlocal C
nonlocal kernel_timing
Expand Down Expand Up @@ -275,7 +280,7 @@ def run_experiment():
matmul.set_debug(False)
shmem.info("Benchmarking...")
perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3)
triton_ms = iris.do_bench(run_experiment, shmem.barrier)
triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble)
triton_tflops = perf(triton_ms)
algo_string = "all_scatter"
shmem.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ def persistent_gemm(
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)

tl.store(C + global_offset, c, mask=sub_mask, cache_modifier=".wt")
tl.debug_barrier()
tl.store(locks + tile_id, 1, cache_modifier=".wt")
tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu")


@triton.jit()
Expand Down Expand Up @@ -185,7 +184,7 @@ def persistent_all_scatter(
global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global
# End: masks/offset calculations.

while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1:
while tl.atomic_xchg(locks + tile_id, 0, sem="acquire", scope="gpu") != 1:
pass

for remote_rank in range(world_size):
Expand Down
Loading