|
30 | 30 | from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop |
31 | 31 | from pytensor.tensor.basic import ( |
32 | 32 | MakeVector, |
33 | | - alloc, |
34 | | - cast, |
35 | 33 | constant, |
36 | | - get_underlying_scalar_constant_value, |
37 | 34 | ) |
38 | 35 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
39 | | -from pytensor.tensor.exceptions import NotScalarConstantError |
40 | 36 | from pytensor.tensor.math import add, exp, mul |
41 | 37 | from pytensor.tensor.rewriting.basic import ( |
42 | 38 | alloc_like, |
43 | 39 | broadcasted_by, |
44 | 40 | register_canonicalize, |
45 | 41 | register_specialize, |
46 | 42 | ) |
47 | | -from pytensor.tensor.shape import shape_padleft |
48 | 43 | from pytensor.tensor.variable import TensorConstant, TensorVariable |
49 | 44 |
|
50 | 45 |
|
@@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): |
434 | 429 |
|
435 | 430 | """ |
436 | 431 | if len(node.outputs) > 1: |
437 | | - return |
438 | | - try: |
439 | | - shape_i = fgraph.shape_feature.shape_i |
440 | | - except AttributeError: |
441 | | - shape_i = None |
442 | | - if isinstance(node.op, Elemwise): |
443 | | - scalar_op = node.op.scalar_op |
444 | | - # print "aa", scalar_op.output_types_preference |
445 | | - if getattr(scalar_op, "output_types_preference", None) in ( |
446 | | - ps.upgrade_to_float, |
447 | | - ps.upcast_out, |
448 | | - ): |
449 | | - # this is the kind of op that we can screw with the input |
450 | | - # dtypes by upcasting explicitly |
451 | | - output_dtype = node.outputs[0].type.dtype |
452 | | - new_inputs = [] |
453 | | - for i in node.inputs: |
454 | | - if i.type.dtype == output_dtype: |
455 | | - new_inputs.append(i) |
456 | | - else: |
457 | | - try: |
458 | | - cval_i = get_underlying_scalar_constant_value( |
459 | | - i, only_process_constants=True |
460 | | - ) |
461 | | - if all(i.broadcastable): |
462 | | - new_inputs.append( |
463 | | - shape_padleft(cast(cval_i, output_dtype), i.ndim) |
464 | | - ) |
465 | | - else: |
466 | | - if shape_i is None: |
467 | | - return |
468 | | - new_inputs.append( |
469 | | - alloc( |
470 | | - cast(cval_i, output_dtype), |
471 | | - *[shape_i(d)(i) for d in range(i.ndim)], |
472 | | - ) |
473 | | - ) |
474 | | - # print >> sys.stderr, "AAA", |
475 | | - # *[Shape_i(d)(i) for d in range(i.ndim)] |
476 | | - except NotScalarConstantError: |
477 | | - # for the case of a non-scalar |
478 | | - if isinstance(i, TensorConstant): |
479 | | - new_inputs.append(cast(i, output_dtype)) |
480 | | - else: |
481 | | - new_inputs.append(i) |
| 432 | + return None |
| 433 | + |
| 434 | + if getattr(node.op.scalar_op, "output_types_preference", None) not in ( |
| 435 | + ps.upgrade_to_float, |
| 436 | + ps.upcast_out, |
| 437 | + ): |
| 438 | + return None |
482 | 439 |
|
483 | | - if new_inputs != node.inputs: |
484 | | - rval = [node.op(*new_inputs)] |
485 | | - if not node.outputs[0].type.is_super(rval[0].type): |
486 | | - # This can happen for example when floatX=float32 |
487 | | - # and we do the true division between and int64 |
488 | | - # and a constant that will get typed as int8. |
| 440 | + # this is the kind of op that we can screw with the input |
| 441 | + # dtypes by upcasting explicitly |
| 442 | + [old_out] = node.outputs |
| 443 | + output_dtype = old_out.type.dtype |
| 444 | + new_inputs = list(node.inputs) |
| 445 | + changed = False |
| 446 | + for i, inp in enumerate(node.inputs): |
| 447 | + if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): |
| 448 | + new_inputs[i] = constant(inp.data.astype(output_dtype)) |
| 449 | + changed = True |
| 450 | + |
| 451 | + if not changed: |
| 452 | + return None |
489 | 453 |
|
490 | | - # As this is just to allow merging more case, if |
491 | | - # the upcast don't work, we can just skip it. |
492 | | - return |
| 454 | + rval = node.op(*new_inputs) |
| 455 | + if not old_out.type.is_super(rval.type): |
| 456 | + # This can happen for example when floatX=float32 |
| 457 | + # and we do the true division between and int64 |
| 458 | + # and a constant that will get typed as int8. |
| 459 | + # As this is just to allow merging more case, if |
| 460 | + # the upcast don't work, we can just skip it. |
| 461 | + return None |
493 | 462 |
|
494 | | - # Copy over output stacktrace from before upcasting |
495 | | - copy_stack_trace(node.outputs[0], rval) |
496 | | - return rval |
| 463 | + # Copy over output stacktrace from before upcasting |
| 464 | + copy_stack_trace(old_out, rval) |
| 465 | + return [rval] |
497 | 466 |
|
498 | 467 |
|
499 | 468 | @node_rewriter([add, mul]) |
|
0 commit comments