@@ -1580,7 +1580,8 @@ function compile_mlir!(
15801580 args,
15811581 compile_options:: CompileOptions ,
15821582 callcache= default_callcache (),
1583- sdycache= default_sdycache ();
1583+ sdycache= default_sdycache (),
1584+ sdygroupidcache= default_sdygroupidcache ();
15841585 fn_kwargs= (),
15851586 backend= " gpu" ,
15861587 runtime:: Union{Val{:PJRT},Val{:IFRT}} ,
@@ -1597,6 +1598,7 @@ function compile_mlir!(
15971598 MLIR. IR. activate! (MLIR. IR. body (mod))
15981599 activate_callcache! (callcache)
15991600 activate_sdycache! (sdycache)
1601+ activate_sdygroupidcache! (sdygroupidcache)
16001602
16011603 # Save in the TLS whether we are raising. We identify that condition by
16021604 # checking whether the user set an explicit list of passes, or chose
@@ -1623,6 +1625,7 @@ function compile_mlir!(
16231625 finally
16241626 deactivate_raising! (is_raising)
16251627 deactivate_sdycache! (sdycache)
1628+ deactivate_sdygroupidcache! (sdygroupidcache)
16261629 deactivate_callcache! (callcache)
16271630 MLIR. IR. deactivate! (MLIR. IR. body (mod))
16281631 MLIR. IR. deactivate! (mod)
@@ -3832,7 +3835,7 @@ function register_thunk(
38323835 )
38333836end
38343837
3835- for cache_type in (:callcache , :sdycache )
3838+ for cache_type in (:callcache , :sdycache , :sdygroupidcache )
38363839 activate_fn = Symbol (:activate_ , cache_type, :! )
38373840 deactivate_fn = Symbol (:deactivate_ , cache_type, :! )
38383841 has_fn = Symbol (:_has_ , cache_type)
@@ -3879,6 +3882,14 @@ function default_sdycache()
38793882 }()
38803883end
38813884
3885+ mutable struct SdyGroupIDCounter{T}
3886+ @atomic group_id:: T
3887+ end
3888+
3889+ function default_sdygroupidcache ()
3890+ return SdyGroupIDCounter {Int} (0 ), Base. IdDict {Union{TracedRArray,TracedRNumber},Int} ()
3891+ end
3892+
38823893function default_callcache ()
38833894 return Dict{
38843895 Vector,
0 commit comments