@@ -52,7 +52,7 @@ defmodule EXLA.Defn do
5252 comp_fun =
5353 & to_stream_computation ( client , input_length , acc_length , & 1 , & 2 , & 3 , & 4 , compile_options )
5454
55- { executable , used_inputs , { output , acc_output } , outfeed , input_typespecs } =
55+ { executable , { used_inputs , { output , acc_output } , outfeed , input_typespecs } } =
5656 compile (
5757 client ,
5858 key ,
@@ -84,9 +84,7 @@ defmodule EXLA.Defn do
8484 EXLA.Defn.Lock . lock ( run_key ( executable ) )
8585 end )
8686
87- if debug? do
88- Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
89- end
87+ debug? && Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
9088
9189 { time , streams } =
9290 :timer . tc ( fn ->
@@ -131,9 +129,8 @@ defmodule EXLA.Defn do
131129 [ stream ]
132130 end )
133131
134- if debug? do
132+ debug? &&
135133 Logger . debug ( "EXLA stream start on device #{ executable . device_id } in #{ us_to_ms ( time ) } ms" )
136- end
137134
138135 streams
139136 end
@@ -250,7 +247,7 @@ defmodule EXLA.Defn do
250247
251248 callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , Keyword . put ( compile_options , :client , client ) )
252249
253- { executable , used_inputs , outputs , outfeed , _input_typespecs? } =
250+ { executable , { used_inputs , outputs , outfeed , _input_typespecs? } } =
254251 compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , debug? , callback )
255252
256253 fn [ args ] ->
@@ -259,18 +256,15 @@ defmodule EXLA.Defn do
259256 EXLA.Defn.Lock . lock ( run_key ( executable ) )
260257 end )
261258
262- if debug? do
263- Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
264- end
259+ debug? && Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
265260
266261 { time , res } =
267262 :timer . tc ( fn ->
268263 maybe_outfeed ( lock , executable , args , used_inputs , outputs , outfeed , run_options )
269264 end )
270265
271- if debug? do
266+ debug? &&
272267 Logger . debug ( "EXLA execution on device #{ executable . device_id } in #{ us_to_ms ( time ) } ms" )
273- end
274268
275269 res
276270 end
@@ -360,15 +354,9 @@ defmodule EXLA.Defn do
360354 debug? ,
361355 to_computation
362356 ) do
363- { { expr_cache_fun , comp_cache_fun } , options } =
364- case Keyword . pop ( options , :cache , true ) do
365- { true , options } ->
366- Keyword . pop ( options , EXLA , { & EXLA.Defn.LockedCache . run / 2 , & EXLA.Defn.LockedCache . run / 2 } )
367-
368- { false , options } ->
369- cache_fun = fn _key , fun -> fun . ( ) end
370- { { cache_fun , cache_fun } , options }
371- end
357+ { cache , options } = Keyword . pop ( options , :cache , true )
358+ { hooks , options } = Keyword . pop ( options , :hooks , % { } )
359+ { lazy_transfers , options } = Keyword . pop ( options , :lazy_transfers , :opt_in )
372360
373361 { args_key , reverse_args_identifiers } =
374362 Enum . map_reduce ( vars , [ ] , fn var , acc ->
@@ -381,119 +369,134 @@ defmodule EXLA.Defn do
381369 end )
382370 end )
383371
384- { lazy_transfers , options } = Keyword . pop ( options , :lazy_transfers , :opt_in )
372+ disk_key = % {
373+ client: client . name ,
374+ args: args_key ,
375+ lazy_transfers: lazy_transfers ,
376+ hooks: Map . keys ( hooks ) ,
377+ options: options
378+ }
385379
386- { eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
387- :timer . tc ( fn ->
388- expr_cache_fun . ( { key , stream? , args_key , lazy_transfers } , fn ->
389- expr = fun . ( vars )
390- inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
391- { expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
380+ EXLA.Defn.Disk . cache ( cache , client , disk_key , debug? , fn ->
381+ { { expr_cache_fun , comp_cache_fun } , options } =
382+ if cache do
383+ Keyword . pop ( options , EXLA , { & EXLA.Defn.LockedCache . run / 2 , & EXLA.Defn.LockedCache . run / 2 } )
384+ else
385+ cache_fun = fn _key , fun -> fun . ( ) end
386+ { { cache_fun , cache_fun } , Keyword . delete ( options , EXLA ) }
387+ end
388+
389+ { eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
390+ :timer . tc ( fn ->
391+ expr_cache_fun . ( { key , stream? , args_key , lazy_transfers } , fn ->
392+ expr = fun . ( vars )
393+ inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
394+ { expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
395+ end )
392396 end )
393- end )
394397
395- if debug? do
396- hit_or_miss = if expr , do: "miss" , else: "hit"
398+ if debug? do
399+ hit_or_miss = if expr , do: "miss" , else: "hit"
397400
398- Logger . debug (
399- "EXLA defn evaluation #{ inspect ( key ) } cache #{ hit_or_miss } in #{ us_to_ms ( eval_time ) } ms"
400- )
401- end
401+ Logger . debug (
402+ "EXLA defn evaluation #{ inspect ( key ) } cache #{ hit_or_miss } in #{ us_to_ms ( eval_time ) } ms"
403+ )
404+ end
402405
403- { hooks , options } = Keyword . pop ( options , :hooks , % { } )
404- outfeed = Outfeed . new ( hooks , defined_hooks )
405- comp_key = { ref , client . name , outfeed . used_hooks , lazy_transfers , options }
406+ outfeed = Outfeed . new ( hooks , defined_hooks )
407+ comp_key = { ref , client . name , outfeed . used_hooks , lazy_transfers , options }
406408
407- { comp_time , { evaled , { xla_time , executable , inputs_and_typespecs , outfeed } } } =
408- :timer . tc ( fn ->
409- comp_cache_fun . ( comp_key , fn ->
410- { reverse_inputs_and_typespecs , reverse_infeeds } =
411- reverse_args_identifiers
412- |> Enum . reverse ( )
413- |> EXLA.Defn.Buffers . split_by_value ( used_inputs , fn
414- { type , shape , _names } , i , nil -> { i , Typespec . tensor ( type , shape ) }
415- { type , shape , _names } , i , depth -> { i , depth , Typespec . tensor ( type , shape ) }
416- end )
409+ { comp_time , { evaled , { xla_time , executable , inputs_and_typespecs , outfeed } } } =
410+ :timer . tc ( fn ->
411+ comp_cache_fun . ( comp_key , fn ->
412+ { reverse_inputs_and_typespecs , reverse_infeeds } =
413+ reverse_args_identifiers
414+ |> Enum . reverse ( )
415+ |> EXLA.Defn.Buffers . split_by_value ( used_inputs , fn
416+ { type , shape , _names } , i , nil -> { i , Typespec . tensor ( type , shape ) }
417+ { type , shape , _names } , i , depth -> { i , depth , Typespec . tensor ( type , shape ) }
418+ end )
417419
418- inputs_and_typespecs = Enum . reverse ( reverse_inputs_and_typespecs )
419-
420- comp_typespecs =
421- for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
422-
423- outputs =
424- if stream? do
425- # The computation returns the final accumulator value
426- { _chunk_result , acc } = outputs
427- acc
428- else
429- outputs
430- end
431-
432- out_typespecs =
433- [ outputs ]
434- |> Nx.Defn.Composite . flatten_list ( )
435- |> Enum . map ( fn t ->
436- t
437- |> Nx . devectorize ( )
438- |> then ( & Typespec . tensor ( & 1 . type , & 1 . shape ) )
439- end )
420+ inputs_and_typespecs = Enum . reverse ( reverse_inputs_and_typespecs )
421+
422+ comp_typespecs =
423+ for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
440424
441- EXLA.MLIR.Module . new ( comp_typespecs , out_typespecs , fn builder ->
442- # Only create the token when we know it will actually be
443- # used, that is: streaming, lazy transfers or hooks
444- outfeed =
445- if stream? or reverse_infeeds != [ ] or hooks != % { } or defined_hooks != % { } do
446- outfeed
447- |> Outfeed . with_token ( Value . create_token ( builder ) )
448- |> Outfeed . add_infeeds ( builder , reverse_infeeds )
425+ outputs =
426+ if stream? do
427+ # The computation returns the final accumulator value
428+ { _chunk_result , acc } = outputs
429+ acc
449430 else
450- outfeed
431+ outputs
451432 end
452433
453- expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
454- outfeed = to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
455-
456- { xla_time , executable } =
457- :timer . tc ( fn ->
458- EXLA.MLIR.Module . compile (
459- builder . module ,
460- client ,
461- comp_typespecs ,
462- builder . return_typespecs ,
463- options
464- )
434+ out_typespecs =
435+ [ outputs ]
436+ |> Nx.Defn.Composite . flatten_list ( )
437+ |> Enum . map ( fn t ->
438+ t
439+ |> Nx . devectorize ( )
440+ |> then ( & Typespec . tensor ( & 1 . type , & 1 . shape ) )
465441 end )
466442
467- { :ok , { xla_time , executable , inputs_and_typespecs , % { outfeed | infeeds: [ ] } } }
443+ EXLA.MLIR.Module . new ( comp_typespecs , out_typespecs , fn builder ->
444+ # Only create the token when we know it will actually be
445+ # used, that is: streaming, lazy transfers or hooks
446+ outfeed =
447+ if stream? or reverse_infeeds != [ ] or hooks != % { } or defined_hooks != % { } do
448+ outfeed
449+ |> Outfeed . with_token ( Value . create_token ( builder ) )
450+ |> Outfeed . add_infeeds ( builder , reverse_infeeds )
451+ else
452+ outfeed
453+ end
454+
455+ expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
456+ outfeed = to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
457+
458+ { xla_time , executable } =
459+ :timer . tc ( fn ->
460+ EXLA.MLIR.Module . compile (
461+ builder . module ,
462+ client ,
463+ comp_typespecs ,
464+ builder . return_typespecs ,
465+ options
466+ )
467+ end )
468+
469+ { :ok , { xla_time , executable , inputs_and_typespecs , % { outfeed | infeeds: [ ] } } }
470+ end )
468471 end )
469472 end )
470- end )
471473
472- cond do
473- not debug? ->
474- :ok
474+ cond do
475+ not debug? ->
476+ :ok
475477
476- evaled ->
477- Logger . debug (
478- "EXLA compilation #{ inspect ( key ) } cache miss in #{ us_to_ms ( comp_time ) } ms (#{ us_to_ms ( xla_time ) } ms in XLA)"
479- )
478+ evaled ->
479+ Logger . debug (
480+ "EXLA compilation #{ inspect ( key ) } cache miss in #{ us_to_ms ( comp_time ) } ms (#{ us_to_ms ( xla_time ) } ms in XLA)"
481+ )
480482
481- true ->
482- Logger . debug ( "EXLA compilation #{ inspect ( key ) } cache hit in #{ us_to_ms ( comp_time ) } ms" )
483- end
483+ true ->
484+ Logger . debug ( "EXLA compilation #{ inspect ( key ) } cache hit in #{ us_to_ms ( comp_time ) } ms" )
485+ end
484486
485- if expr || evaled do
486- measurements = % {
487- eval_time: eval_time ,
488- compile_time: comp_time ,
489- total_time: eval_time + comp_time
490- }
487+ if expr || evaled do
488+ measurements = % {
489+ eval_time: eval_time ,
490+ compile_time: comp_time ,
491+ total_time: eval_time + comp_time
492+ }
491493
492- :telemetry . execute ( [ :exla , :compilation ] , measurements , % { key: key } )
493- end
494+ :telemetry . execute ( [ :exla , :compilation ] , measurements , % { key: key } )
495+ end
494496
495- outfeed = Outfeed . with_user_hooks ( outfeed , hooks )
496- { executable , used_inputs , outputs , outfeed , inputs_and_typespecs }
497+ outfeed = Outfeed . with_user_hooks ( outfeed , hooks )
498+ { executable , { used_inputs , outputs , outfeed , inputs_and_typespecs } }
499+ end )
497500 end
498501
499502 defp us_to_ms ( time ) , do: Float . round ( time / 1000 , 1 )
0 commit comments