11import plotly .graph_objs as go
2+ from _plotly_utils .basevalidators import ColorscaleValidator
3+ from ._core import apply_default_cascade
24import numpy as np # is it fine to depend on np here?
35
46_float_types = []
@@ -54,7 +56,19 @@ def _infer_zmax_from_type(img):
5456 return 2 ** 32
5557
5658
57- def imshow (img , zmin = None , zmax = None , origin = None , colorscale = None ):
59+ def imshow (
60+ img ,
61+ zmin = None ,
62+ zmax = None ,
63+ origin = None ,
64+ color_continuous_scale = None ,
65+ color_continuous_midpoint = None ,
66+ range_color = None ,
67+ title = None ,
68+ template = None ,
69+ width = None ,
70+ height = None ,
71+ ):
5872 """
5973 Display an image, i.e. data on a 2D regular raster.
6074
@@ -74,16 +88,38 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
7488 zmin and zmax correspond to the min and max values of the datatype for integer
7589 datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
7690 a multichannel image of floats, the max of the image is computed and zmax is the
77- smallest power of 256 (1, 255, 65535) greater than this max value,
91+ smallest power of 256 (1, 255, 65535) greater than this max value,
7892 with a 5% tolerance. For a single-channel image, the max of the image is used.
7993
8094 origin : str, 'upper' or 'lower' (default 'upper')
8195 position of the [0, 0] pixel of the image array, in the upper left or lower left
8296 corner. The convention 'upper' is typically used for matrices and images.
8397
84- colorscale : str
85- colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
86- RGB or RGBA images.
98+ color_continuous_scale : str or list of str
99+ colormap used to map scalar data to colors (for a 2D image). This parameter is
100+ not used for RGB or RGBA images. If a string is provided, it should be the name
101+ of a known color scale, and if a list is provided, it should be a list of CSS-
102+ compatible colors.
103+
104+ color_continuous_midpoint : number
105+ If set, computes the bounds of the continuous color scale to have the desired
106+ midpoint.
107+
108+ range_color : list of two numbers
109+ If provided, overrides auto-scaling on the continuous color scale, including
110+ overriding `color_continuous_midpoint`.
111+
112+ title : str
113+ The figure title.
114+
115+ template : str or dict or plotly.graph_objects.layout.Template instance
116+ The figure template name or definition.
117+
118+ width : number
119+ The figure width in pixels.
120+
121+ height: number
122+ The figure height in pixels, defaults to 600.
87123
88124 Returns
89125 -------
@@ -101,21 +137,33 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
101137 In order to update and customize the returned figure, use
102138 `go.Figure.update_traces` or `go.Figure.update_layout`.
103139 """
140+ args = locals ()
141+ apply_default_cascade (args )
142+
104143 img = np .asanyarray (img )
105144 # Cast bools to uint8 (also one byte)
106145 if img .dtype == np .bool :
107146 img = 255 * img .astype (np .uint8 )
108147
109148 # For 2d data, use Heatmap trace
110149 if img .ndim == 2 :
111- if colorscale is None :
112- colorscale = "gray"
113- trace = go .Heatmap (z = img , zmin = zmin , zmax = zmax , colorscale = colorscale )
150+ trace = go .Heatmap (z = img , zmin = zmin , zmax = zmax , coloraxis = "coloraxis1" )
114151 autorange = True if origin == "lower" else "reversed"
115152 layout = dict (
116153 xaxis = dict (scaleanchor = "y" , constrain = "domain" ),
117154 yaxis = dict (autorange = autorange , constrain = "domain" ),
118155 )
156+ colorscale_validator = ColorscaleValidator ("colorscale" , "imshow" )
157+ range_color = range_color or [None , None ]
158+ layout ["coloraxis1" ] = dict (
159+ colorscale = colorscale_validator .validate_coerce (
160+ args ["color_continuous_scale" ]
161+ ),
162+ cmid = color_continuous_midpoint ,
163+ cmin = range_color [0 ],
164+ cmax = range_color [1 ],
165+ )
166+
119167 # For 2D+RGB data, use Image trace
120168 elif img .ndim == 3 and img .shape [- 1 ] in [3 , 4 ]:
121169 if zmax is None and img .dtype is not np .uint8 :
@@ -127,8 +175,17 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
127175 layout ["yaxis" ] = dict (autorange = True )
128176 else :
129177 raise ValueError (
130- "px.imshow only accepts 2D grayscale , RGB or RGBA images. "
178+ "px.imshow only accepts 2D single-channel , RGB or RGBA images. "
131179 "An image of shape %s was provided" % str (img .shape )
132180 )
181+
182+ layout_patch = dict ()
183+ for v in ["title" , "height" , "width" ]:
184+ if args [v ]:
185+ layout_patch [v ] = args [v ]
186+ if "title" not in layout_patch and args ["template" ].layout .margin .t is None :
187+ layout_patch ["margin" ] = {"t" : 60 }
133188 fig = go .Figure (data = trace , layout = layout )
189+ fig .update_layout (layout_patch )
190+ fig .update_layout (template = args ["template" ], overwrite = True )
134191 return fig
0 commit comments