11import plotly .graph_objs as go
2+ from _plotly_utils .basevalidators import ColorscaleValidator
23import numpy as np # is it fine to depend on np here?
34
45_float_types = []
@@ -54,7 +55,15 @@ def _infer_zmax_from_type(img):
5455 return 2 ** 32
5556
5657
57- def imshow (img , zmin = None , zmax = None , origin = None , colorscale = None ):
58+ def imshow (
59+ img ,
60+ zmin = None ,
61+ zmax = None ,
62+ origin = None ,
63+ color_continuous_scale = None ,
64+ color_continuous_midpoint = None ,
65+ range_color = None ,
66+ ):
5867 """
5968 Display an image, i.e. data on a 2D regular raster.
6069
@@ -74,16 +83,24 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
7483 zmin and zmax correspond to the min and max values of the datatype for integer
7584 datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
7685 a multichannel image of floats, the max of the image is computed and zmax is the
77- smallest power of 256 (1, 255, 65535) greater than this max value,
86+ smallest power of 256 (1, 255, 65535) greater than this max value,
7887 with a 5% tolerance. For a single-channel image, the max of the image is used.
7988
8089 origin : str, 'upper' or 'lower' (default 'upper')
8190 position of the [0, 0] pixel of the image array, in the upper left or lower left
8291 corner. The convention 'upper' is typically used for matrices and images.
8392
84- colorscale : str
85- colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
86- RGB or RGBA images.
93+ color_continuous_scale : str or list of str
94+ colormap used to map scalar data to colors (for a 2D image). This parameter is
95+ not used for RGB or RGBA images.
96+
97+ color_continuous_midpoint : number
98+ If set, computes the bounds of the continuous color scale to have the desired
99+ midpoint.
100+
101+ range_color : list of two numbers
102+ If provided, overrides auto-scaling on the continuous color scale, including
103+ overriding `color_continuous_midpoint`.
87104
88105 Returns
89106 -------
@@ -108,14 +125,21 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
108125
109126 # For 2d data, use Heatmap trace
110127 if img .ndim == 2 :
111- if colorscale is None :
112- colorscale = "gray"
113- trace = go .Heatmap (z = img , zmin = zmin , zmax = zmax , colorscale = colorscale )
128+ trace = go .Heatmap (z = img , zmin = zmin , zmax = zmax , coloraxis = "coloraxis1" )
114129 autorange = True if origin == "lower" else "reversed"
115130 layout = dict (
116131 xaxis = dict (scaleanchor = "y" , constrain = "domain" ),
117132 yaxis = dict (autorange = autorange , constrain = "domain" ),
118133 )
134+ colorscale_validator = ColorscaleValidator ("colorscale" , "imshow" )
135+ range_color = range_color or [None , None ]
136+ layout ["coloraxis1" ] = dict (
137+ colorscale = colorscale_validator .validate_coerce (color_continuous_scale ),
138+ cmid = color_continuous_midpoint ,
139+ cmin = range_color [0 ],
140+ cmax = range_color [1 ],
141+ )
142+
119143 # For 2D+RGB data, use Image trace
120144 elif img .ndim == 3 and img .shape [- 1 ] in [3 , 4 ]:
121145 if zmax is None and img .dtype is not np .uint8 :
@@ -127,7 +151,7 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
127151 layout ["yaxis" ] = dict (autorange = True )
128152 else :
129153 raise ValueError (
130- "px.imshow only accepts 2D grayscale , RGB or RGBA images. "
154+ "px.imshow only accepts 2D single-channel , RGB or RGBA images. "
131155 "An image of shape %s was provided" % str (img .shape )
132156 )
133157 fig = go .Figure (data = trace , layout = layout )
0 commit comments