File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -42,16 +42,25 @@ def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
4242 if var .owner is None or var .owner .inputs is None :
4343 return set ()
4444
45+ def _filter_non_parameter_inputs (var ):
46+ node = var .owner
47+ if isinstance (node .op , RandomVariable ):
48+ # Filter out rng, dtype and size parameters or RandomVariable nodes
49+ return node .inputs [3 :]
50+ else :
51+ # Otherwise return all inputs
52+ return node .inputs
53+
4554 def _expand (x ):
4655 if x .name :
4756 return [x ]
4857 if isinstance (x .owner , Apply ):
49- return reversed (x . owner . inputs )
58+ return reversed (_filter_non_parameter_inputs ( x ) )
5059 return []
5160
5261 parents = {
5362 get_var_name (x )
54- for x in walk (nodes = var . owner . inputs , expand = _expand )
63+ for x in walk (nodes = _filter_non_parameter_inputs ( var ) , expand = _expand )
5564 # Only consider nodes that are in the named model variables.
5665 if x .name and x .name in self ._all_var_names
5766 }
You can’t perform that action at this time.
0 commit comments