11import plotly .graph_objs as go
22from _plotly_utils .basevalidators import ColorscaleValidator
3+ from ._core import apply_default_cascade
34import numpy as np # is it fine to depend on np here?
45
56_float_types = []
@@ -63,6 +64,10 @@ def imshow(
6364 color_continuous_scale = None ,
6465 color_continuous_midpoint = None ,
6566 range_color = None ,
67+ title = None ,
68+ template = None ,
69+ width = None ,
70+ height = None ,
6671):
6772 """
6873 Display an image, i.e. data on a 2D regular raster.
@@ -118,6 +123,9 @@ def imshow(
118123 In order to update and customize the returned figure, use
119124 `go.Figure.update_traces` or `go.Figure.update_layout`.
120125 """
126+ args = locals ()
127+ apply_default_cascade (args )
128+
121129 img = np .asanyarray (img )
122130 # Cast bools to uint8 (also one byte)
123131 if img .dtype == np .bool :
@@ -134,7 +142,9 @@ def imshow(
134142 colorscale_validator = ColorscaleValidator ("colorscale" , "imshow" )
135143 range_color = range_color or [None , None ]
136144 layout ["coloraxis1" ] = dict (
137- colorscale = colorscale_validator .validate_coerce (color_continuous_scale ),
145+ colorscale = colorscale_validator .validate_coerce (
146+ args ["color_continuous_scale" ]
147+ ),
138148 cmid = color_continuous_midpoint ,
139149 cmin = range_color [0 ],
140150 cmax = range_color [1 ],
@@ -154,5 +164,14 @@ def imshow(
154164 "px.imshow only accepts 2D single-channel, RGB or RGBA images. "
155165 "An image of shape %s was provided" % str (img .shape )
156166 )
167+
168+ layout_patch = dict ()
169+ for v in ["title" , "height" , "width" ]:
170+ if args [v ]:
171+ layout_patch [v ] = args [v ]
172+ if "title" not in layout_patch and args ["template" ].layout .margin .t is None :
173+ layout_patch ["margin" ] = {"t" : 60 }
157174 fig = go .Figure (data = trace , layout = layout )
175+ fig .update_layout (layout_patch )
176+ fig .update_layout (template = args ["template" ], overwrite = True )
158177 return fig
0 commit comments