@@ -326,8 +326,8 @@ class Function:
326326 def __init__ (
327327 self ,
328328 vm : "VM" ,
329- input_storage ,
330- output_storage ,
329+ input_storage : list [ Container ] ,
330+ output_storage : list [ Container ] ,
331331 indices ,
332332 outputs ,
333333 defaults ,
@@ -372,7 +372,6 @@ def __init__(
372372 name
373373 A string name.
374374 """
375- # TODO: Rename to `vm`
376375 self .vm = vm
377376 self .input_storage = input_storage
378377 self .output_storage = output_storage
@@ -388,31 +387,49 @@ def __init__(
388387 self .nodes_with_inner_function = []
389388 self .output_keys = output_keys
390389
391- # See if we have any mutable / borrow inputs
392- # TODO: this only need to be set if there is more than one input
393- self ._check_for_aliased_inputs = False
394- for i in maker .inputs :
395- # If the input is a shared variable, the memory region is
396- # under PyTensor control and so we don't need to check if it
397- # is aliased as we never do that.
398- if (
399- isinstance (i , In )
400- and not i .shared
401- and (getattr (i , "borrow" , False ) or getattr (i , "mutable" , False ))
390+ assert len (self .input_storage ) == len (self .maker .fgraph .inputs )
391+ assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
392+
393+ # Group indexes of inputs that are potentially aliased to each other
394+ # Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
395+ # even though there could be two distinct types that use the same kinds of underlying objects.
396+ potential_aliased_input_groups = []
397+ for inp in maker .inputs :
398+ # If the input is a shared variable, the memory region is under PyTensor control
399+ # and can't be aliased.
400+ if not (
401+ isinstance (inp , In )
402+ and inp .borrow
403+ and not inp .shared
404+ and hasattr (inp .variable .type , "may_share_memory" )
402405 ):
403- self ._check_for_aliased_inputs = True
404- break
406+ continue
407+
408+ for group in potential_aliased_input_groups :
409+ # If one is super of the other, that means one could be replaced by the other
410+ if any (
411+ inp .variable .type .is_super (other_inp .variable .type )
412+ or other_inp .variable .type .is_super (inp .variable .type )
413+ for other_inp in group
414+ ):
415+ group .append (inp )
416+ break
417+ else : # no break
418+ # Input makes a new group
419+ potential_aliased_input_groups .append ([inp ])
420+
421+ # Potential aliased inputs are those that belong to the same group
422+ self ._potential_aliased_input_groups : tuple [tuple [int , ...], ...] = tuple (
423+ tuple (maker .inputs .index (inp ) for inp in group )
424+ for group in potential_aliased_input_groups
425+ if len (group ) > 1
426+ )
405427
406428 # We will be popping stuff off this `containers` object. It is a copy.
407429 containers = list (self .input_storage )
408430 finder = {}
409431 inv_finder = {}
410432
411- def distribute (indices , cs , value ):
412- input .distribute (value , indices , cs )
413- for c in cs :
414- c .provided += 1
415-
416433 # Store the list of names of named inputs.
417434 named_inputs = []
418435 # Count the number of un-named inputs.
@@ -777,6 +794,13 @@ def checkSV(sv_ori, sv_rpl):
777794 f_cpy .maker .fgraph .name = name
778795 return f_cpy
779796
797+ def _restore_defaults (self ):
798+ for i , (required , refeed , value ) in enumerate (self .defaults ):
799+ if refeed :
800+ if isinstance (value , Container ):
801+ value = value .storage [0 ]
802+ self [i ] = value
803+
780804 def __call__ (self , * args , ** kwargs ):
781805 """
782806 Evaluates value of a function on given arguments.
@@ -805,52 +829,43 @@ def __call__(self, *args, **kwargs):
805829 List of outputs on indices/keys from ``output_subset`` or all of them,
806830 if ``output_subset`` is not passed.
807831 """
808-
809- def restore_defaults ():
810- for i , (required , refeed , value ) in enumerate (self .defaults ):
811- if refeed :
812- if isinstance (value , Container ):
813- value = value .storage [0 ]
814- self [i ] = value
815-
832+ input_storage = self .input_storage
816833 profile = self .profile
817- t0 = time .perf_counter ()
834+
835+ if profile :
836+ t0 = time .perf_counter ()
818837
819838 output_subset = kwargs .pop ("output_subset" , None )
820839 if output_subset is not None and self .output_keys is not None :
821840 output_subset = [self .output_keys .index (key ) for key in output_subset ]
822841
823842 # Reinitialize each container's 'provided' counter
824843 if self .trust_input :
825- i = 0
826- for arg in args :
827- s = self .input_storage [i ]
828- s .storage [0 ] = arg
829- i += 1
844+ for arg_container , arg in zip (input_storage , args , strict = False ):
845+ arg_container .storage [0 ] = arg
830846 else :
831- for c in self . input_storage :
832- c .provided = 0
847+ for arg_container in input_storage :
848+ arg_container .provided = 0
833849
834- if len (args ) + len (kwargs ) > len (self . input_storage ):
850+ if len (args ) + len (kwargs ) > len (input_storage ):
835851 raise TypeError ("Too many parameter passed to pytensor function" )
836852
837853 # Set positional arguments
838- i = 0
839- for arg in args :
840- # TODO: provide a option for skipping the filter if we really
841- # want speed.
842- s = self .input_storage [i ]
843- # see this emails for a discuation about None as input
854+ for arg_container , arg in zip (input_storage , args , strict = False ):
855+ # See discussion about None as input
844856 # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
845857 if arg is None :
846- s .storage [0 ] = arg
858+ arg_container .storage [0 ] = arg
847859 else :
848860 try :
849- s .storage [0 ] = s .type .filter (
850- arg , strict = s .strict , allow_downcast = s .allow_downcast
861+ arg_container .storage [0 ] = arg_container .type .filter (
862+ arg ,
863+ strict = arg_container .strict ,
864+ allow_downcast = arg_container .allow_downcast ,
851865 )
852866
853867 except Exception as e :
868+ i = input_storage .index (arg_container )
854869 function_name = "pytensor function"
855870 argument_name = "argument"
856871 if self .name :
@@ -875,93 +890,74 @@ def restore_defaults():
875890 + function_name
876891 + f" at index { int (i )} (0-based). { where } "
877892 ) + e .args
878- restore_defaults ()
893+ self . _restore_defaults ()
879894 raise
880- s .provided += 1
881- i += 1
895+ arg_container .provided += 1
882896
883897 # Set keyword arguments
884898 if kwargs : # for speed, skip the items for empty kwargs
885899 for k , arg in kwargs .items ():
886900 self [k ] = arg
887901
888- if (
889- not self .trust_input
890- and
891- # The getattr is only needed for old pickle
892- getattr (self , "_check_for_aliased_inputs" , True )
893- ):
902+ if not self .trust_input :
894903 # Collect aliased inputs among the storage space
895- args_share_memory = []
896- for i in range (len (self .input_storage )):
897- i_var = self .maker .inputs [i ].variable
898- i_val = self .input_storage [i ].storage [0 ]
899- if hasattr (i_var .type , "may_share_memory" ):
900- is_aliased = False
901- for j in range (len (args_share_memory )):
902- group_j = zip (
903- [
904- self .maker .inputs [k ].variable
905- for k in args_share_memory [j ]
906- ],
907- [
908- self .input_storage [k ].storage [0 ]
909- for k in args_share_memory [j ]
910- ],
911- )
904+ for potential_group in self ._potential_aliased_input_groups :
905+ args_share_memory : list [list [int ]] = []
906+ for i in potential_group :
907+ i_type = self .maker .inputs [i ].variable .type
908+ i_val = input_storage [i ].storage [0 ]
909+
910+ # Check if value is aliased with any of the values in one of the groups
911+ for j_group in args_share_memory :
912912 if any (
913- (
914- var .type is i_var .type
915- and var .type .may_share_memory (val , i_val )
916- )
917- for (var , val ) in group_j
913+ i_type .may_share_memory (input_storage [j ].storage [0 ], i_val )
914+ for j in j_group
918915 ):
919- is_aliased = True
920- args_share_memory [j ].append (i )
916+ j_group .append (i )
921917 break
922-
923- if not is_aliased :
918+ else : # no break
919+ # Create a new group
924920 args_share_memory .append ([i ])
925921
926- # Check for groups of more than one argument that share memory
927- for group in args_share_memory :
928- if len (group ) > 1 :
929- # copy all but the first
930- for j in group [1 :]:
931- self . input_storage [j ].storage [0 ] = copy .copy (
932- self . input_storage [j ].storage [0 ]
933- )
922+ # Check for groups of more than one argument that share memory
923+ for group in args_share_memory :
924+ if len (group ) > 1 :
925+ # copy all but the first
926+ for i in group [1 :]:
927+ input_storage [i ].storage [0 ] = copy .copy (
928+ input_storage [i ].storage [0 ]
929+ )
934930
935- # Check if inputs are missing, or if inputs were set more than once, or
936- # if we tried to provide inputs that are supposed to be implicit.
937- if not self .trust_input :
938- for c in self .input_storage :
939- if c .required and not c .provided :
940- restore_defaults ()
931+ # Check if inputs are missing, or if inputs were set more than once, or
932+ # if we tried to provide inputs that are supposed to be implicit.
933+ for arg_container in input_storage :
934+ if arg_container .required and not arg_container .provided :
935+ self ._restore_defaults ()
941936 raise TypeError (
942- f"Missing required input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
937+ f"Missing required input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
943938 )
944- if c .provided > 1 :
945- restore_defaults ()
939+ if arg_container .provided > 1 :
940+ self . _restore_defaults ()
946941 raise TypeError (
947- f"Multiple values for input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
942+ f"Multiple values for input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
948943 )
949- if c .implicit and c .provided > 0 :
950- restore_defaults ()
944+ if arg_container .implicit and arg_container .provided > 0 :
945+ self . _restore_defaults ()
951946 raise TypeError (
952- f"Tried to provide value for implicit input: { getattr (self .inv_finder [c ], 'variable' , self .inv_finder [c ])} "
947+ f"Tried to provide value for implicit input: { getattr (self .inv_finder [arg_container ], 'variable' , self .inv_finder [arg_container ])} "
953948 )
954949
955950 # Do the actual work
956- t0_fn = time .perf_counter ()
951+ if profile :
952+ t0_fn = time .perf_counter ()
957953 try :
958954 outputs = (
959955 self .vm ()
960956 if output_subset is None
961957 else self .vm (output_subset = output_subset )
962958 )
963959 except Exception :
964- restore_defaults ()
960+ self . _restore_defaults ()
965961 if hasattr (self .vm , "position_of_error" ):
966962 # this is a new vm-provided function or c linker
967963 # they need this because the exception manipulation
@@ -979,26 +975,24 @@ def restore_defaults():
979975 # old-style linkers raise their own exceptions
980976 raise
981977
982- dt_fn = time .perf_counter () - t0_fn
983- self .maker .mode .fn_time += dt_fn
984978 if profile :
979+ dt_fn = time .perf_counter () - t0_fn
980+ self .maker .mode .fn_time += dt_fn
985981 profile .vm_call_time += dt_fn
986982
987983 # Retrieve the values that were computed
988984 if outputs is None :
989985 outputs = [x .data for x in self .output_storage ]
990- assert len (outputs ) == len (self .output_storage )
991986
992987 # Remove internal references to required inputs.
993988 # These cannot be re-used anyway.
994- for c in self . input_storage :
995- if c .required :
996- c .storage [0 ] = None
989+ for arg_container in input_storage :
990+ if arg_container .required :
991+ arg_container .storage [0 ] = None
997992
998993 # if we are allowing garbage collection, remove the
999994 # output reference from the internal storage cells
1000995 if getattr (self .vm , "allow_gc" , False ):
1001- assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
1002996 for o_container , o_variable in zip (
1003997 self .output_storage , self .maker .fgraph .outputs
1004998 ):
@@ -1007,37 +1001,31 @@ def restore_defaults():
10071001 # WARNING: This circumvents the 'readonly' attribute in x
10081002 o_container .storage [0 ] = None
10091003
1010- # TODO: Get rid of this and `expanded_inputs`, since all the VMs now
1011- # perform the updates themselves
10121004 if getattr (self .vm , "need_update_inputs" , True ):
10131005 # Update the inputs that have an update function
10141006 for input , storage in reversed (
1015- list (zip (self .maker .expanded_inputs , self . input_storage ))
1007+ list (zip (self .maker .expanded_inputs , input_storage ))
10161008 ):
10171009 if input .update is not None :
10181010 storage .data = outputs .pop ()
10191011 else :
10201012 outputs = outputs [: self .n_returned_outputs ]
10211013
10221014 # Put default values back in the storage
1023- restore_defaults ()
1024- #
1025- # NOTE: This logic needs to be replicated in
1026- # scan.
1027- # grep for 'PROFILE_CODE'
1028- #
1029-
1030- dt_call = time .perf_counter () - t0
1031- pytensor .compile .profiling .total_fct_exec_time += dt_call
1032- self .maker .mode .call_time += dt_call
1015+ self ._restore_defaults ()
1016+
10331017 if profile :
1018+ dt_call = time .perf_counter () - t0
1019+ pytensor .compile .profiling .total_fct_exec_time += dt_call
1020+ self .maker .mode .call_time += dt_call
10341021 profile .fct_callcount += 1
10351022 profile .fct_call_time += dt_call
10361023 if hasattr (self .vm , "update_profile" ):
10371024 self .vm .update_profile (profile )
10381025 if profile .ignore_first_call :
10391026 profile .reset ()
10401027 profile .ignore_first_call = False
1028+
10411029 if self .return_none :
10421030 return None
10431031 elif self .unpack_single and len (outputs ) == 1 and output_subset is None :
0 commit comments