|
2 | 2 | from collections import OrderedDict |
3 | 3 | from copy import copy |
4 | 4 | from functools import partial |
5 | | -from typing import List, Optional, Sequence, cast |
| 5 | +from typing import Dict, List, Optional, Sequence, Tuple, cast |
6 | 6 |
|
7 | 7 | import pytensor.tensor as at |
8 | 8 | from pytensor import function |
@@ -81,6 +81,81 @@ def local_traverse(out): |
81 | 81 | return ret |
82 | 82 |
|
83 | 83 |
|
| 84 | +def construct_nominal_fgraph( |
| 85 | + inputs: Sequence[Variable], outputs: Sequence[Variable] |
| 86 | +) -> Tuple[ |
| 87 | + FunctionGraph, |
| 88 | + Sequence[Variable], |
| 89 | + Dict[Variable, Variable], |
| 90 | + Dict[Variable, Variable], |
| 91 | +]: |
| 92 | + """Construct an inner-`FunctionGraph` with ordered nominal inputs.""" |
| 93 | + dummy_inputs = [] |
| 94 | + for n, inp in enumerate(inputs): |
| 95 | + if ( |
| 96 | + not isinstance(inp, Variable) |
| 97 | + or isinstance(inp, Constant) |
| 98 | + or isinstance(inp, SharedVariable) |
| 99 | + ): |
| 100 | + raise TypeError( |
| 101 | + f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}" |
| 102 | + ) |
| 103 | + |
| 104 | + dummy_inputs.append(inp.type()) |
| 105 | + |
| 106 | + dummy_shared_inputs = [] |
| 107 | + shared_inputs = [] |
| 108 | + for var in graph_inputs(outputs, inputs): |
| 109 | + if isinstance(var, SharedVariable): |
| 110 | + # To correctly support shared variables the inner-graph should |
| 111 | + # not see them; otherwise, there will be problems with |
| 112 | + # gradients. |
| 113 | + # That's why we collect the shared variables and replace them |
| 114 | + # with dummies. |
| 115 | + shared_inputs.append(var) |
| 116 | + dummy_shared_inputs.append(var.type()) |
| 117 | + elif var not in inputs and not isinstance(var, Constant): |
| 118 | + raise MissingInputError(f"OpFromGraph is missing an input: {var}") |
| 119 | + |
| 120 | + replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs)) |
| 121 | + |
| 122 | + new = rebuild_collect_shared( |
| 123 | + cast(Sequence[Variable], outputs), |
| 124 | + inputs=inputs + shared_inputs, |
| 125 | + replace=replacements, |
| 126 | + copy_inputs_over=False, |
| 127 | + ) |
| 128 | + ( |
| 129 | + local_inputs, |
| 130 | + local_outputs, |
| 131 | + (clone_d, update_d, update_expr, new_shared_inputs), |
| 132 | + ) = new |
| 133 | + |
| 134 | + assert len(local_inputs) == len(inputs) + len(shared_inputs) |
| 135 | + assert len(local_outputs) == len(outputs) |
| 136 | + assert not update_d |
| 137 | + assert not update_expr |
| 138 | + assert not new_shared_inputs |
| 139 | + |
| 140 | + fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) |
| 141 | + |
| 142 | + # The inputs need to be `NominalVariable`s so that we can merge |
| 143 | + # inner-graphs |
| 144 | + nominal_local_inputs = tuple( |
| 145 | + NominalVariable(n, var.type) for n, var in enumerate(local_inputs) |
| 146 | + ) |
| 147 | + |
| 148 | + fgraph.replace_all(zip(local_inputs, nominal_local_inputs)) |
| 149 | + |
| 150 | + for i, inp in enumerate(fgraph.inputs): |
| 151 | + nom_inp = nominal_local_inputs[i] |
| 152 | + fgraph.inputs[i] = nom_inp |
| 153 | + fgraph.clients.pop(inp, None) |
| 154 | + fgraph.add_input(nom_inp) |
| 155 | + |
| 156 | + return fgraph, shared_inputs, update_d, update_expr |
| 157 | + |
| 158 | + |
84 | 159 | class OpFromGraph(Op, HasInnerGraph): |
85 | 160 | r""" |
86 | 161 | This creates an `Op` from inputs and outputs lists of variables. |
@@ -338,76 +413,15 @@ def __init__( |
338 | 413 | f"Inputs and outputs must be Variable instances; got {out}" |
339 | 414 | ) |
340 | 415 |
|
341 | | - dummy_inputs = [] |
342 | | - for n, inp in enumerate(inputs): |
343 | | - if ( |
344 | | - not isinstance(inp, Variable) |
345 | | - or isinstance(inp, Constant) |
346 | | - or isinstance(inp, SharedVariable) |
347 | | - ): |
348 | | - raise TypeError( |
349 | | - f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}" |
350 | | - ) |
351 | | - |
352 | | - dummy_inputs.append(inp.type()) |
353 | | - |
354 | 416 | if "updates" in kwargs or "givens" in kwargs: |
355 | 417 | raise NotImplementedError("Updates and givens are not supported") |
356 | 418 |
|
357 | 419 | self.is_inline = inline |
358 | 420 |
|
359 | | - dummy_shared_inputs = [] |
360 | | - self.shared_inputs = [] |
361 | | - for var in graph_inputs(outputs, inputs): |
362 | | - if isinstance(var, SharedVariable): |
363 | | - # To correctly support shared variables the inner-graph should |
364 | | - # not see them; otherwise, there will be problems with |
365 | | - # gradients. |
366 | | - # That's why we collect the shared variables and replace them |
367 | | - # with dummies. |
368 | | - self.shared_inputs.append(var) |
369 | | - dummy_shared_inputs.append(var.type()) |
370 | | - elif var not in inputs and not isinstance(var, Constant): |
371 | | - raise MissingInputError(f"OpFromGraph is missing an input: {var}") |
372 | | - |
373 | | - replacements = dict( |
374 | | - zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs) |
| 421 | + self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( |
| 422 | + inputs, outputs |
375 | 423 | ) |
376 | 424 |
|
377 | | - new = rebuild_collect_shared( |
378 | | - cast(Sequence[Variable], outputs), |
379 | | - inputs=inputs + self.shared_inputs, |
380 | | - replace=replacements, |
381 | | - copy_inputs_over=False, |
382 | | - ) |
383 | | - ( |
384 | | - local_inputs, |
385 | | - local_outputs, |
386 | | - (clone_d, update_d, update_expr, shared_inputs), |
387 | | - ) = new |
388 | | - |
389 | | - assert len(local_inputs) == len(inputs) + len(self.shared_inputs) |
390 | | - assert len(local_outputs) == len(outputs) |
391 | | - assert not update_d |
392 | | - assert not update_expr |
393 | | - assert not shared_inputs |
394 | | - |
395 | | - self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) |
396 | | - |
397 | | - # The inputs need to be `NominalVariable`s so that we can merge |
398 | | - # inner-graphs |
399 | | - nominal_local_inputs = tuple( |
400 | | - NominalVariable(n, var.type) for n, var in enumerate(local_inputs) |
401 | | - ) |
402 | | - |
403 | | - self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs)) |
404 | | - |
405 | | - for i, inp in enumerate(self.fgraph.inputs): |
406 | | - nom_inp = nominal_local_inputs[i] |
407 | | - self.fgraph.inputs[i] = nom_inp |
408 | | - self.fgraph.clients.pop(inp, None) |
409 | | - self.fgraph.add_input(nom_inp) |
410 | | - |
411 | 425 | self.kwargs = kwargs |
412 | 426 | self.input_types = [inp.type for inp in inputs] |
413 | 427 | self.output_types = [out.type for out in outputs] |
|
0 commit comments