1010
1111from theano import gof
1212
13- from sympy import Array as SympyArray
14- from sympy .printing import latex as sympy_latex
13+ try :
14+ from sympy import Array as SympyArray
15+ from sympy .printing import latex as sympy_latex
16+
17+ def latex_print_array (data ): # pragma: no cover
18+ return sympy_latex (SympyArray (data ))
19+
20+
21+ except ImportError : # pragma: no cover
22+
23+ def latex_print_array (data ):
24+ return data
25+
1526
1627from .opt import FunctionGraph
1728from .ops import RandomVariable
@@ -60,16 +71,16 @@ def process_param(self, idx, sform, pstate):
6071 The printer state.
6172
6273 """
63- return sform
74+ return sform # pragma: no cover
6475
6576 def process (self , output , pstate ):
6677 if output in pstate .memo :
6778 return pstate .memo [output ]
6879
6980 pprinter = pstate .pprinter
70- node = output . owner
81+ node = getattr ( output , " owner" , None )
7182
72- if node is None or not isinstance (node .op , RandomVariable ):
83+ if node is None or not isinstance (node .op , RandomVariable ): # pragma: no cover
7384 raise TypeError (
7485 "Function %s cannot represent a variable that is "
7586 "not the result of a RandomVariable operation" % self .name
@@ -78,7 +89,7 @@ def process(self, output, pstate):
7889 op_name = self .name or getattr (node .op , "print_name" , None )
7990 op_name = op_name or getattr (node .op , "name" , None )
8091
81- if op_name is None :
92+ if op_name is None : # pragma: no cover
8293 raise ValueError (f"Could not find a name for { node .op } " )
8394
8495 # Allow `Op`s to specify their ascii and LaTeX formats (in a tuple/list
@@ -144,7 +155,7 @@ def process(self, output, pstate):
144155
145156class GenericSubtensorPrinter (object ):
146157 def process (self , r , pstate ):
147- if r . owner is None :
158+ if getattr ( r , " owner" , None ) is None : # pragma: no cover
148159 raise TypeError ("Can only print Subtensor." )
149160
150161 output_latex = getattr (pstate , "latex" , False )
@@ -161,13 +172,13 @@ def process(self, r, pstate):
161172 if isinstance (entry , slice ):
162173 s_parts = ["" ] * 2
163174 if entry .start is not None :
164- s_parts [0 ] = entry . start
175+ s_parts [0 ] = pstate . pprinter . process ( inputs . pop ())
165176
166177 if entry .stop is not None :
167- s_parts [1 ] = entry . stop
178+ s_parts [1 ] = pstate . pprinter . process ( inputs . pop ())
168179
169180 if entry .step is not None :
170- s_parts .append (entry . stop )
181+ s_parts .append (pstate . pprinter . process ( inputs . pop ()) )
171182
172183 sidxs .append (":" .join (s_parts ))
173184 else :
@@ -215,16 +226,22 @@ def process(cls, output, pstate):
215226 using_latex = getattr (pstate , "latex" , False )
216227 # Crude--but effective--means of stopping print-outs for large
217228 # arrays.
218- constant = isinstance (output , tt .TensorConstant )
229+ constant = isinstance (output , ( tt .TensorConstant , theano . scalar . basic . ScalarConstant ) )
219230 too_large = constant and (output .data .size > cls .max_line_width * cls .max_line_height )
220231
221232 if constant and not too_large :
222233 # Print constants that aren't too large
223234 if using_latex and output .ndim > 0 :
224- out_name = sympy_latex ( SympyArray ( output .data ) )
235+ out_name = latex_print_array ( output .data )
225236 else :
226237 out_name = str (output .data )
227- elif isinstance (output , tt .TensorVariable ) or constant :
238+ elif (
239+ isinstance (
240+ output ,
241+ (tt .TensorVariable , theano .scalar .basic .Scalar , theano .scalar .basic .ScalarVariable ),
242+ )
243+ or constant
244+ ):
228245 # Process name and shape
229246
230247 # Attempt to get the original variable, in case this is a cloned
@@ -238,7 +255,7 @@ def process(cls, output, pstate):
238255
239256 shape_strings = pstate .preamble_dict .setdefault ("shape_strings" , OrderedDict ())
240257 shape_strings [output ] = shape_info
241- else :
258+ else : # pragma: no cover
242259 raise TypeError (f"Type { type (output )} not handled by variable printer" )
243260
244261 pstate .memo [output ] = out_name
@@ -268,7 +285,7 @@ def process_variable_name(cls, output, pstate):
268285 _ = [available_names .pop (v .name , None ) for v in fgraph .variables ]
269286 setattr (pstate , "available_names" , available_names )
270287
271- if output . name :
288+ if getattr ( output , " name" , None ) :
272289 # Observed an existing name; remove it.
273290 out_name = output .name
274291 available_names .pop (out_name , None )
@@ -524,11 +541,18 @@ def __call__(self, *args, latex_env="equation", latex_label=None):
524541
525542# The order here is important!
526543tt_pprint .printers .insert (
527- 0 , (lambda pstate , r : isinstance (r , tt .Variable ), VariableWithShapePrinter )
544+ 0 ,
545+ (
546+ lambda pstate , r : isinstance (r , (theano .scalar .basic .Scalar , tt .Variable )),
547+ VariableWithShapePrinter ,
548+ ),
528549)
529550tt_pprint .printers .insert (
530551 0 ,
531- (lambda pstate , r : r .owner and isinstance (r .owner .op , RandomVariable ), RandomVariablePrinter ()),
552+ (
553+ lambda pstate , r : getattr (r , "owner" , None ) and isinstance (r .owner .op , RandomVariable ),
554+ RandomVariablePrinter (),
555+ ),
532556)
533557
534558
@@ -538,9 +562,9 @@ def process(self, output, pstate):
538562 return pstate .memo [output ]
539563
540564 pprinter = pstate .pprinter
541- node = output . owner
565+ node = getattr ( output , " owner" , None )
542566
543- if node is None or not isinstance (node .op , Observed ):
567+ if node is None or not isinstance (node .op , Observed ): # pragma: no cover
544568 raise TypeError (f"Node Op is not of type `Observed`: { node .op } " )
545569
546570 val = node .inputs [0 ]
0 commit comments