@@ -32,6 +32,7 @@ defmodule EXLA.Defn do
3232
3333 @ doc false
3434 def __stream__ ( key , input , acc , vars , fun , [ args ] , options ) do
35+ { debug? , options } = Keyword . pop ( options , :debug , false )
3536 { run_options , compile_options } = Keyword . pop ( options , :run_options , [ ] )
3637
3738 { client_name , compile_options } =
@@ -51,24 +52,26 @@ defmodule EXLA.Defn do
5152 comp_fun =
5253 & to_stream_computation ( client , input_length , acc_length , & 1 , & 2 , & 3 , & 4 , compile_options )
5354
54- { executable , used_inputs , { output , acc_output } , outfeed , extra , debug? } =
55+ { executable , used_inputs , { output , acc_output } , outfeed , input_typespecs } =
5556 compile (
5657 client ,
57- { :stream , key } ,
58+ key ,
5859 vars ,
5960 fun ,
6061 compile_options ,
6162 used_buffers ,
6263 used_inputs ,
6364 _stream = true ,
65+ debug? ,
6466 comp_fun
6567 )
6668
67- { input_typespecs , input_indexes } = extra
69+ # Now discard the infeed from used inputs, similar to how it is done to buffers.
70+ # Note we discard all lazy transfers too, as they are not possible with streams.
71+ used_inputs = for { i , nil } <- used_inputs , i >= used_buffers , do: { i , nil } , into: % { }
6872
69- # Also discard the stream inputs from used inputs, similar to how it is done to buffers
70- # Note we discard all lazy transfers too, as they are not possible with streams
71- used_inputs = Enum . sort ( for { i , nil } <- used_inputs , i >= used_buffers , do: i )
73+ # And capture the typespecs for the infeed.
74+ input_typespecs = Enum . take_while ( input_typespecs , fn { i , _ } -> i < input_length end )
7275
7376 # Execution of streams requires the coordination of
7477 # multiple processes which is outlined below.
@@ -120,7 +123,6 @@ defmodule EXLA.Defn do
120123 outfeed_pid ,
121124 input ,
122125 input_typespecs ,
123- input_indexes ,
124126 output ,
125127 output_typespecs ,
126128 acc_output
@@ -151,9 +153,6 @@ defmodule EXLA.Defn do
151153 { input_typespecs , used_typespecs } =
152154 Enum . split_while ( used_typespecs , fn { i , _ } -> i < input_length end )
153155
154- # Get all input indexes and shape
155- input_indexes = Enum . map ( input_typespecs , & elem ( & 1 , 0 ) )
156-
157156 # Drop all accumulator entries from used_typespecs as we will handle it separately.
158157 { acc_typespecs , used_typespecs } = Enum . split ( used_typespecs , acc_length )
159158
@@ -166,13 +165,10 @@ defmodule EXLA.Defn do
166165 # The input will be read as part of the infeed.
167166 acc_typespecs_l = Enum . map ( acc_typespecs , & elem ( & 1 , 1 ) )
168167 acc_typespec = List . to_tuple ( acc_typespecs_l )
169-
170168 flag_typespec = Typespec . tensor ( { :pred , 8 } , { } )
171169
172170 args = EXLA.MLIR.Function . get_arguments ( builder )
173-
174171 { token , [ flag ] } = Value . infeed ( root_token , [ flag_typespec ] )
175-
176172 init = [ flag , token | args ]
177173
178174 arg_typespecs = Enum . map ( init , & Value . get_typespec / 1 )
@@ -186,11 +182,9 @@ defmodule EXLA.Defn do
186182 { body_computation , [ _flag , token | args ] } = Function . push_region ( builder , arg_typespecs )
187183
188184 { acc , constant } = Enum . split ( args , acc_length )
189-
190- { indices , input_typespecs } = Enum . unzip ( input_typespecs )
185+ { input_indices , input_typespecs } = Enum . unzip ( input_typespecs )
191186 { token , input } = Value . infeed ( token , input_typespecs )
192-
193- input_params = Enum . zip ( indices , input )
187+ input_params = Enum . zip ( input_indices , input )
194188
195189 { % Outfeed { token: token } = outfeed , acc } =
196190 case expr do
@@ -226,9 +220,7 @@ defmodule EXLA.Defn do
226220
227221 # Emit the stream hook to signal loop output
228222 { token , [ flag ] } = Value . infeed ( token , [ flag_typespec ] )
229-
230223 Value . return ( flag . function , [ flag , token | acc ] ++ List . flatten ( constant ) )
231-
232224 Function . pop_region ( builder )
233225
234226 [ _flag , out_token | results ] = Value . while ( builder , pred_computation , body_computation , init )
@@ -238,8 +230,7 @@ defmodule EXLA.Defn do
238230
239231 outfeed = outfeed |> Outfeed . with_token ( out_token ) |> Outfeed . close ( builder )
240232 Value . func_return ( builder , output )
241-
242- { { input_typespecs , input_indexes } , outfeed }
233+ outfeed
243234 end
244235
245236 @ doc false
@@ -249,6 +240,7 @@ defmodule EXLA.Defn do
249240
250241 @ doc false
251242 def __compile__ ( key , vars , fun , options ) do
243+ { debug? , options } = Keyword . pop ( options , :debug , false )
252244 { run_options , compile_options } = Keyword . pop ( options , :run_options , [ ] )
253245
254246 { client_name , compile_options } =
@@ -258,8 +250,8 @@ defmodule EXLA.Defn do
258250
259251 callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , Keyword . put ( compile_options , :client , client ) )
260252
261- { executable , used_inputs , outputs , outfeed , :ok , debug ?} =
262- compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , callback )
253+ { executable , used_inputs , outputs , outfeed , _input_typespecs ?} =
254+ compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , debug? , callback )
263255
264256 fn [ args ] ->
265257 { time , lock } =
@@ -306,10 +298,8 @@ defmodule EXLA.Defn do
306298
307299 { res , cache } = recur_flatten ( expr , state , new_cache ( outfeed ) )
308300 outfeed = cache |> get_outfeed ( ) |> Outfeed . close ( function )
309-
310301 Value . func_return ( function , res )
311-
312- { :ok , outfeed }
302+ outfeed
313303 end
314304
315305 defp maybe_outfeed ( lock , executable , args , used_inputs , outputs , outfeed , run_options )
@@ -367,6 +357,7 @@ defmodule EXLA.Defn do
367357 used_buffers ,
368358 used_inputs ,
369359 stream? ,
360+ debug? ,
370361 to_computation
371362 ) do
372363 { { expr_cache_fun , comp_cache_fun } , options } =
@@ -379,8 +370,6 @@ defmodule EXLA.Defn do
379370 { { cache_fun , cache_fun } , options }
380371 end
381372
382- { debug? , options } = Keyword . pop ( options , :debug , false )
383-
384373 { args_key , reverse_args_identifiers } =
385374 Enum . map_reduce ( vars , [ ] , fn var , acc ->
386375 Nx.Defn.Composite . traverse ( var , acc , fn
@@ -396,7 +385,7 @@ defmodule EXLA.Defn do
396385
397386 { eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
398387 :timer . tc ( fn ->
399- expr_cache_fun . ( { key , args_key , lazy_transfers } , fn ->
388+ expr_cache_fun . ( { key , stream? , args_key , lazy_transfers } , fn ->
400389 expr = fun . ( vars )
401390 inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
402391 { expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
@@ -412,12 +401,10 @@ defmodule EXLA.Defn do
412401 end
413402
414403 { hooks , options } = Keyword . pop ( options , :hooks , % { } )
415-
416404 outfeed = Outfeed . new ( hooks , defined_hooks )
417-
418405 comp_key = { ref , client . name , outfeed . used_hooks , lazy_transfers , options }
419406
420- { comp_time , { evaled , { xla_time , executable , extra , outfeed } } } =
407+ { comp_time , { evaled , { xla_time , executable , inputs_and_typespecs , outfeed } } } =
421408 :timer . tc ( fn ->
422409 comp_cache_fun . ( comp_key , fn ->
423410 { reverse_inputs_and_typespecs , reverse_infeeds } =
@@ -430,7 +417,7 @@ defmodule EXLA.Defn do
430417
431418 inputs_and_typespecs = Enum . reverse ( reverse_inputs_and_typespecs )
432419
433- comp_arg_typespecs =
420+ comp_typespecs =
434421 for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
435422
436423 outputs =
@@ -451,7 +438,7 @@ defmodule EXLA.Defn do
451438 |> then ( & Typespec . tensor ( & 1 . type , & 1 . shape ) )
452439 end )
453440
454- EXLA.MLIR.Module . new ( comp_arg_typespecs , out_typespecs , fn builder ->
441+ EXLA.MLIR.Module . new ( comp_typespecs , out_typespecs , fn builder ->
455442 # Only create the token when we know it will actually be
456443 # used, that is: streaming, lazy transfers or hooks
457444 outfeed =
@@ -464,25 +451,20 @@ defmodule EXLA.Defn do
464451 end
465452
466453 expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
467-
468- { extra , outfeed } =
469- to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
454+ outfeed = to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
470455
471456 { xla_time , executable } =
472457 :timer . tc ( fn ->
473- typespecs =
474- for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
475-
476458 EXLA.MLIR.Module . compile (
477459 builder . module ,
478460 client ,
479- typespecs ,
461+ comp_typespecs ,
480462 builder . return_typespecs ,
481463 options
482464 )
483465 end )
484466
485- { :ok , { xla_time , executable , extra , % { outfeed | infeeds: [ ] } } }
467+ { :ok , { xla_time , executable , inputs_and_typespecs , % { outfeed | infeeds: [ ] } } }
486468 end )
487469 end )
488470 end )
@@ -511,7 +493,7 @@ defmodule EXLA.Defn do
511493 end
512494
513495 outfeed = Outfeed . with_user_hooks ( outfeed , hooks )
514- { executable , used_inputs , outputs , outfeed , extra , debug? }
496+ { executable , used_inputs , outputs , outfeed , inputs_and_typespecs }
515497 end
516498
517499 defp us_to_ms ( time ) , do: Float . round ( time / 1000 , 1 )
0 commit comments