@@ -60,6 +60,7 @@ defmodule EXLA.Defn do
6060 compile_options ,
6161 used_buffers ,
6262 used_inputs ,
63+ _stream = true ,
6364 comp_fun
6465 )
6566
@@ -258,7 +259,7 @@ defmodule EXLA.Defn do
258259 callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , Keyword . put ( compile_options , :client , client ) )
259260
260261 { executable , used_inputs , outputs , outfeed , :ok , debug? } =
261- compile ( client , key , vars , fun , compile_options , 0 , [ ] , callback )
262+ compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , callback )
262263
263264 fn [ args ] ->
264265 { time , lock } =
@@ -357,7 +358,17 @@ defmodule EXLA.Defn do
357358
358359 ## Compile
359360
360- defp compile ( client , key , vars , fun , options , used_buffers , used_inputs , to_computation ) do
361+ defp compile (
362+ client ,
363+ key ,
364+ vars ,
365+ fun ,
366+ options ,
367+ used_buffers ,
368+ used_inputs ,
369+ stream? ,
370+ to_computation
371+ ) do
361372 { { expr_cache_fun , comp_cache_fun } , options } =
362373 case Keyword . pop ( options , :cache , true ) do
363374 { true , options } ->
@@ -385,7 +396,7 @@ defmodule EXLA.Defn do
385396
386397 { eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
387398 :timer . tc ( fn ->
388- expr_cache_fun . ( { key , args_key } , fn ->
399+ expr_cache_fun . ( { key , args_key , lazy_transfers } , fn ->
389400 expr = fun . ( vars )
390401 inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
391402 { expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
@@ -432,10 +443,16 @@ defmodule EXLA.Defn do
432443 end )
433444
434445 EXLA.MLIR.Module . new ( comp_arg_typespecs , out_typespecs , fn builder ->
446+ # Only create the token when we know it will actually be
447+ # used, that is: streaming, lazy transfers or hooks
435448 outfeed =
436- outfeed
437- |> Outfeed . with_token ( Value . create_token ( builder ) )
438- |> Outfeed . add_infeeds ( builder , reverse_infeeds )
449+ if stream? or reverse_infeeds != [ ] or hooks != % { } or defined_hooks != % { } do
450+ outfeed
451+ |> Outfeed . with_token ( Value . create_token ( builder ) )
452+ |> Outfeed . add_infeeds ( builder , reverse_infeeds )
453+ else
454+ outfeed
455+ end
439456
440457 expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
441458
@@ -520,19 +537,30 @@ defmodule EXLA.Defn do
520537 cache
521538 ) do
522539 [ initial_arg , _arg , pred , body ] = args
523- initial_with_token = { get_token ( cache ) , initial_arg }
524540
525- { initial , cache } = recur_composite ( initial_with_token , state , cache )
541+ initial =
542+ if token = get_token ( cache ) do
543+ { token , initial_arg }
544+ else
545+ initial_arg
546+ end
547+
548+ { initial , cache } = recur_composite ( initial , state , cache )
526549
527550 { pred_computation , cache } = mlir_while_computation ( pred , initial , { :pred , 8 } , state , cache )
528551 { body_computation , cache } = mlir_while_computation ( body , initial , :with_token , state , cache )
529552
530- [ token | results ] =
553+ results =
531554 Value . while ( function , pred_computation , body_computation , List . flatten ( initial ) )
532555
533- result = wrap_tuple_result ( results , initial_arg )
534-
535- { result , update_token ( cache , token ) }
556+ if get_token ( cache ) do
557+ [ token | results ] = results
558+ result = wrap_tuple_result ( results , initial_arg )
559+ { result , update_token ( cache , token ) }
560+ else
561+ result = wrap_tuple_result ( results , initial_arg )
562+ { result , cache }
563+ end
536564 end
537565
538566 defp cached_recur_operator ( :cond , % T { data: % Expr { args: args } } = t , state , cache ) do
@@ -688,16 +716,19 @@ defmodule EXLA.Defn do
688716 { computation , cache }
689717
690718 % { } ->
691- { computation , cache } = token_computation ( "optional" , call_args , expr , state , cache )
719+ { computation , cache } = optional_computation ( "optional" , call_args , expr , state , cache )
692720 { computation , Map . put ( cache , key , computation ) }
693721 end
694722
695- typespecs = [ Typespec . token ( ) | container_to_typespecs ( expr ) ]
696-
697- [ token | result ] =
698- Value . call ( state . builder , [ get_token ( cache ) | call_args ] , call_body , typespecs )
699-
700- { wrap_tuple_result ( result , expr ) , update_token ( cache , token ) }
723+ if token = get_token ( cache ) do
724+ typespecs = [ Typespec . token ( ) | container_to_typespecs ( expr ) ]
725+ [ token | result ] = Value . call ( state . builder , [ token | call_args ] , call_body , typespecs )
726+ { wrap_tuple_result ( result , expr ) , update_token ( cache , token ) }
727+ else
728+ typespecs = container_to_typespecs ( expr )
729+ result = Value . call ( state . builder , call_args , call_body , typespecs )
730+ { wrap_tuple_result ( result , expr ) , cache }
731+ end
701732 end
702733
703734 defp cached_recur_operator ( :attach_token , % T { data: % Expr { args: [ token , expr ] } } , state , cache ) do
@@ -1553,7 +1584,17 @@ defmodule EXLA.Defn do
15531584 defp mlir_while_computation ( expr , initial , type , state , cache ) do
15541585 arg_typespecs = Enum . map ( List . flatten ( initial ) , & Value . get_typespec / 1 )
15551586
1556- { region , [ arg_token | arg_params ] } = Function . push_region ( state . builder , arg_typespecs )
1587+ { region , args } = Function . push_region ( state . builder , arg_typespecs )
1588+
1589+ outer_token = get_token ( cache )
1590+
1591+ { inner_token , arg_params } =
1592+ if outer_token do
1593+ [ arg_token | arg_params ] = args
1594+ { arg_token , arg_params }
1595+ else
1596+ { nil , args }
1597+ end
15571598
15581599 params = Enum . with_index ( arg_params , & { & 2 , & 1 } )
15591600
@@ -1570,11 +1611,15 @@ defmodule EXLA.Defn do
15701611 expr
15711612 end
15721613
1573- { res , comp_cache } = recur_composite ( expr , & & 1 , state , reset_token ( cache , arg_token ) )
1614+ { res , comp_cache } = recur_composite ( expr , & & 1 , state , reset_token ( cache , inner_token ) )
15741615
15751616 res =
15761617 if type == :with_token do
1577- [ get_token ( comp_cache ) | List . flatten ( res ) ]
1618+ if outer_token do
1619+ [ get_token ( comp_cache ) | List . flatten ( res ) ]
1620+ else
1621+ List . flatten ( res )
1622+ end
15781623 else
15791624 Enum . map ( res , & to_type ( & 1 , type ) )
15801625 end
@@ -1585,21 +1630,34 @@ defmodule EXLA.Defn do
15851630 { region , merge_outfeed ( cache , comp_cache ) }
15861631 end
15871632
1588- defp token_computation ( name , args , expr , % { builder: % Function { } } = state , cache ) do
1633+ defp optional_computation ( name , args , expr , % { builder: % Function { } } = state , cache ) do
15891634 % Function { module: module , name: name } = subbuilder ( state . builder , name )
15901635
1591- token_typespec = Typespec . token ( )
15921636 arg_typespecs = Enum . map ( args , & Value . get_typespec / 1 )
15931637 out_typespecs = container_to_typespecs ( expr )
15941638
1595- function =
1596- EXLA.MLIR.Module . add_function ( module , name , [ token_typespec | arg_typespecs ] , [
1597- token_typespec | out_typespecs
1598- ] )
1639+ outer_token = get_token ( cache )
1640+ token_typespec = Typespec . token ( )
1641+
1642+ { arg_typespecs , out_typespecs } =
1643+ if outer_token do
1644+ { [ token_typespec | arg_typespecs ] , [ token_typespec | out_typespecs ] }
1645+ else
1646+ { arg_typespecs , out_typespecs }
1647+ end
15991648
1600- [ arg_token | tail ] = EXLA.MLIR.Function . get_arguments ( function )
1649+ function = EXLA.MLIR.Module . add_function ( module , name , arg_typespecs , out_typespecs )
1650+ args = EXLA.MLIR.Function . get_arguments ( function )
16011651
1602- params = Enum . with_index ( tail , fn param , i -> { i , param } end )
1652+ { inner_token , args } =
1653+ if outer_token do
1654+ [ arg_token | args ] = args
1655+ { arg_token , args }
1656+ else
1657+ { nil , args }
1658+ end
1659+
1660+ params = Enum . with_index ( args , fn param , i -> { i , param } end )
16031661
16041662 state = % {
16051663 state
@@ -1608,9 +1666,13 @@ defmodule EXLA.Defn do
16081666 scope_ids: Tree . scope_ids ( expr )
16091667 }
16101668
1611- { res , comp_cache } = recur_composite ( expr , state , reset_token ( cache , arg_token ) )
1669+ { res , comp_cache } = recur_composite ( expr , state , reset_token ( cache , inner_token ) )
16121670
1613- Value . return ( function , [ get_token ( comp_cache ) | List . flatten ( res ) ] )
1671+ if outer_token do
1672+ Value . return ( function , [ get_token ( comp_cache ) | List . flatten ( res ) ] )
1673+ else
1674+ Value . return ( function , List . flatten ( res ) )
1675+ end
16141676
16151677 { function , merge_outfeed ( cache , comp_cache ) }
16161678 end
@@ -1786,10 +1848,10 @@ defmodule EXLA.Defn do
17861848
17871849 out_typespecs = container_to_typespecs ( on_true )
17881850
1789- in_token = get_token ( cache )
1851+ outer_token = get_token ( cache )
17901852
17911853 result_typespecs =
1792- if in_token do
1854+ if outer_token do
17931855 [ Typespec . token ( ) | out_typespecs ]
17941856 else
17951857 out_typespecs
@@ -1799,7 +1861,7 @@ defmodule EXLA.Defn do
17991861 { false_computation , cache } = to_mlir_if_branch ( on_false , false_ids , state , cache )
18001862 if_results = Value . if_op ( pred_op , true_computation , false_computation , result_typespecs )
18011863
1802- if in_token do
1864+ if outer_token do
18031865 [ token | results ] = if_results
18041866 { wrap_tuple_result ( results , on_true ) , update_token ( cache , token ) }
18051867 else
0 commit comments