@@ -47,14 +47,16 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747 return length (res) == 0 ? nothing : (length (res) == 1 ? res[1 ] : res)
4848end
4949
50- function normalize_grid_and_blocks (grid_fn, metadata)
51- return normalize_grid_and_blocks (grid_fn (metadata), metadata)
50+ function normalize_grid_and_blocks (grid_fn, metadata, device_properties)
51+ return normalize_grid_and_blocks (
52+ grid_fn (metadata, device_properties), metadata, device_properties
53+ )
5254end
5355
54- function normalize_grid_and_blocks (grid:: Integer , metadata)
55- return normalize_grid_and_blocks ((grid,), metadata)
56+ function normalize_grid_and_blocks (grid:: Integer , metadata, device_properties )
57+ return normalize_grid_and_blocks ((grid,), metadata, device_properties )
5658end
57- function normalize_grid_and_blocks (grid:: Dims{N} , metadata) where {N}
59+ function normalize_grid_and_blocks (grid:: Dims{N} , metadata, device_properties ) where {N}
5860 @assert N <= 3
5961 @assert all (grid .> 0 )
6062 return (grid... , ntuple (_ -> 1 , 3 - N)... )
@@ -71,8 +73,9 @@ function overlayed_pycall_with_triton(
7173 args... ;
7274 grid,
7375 blocks,
74- num_warps:: Integer = 1 ,
76+ num_warps:: Integer = 4 ,
7577 num_stages:: Integer = 3 ,
78+ num_ctas:: Integer = 1 ,
7679 hints= nothing ,
7780)
7881 triton = tritonptr[]
@@ -105,16 +108,23 @@ function overlayed_pycall_with_triton(
105108 fn= kernel, constexprs= constants, signature= sigmap, attrs= attrs
106109 )
107110
111+ # TODO : pass the device/client here from `compile`
112+ client = Reactant. XLA. default_backend ()
113+ @assert Reactant. XLA. platform_name (client) == " cuda"
114+ device = Reactant. XLA. default_device (client)
115+ device_properties = Reactant. XLA. device_properties (device)
116+
108117 target = triton. backends. compiler. GPUTarget (
109- " cuda " ,
110- parse (Int, Reactant . Compiler . cubinChip[][ 4 : end ] ),
111- Reactant . Compiler . cuWarpSize[] ,
118+ Reactant . XLA . platform_name (client) ,
119+ parse (Int, " $(device_properties . major)$(device_properties . minor) " ),
120+ device_properties . warp_size ,
112121 )
113122 backend = triton. compiler. make_backend (target)
114123 options = backend. parse_options (
115124 pydict (
116125 " num_warps" => num_warps,
117126 " num_stages" => num_stages,
127+ " num_ctas" => num_ctas,
118128 " extern_libs" => pytuple ((pytuple ((" libdevice" , Reactant_jll. libdevice)),)),
119129 ),
120130 )
@@ -123,8 +133,8 @@ function overlayed_pycall_with_triton(
123133 # we are compiling here + lowering again inside enzymejax
124134 ccinfo = triton. compile (src; target= target, options= options. __dict__)
125135
126- grid = normalize_grid_and_blocks (grid, ccinfo. metadata)
127- blocks = normalize_grid_and_blocks (blocks, ccinfo. metadata)
136+ grid = normalize_grid_and_blocks (grid, ccinfo. metadata, device_properties )
137+ blocks = normalize_grid_and_blocks (blocks, ccinfo. metadata, device_properties )
128138
129139 return @opcall triton_call (
130140 pyconvert (String, ccinfo. asm[" source" ]),
@@ -136,5 +146,8 @@ function overlayed_pycall_with_triton(
136146 block_x= @opcall (constant (blocks[1 ])),
137147 block_y= @opcall (constant (blocks[2 ])),
138148 block_z= @opcall (constant (blocks[3 ])),
149+ # The following are written to module attributes and restored later on
150+ num_ctas,
151+ num_warps,
139152 )
140153end
0 commit comments