22
33from textwrap import indent
44
5- import numpy as np
6-
75from pytensor .gradient import DisconnectedType
8- from pytensor .graph .basic import Apply , Variable
6+ from pytensor .graph .basic import Apply , Constant , Variable
97from pytensor .graph .replace import _vectorize_node
108from pytensor .link .c .op import COp
119from pytensor .link .c .params_type import ParamsType
1210from pytensor .link .c .type import Generic
13- from pytensor .scalar .basic import ScalarType
11+ from pytensor .scalar .basic import ScalarType , as_scalar
1412from pytensor .tensor .type import DenseTensorType
1513
1614
@@ -56,18 +54,6 @@ def __str__(self):
5654 msg = self .msg
5755 return f"{ name } {{raises={ exc_name } , msg='{ msg } '}}"
5856
59- def __eq__ (self , other ):
60- if type (self ) is not type (other ):
61- return False
62-
63- if self .msg == other .msg and self .exc_type == other .exc_type :
64- return True
65-
66- return False
67-
68- def __hash__ (self ):
69- return hash ((self .msg , self .exc_type ))
70-
7157 def make_node (self , value : Variable , * conds : Variable ):
7258 """
7359
@@ -84,12 +70,10 @@ def make_node(self, value: Variable, *conds: Variable):
8470 if not isinstance (value , Variable ):
8571 value = pt .as_tensor_variable (value )
8672
87- conds = [
88- pt .as_tensor_variable (c ) if not isinstance (c , Variable ) else c
89- for c in conds
90- ]
91-
92- assert all (c .type .ndim == 0 for c in conds )
73+ conds = [as_scalar (c ) for c in conds ]
74+ for i , cond in enumerate (conds ):
75+ if cond .dtype != "bool" :
76+ conds [i ] = cond .astype ("bool" )
9377
9478 return Apply (
9579 self ,
@@ -101,7 +85,7 @@ def perform(self, node, inputs, outputs):
10185 (out ,) = outputs
10286 val , * conds = inputs
10387 out [0 ] = val
104- if not np . all (conds ):
88+ if not all (conds ):
10589 raise self .exc_type (self .msg )
10690
10791 def grad (self , input , output_gradients ):
@@ -117,38 +101,20 @@ def c_code(self, node, name, inames, onames, props):
117101 )
118102 value_name , * cond_names = inames
119103 out_name = onames [0 ]
120- check = []
121104 fail_code = props ["fail" ]
122105 param_struct_name = props ["params" ]
123106 msg = self .msg .replace ('"' , '\\ "' ).replace ("\n " , "\\ n" )
124107
125- for idx , cond_name in enumerate (cond_names ):
126- if isinstance (node .inputs [0 ].type , DenseTensorType ):
127- check .append (
128- f"""
129- if(PyObject_IsTrue((PyObject *){ cond_name } ) == 0) {{
130- PyObject * exc_type = { param_struct_name } ->exc_type;
131- Py_INCREF(exc_type);
132- PyErr_SetString(exc_type, "{ msg } ");
133- Py_XDECREF(exc_type);
134- { indent (fail_code , " " * 4 )}
135- }}
136- """
137- )
138- else :
139- check .append (
140- f"""
141- if({ cond_name } == 0) {{
142- PyObject * exc_type = { param_struct_name } ->exc_type;
143- Py_INCREF(exc_type);
144- PyErr_SetString(exc_type, "{ msg } ");
145- Py_XDECREF(exc_type);
146- { indent (fail_code , " " * 4 )}
147- }}
148- """
149- )
150-
151- check = "\n " .join (check )
108+ all_conds = " && " .join (cond_names )
109+ check = f"""
110+ if(!({ all_conds } )) {{
111+ PyObject * exc_type = { param_struct_name } ->exc_type;
112+ Py_INCREF(exc_type);
113+ PyErr_SetString(exc_type, "{ msg } ");
114+ Py_XDECREF(exc_type);
115+ { indent (fail_code , " " * 4 )}
116+ }}
117+ """
152118
153119 if isinstance (node .inputs [0 ].type , DenseTensorType ):
154120 res = f"""
@@ -162,14 +128,19 @@ def c_code(self, node, name, inames, onames, props):
162128 { check }
163129 { out_name } = { value_name } ;
164130 """
165- return res
131+
132+ return "\n " .join ((check , res ))
166133
167134 def c_code_cache_version (self ):
168- return (1 , 1 )
135+ return (2 , )
169136
170137 def infer_shape (self , fgraph , node , input_shapes ):
171138 return [input_shapes [0 ]]
172139
140+ def do_constant_folding (self , fgraph , node ):
141+ # Only constant-fold if the Assert does not fail
142+ return all ((isinstance (c , Constant ) and bool (c .data )) for c in node .inputs [1 :])
143+
173144
174145class Assert (CheckAndRaise ):
175146 """Implements assertion in a computational graph.
0 commit comments