|
41 | 41 | ) |
42 | 42 | from pytensor.graph.rewriting.db import RewriteDatabase |
43 | 43 | from pytensor.raise_op import Assert, CheckAndRaise, assert_op |
| 44 | +from pytensor.scalar.basic import Second |
44 | 45 | from pytensor.tensor.basic import ( |
45 | 46 | Alloc, |
46 | 47 | AllocEmpty, |
@@ -320,56 +321,52 @@ def dimshuffled_alloc(i): |
320 | 321 | return new_outs |
321 | 322 |
|
322 | 323 |
|
323 | | -@register_canonicalize("shape_unsafe") |
324 | 324 | @node_rewriter([Elemwise]) |
325 | 325 | def local_fill_sink(fgraph, node): |
326 | 326 | """ |
327 | 327 | f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) |
328 | 328 | f need to be an elemwise that isn't a fill. |
329 | 329 | """ |
330 | | - if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill: |
| 330 | + if isinstance(node.op.scalar_op, Second): |
331 | 331 | return False |
| 332 | + |
332 | 333 | models = [] |
333 | 334 | inputs = [] |
334 | 335 | for inp in node.inputs: |
335 | 336 | if inp.owner and inp.owner.op == fill: |
336 | | - models.append(inp.owner.inputs[0]) |
337 | | - inputs.append(inp.owner.inputs[1]) |
| 337 | + a, b = inp.owner.inputs |
| 338 | + if b.type.dtype != inp.dtype: |
| 339 | + # The input was implicitly casted by the fill operation |
| 340 | + b = b.cast(inp.dtype) |
| 341 | + models.append(a) |
| 342 | + inputs.append(b) |
338 | 343 | else: |
339 | 344 | inputs.append(inp) |
| 345 | + |
340 | 346 | if not models: |
341 | 347 | return False |
342 | | - c = node.op(*inputs) |
343 | | - for model in models: |
344 | | - if ( |
345 | | - model.type.dtype != c.type.dtype |
346 | | - or model.type.broadcastable != c.type.broadcastable |
347 | | - ): |
348 | | - c = fill(model, c) |
349 | 348 |
|
350 | | - # The newly created node c doesn't has 'clients', |
351 | | - # so this iteration is took place with node.outputs[0] |
352 | | - # TODO: This should just be a WalkingGraphRewrite! |
353 | | - replacements = {node.outputs[0]: c} |
354 | | - for client, cl_idx in fgraph.clients[node.outputs[0]]: |
355 | | - if ( |
356 | | - hasattr(client, "op") |
357 | | - and isinstance(client.op, Elemwise) |
358 | | - and client.op != fill |
359 | | - ): |
360 | | - client_inputs = client.inputs[:] |
361 | | - client_inputs[cl_idx] = c |
362 | | - new_client = client.op(*client_inputs) |
363 | | - |
364 | | - # Add clients to new_client |
365 | | - fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[ |
366 | | - client.outputs[0] |
367 | | - ] |
368 | | - r = local_fill_sink.transform(fgraph, new_client.owner) |
369 | | - if not r: |
370 | | - continue |
371 | | - replacements.update(r) |
372 | | - return replacements |
| 349 | + outputs = node.op.make_node(*inputs).outputs |
| 350 | + |
| 351 | + # Check if we need to propagate the fill to the new outputs |
| 352 | + # It's enough to check the first output, as Elemwise outputs must all have the same shapes |
| 353 | + # Note: There are orderings that may require fewer fills. |
| 354 | + old_bcast_pattern = node.outputs[0].type.broadcastable |
| 355 | + models_iter = iter(models) |
| 356 | + while old_bcast_pattern != outputs[0].type.broadcastable: |
| 357 | + model = next(models_iter) |
| 358 | + # Only apply this model if it would actually do anything |
| 359 | + if broadcasted_by(outputs[0], model): |
| 360 | + outputs = [fill(model, output) for output in outputs] |
| 361 | + |
| 362 | + return outputs |
| 363 | + |
| 364 | + |
| 365 | +# The rewrite is wrapped in an in2out GraphRewriter |
| 366 | +# so that fill can be sinked until the terminal nodes in a single pass through the graph |
| 367 | +# without triggering other rewrites after each local substitution |
| 368 | +topological_fill_sink = in2out(local_fill_sink) |
| 369 | +register_canonicalize(topological_fill_sink, "shape_unsafe") |
373 | 370 |
|
374 | 371 |
|
375 | 372 | @register_specialize("shape_unsafe") |
|
0 commit comments