11import plotly .graph_objs as go
22from _plotly_utils .basevalidators import ColorscaleValidator
3- from ._core import apply_default_cascade , init_figure
3+ from ._core import apply_default_cascade , init_figure , configure_animation_controls
44from io import BytesIO
55import base64
66from .imshow_utils import rescale_intensity , _integer_ranges , _integer_types
@@ -133,7 +133,7 @@ def imshow(
133133 labels = {},
134134 x = None ,
135135 y = None ,
136- animation_frame = False ,
136+ animation_frame = None ,
137137 facet_col = None ,
138138 facet_col_wrap = None ,
139139 color_continuous_scale = None ,
@@ -353,13 +353,21 @@ def imshow(
353353
354354 # --------------- Starting from here img is always a numpy array --------
355355 img = np .asanyarray (img )
356+ slice_through = False
356357 if facet_col is not None :
357358 img = np .moveaxis (img , facet_col , 0 )
358359 facet_col = True
359-
360+ slice_through = True
361+ if animation_frame is not None :
362+ img = np .moveaxis (img , animation_frame , 0 )
363+ animation_frame = True
364+ args ["animation_frame" ] = "plane"
365+ slice_through = True
366+
367+ print ("slice_through" , slice_through )
360368 # Default behaviour of binary_string: True for RGB images, False for 2D
361369 if binary_string is None :
362- if facet_col :
370+ if slice_through :
363371 binary_string = img .ndim >= 4 and not is_dataframe
364372 else :
365373 binary_string = img .ndim >= 3 and not is_dataframe
@@ -391,7 +399,7 @@ def imshow(
391399 zmin = 0
392400
393401 # For 2d data, use Heatmap trace, unless binary_string is True
394- if (img .ndim == 2 or (img .ndim == 3 and facet_col )) and not binary_string :
402+ if (img .ndim == 2 or (img .ndim == 3 and slice_through )) and not binary_string :
395403 if y is not None and img .shape [0 ] != len (y ):
396404 raise ValueError (
397405 "The length of the y vector must match the length of the first "
@@ -402,10 +410,10 @@ def imshow(
402410 "The length of the x vector must match the length of the second "
403411 + "dimension of the img matrix."
404412 )
405- if facet_col :
413+ if slice_through :
406414 traces = [
407- go .Heatmap (x = x , y = y , z = img_slice , coloraxis = "coloraxis1" )
408- for img_slice in img
415+ go .Heatmap (x = x , y = y , z = img_slice , coloraxis = "coloraxis1" , name = str ( i ) )
416+ for i , img_slice in enumerate ( img )
409417 ]
410418 else :
411419 traces = [go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )]
@@ -429,7 +437,7 @@ def imshow(
429437 # For 2D+RGB data, use Image trace
430438 elif (
431439 img .ndim == 3
432- and (img .shape [- 1 ] in [3 , 4 ] or (facet_col and binary_string ))
440+ and (img .shape [- 1 ] in [3 , 4 ] or (slice_through and binary_string ))
433441 or (img .ndim == 2 and binary_string )
434442 ):
435443 rescale_image = True # to check whether image has been modified
@@ -442,7 +450,7 @@ def imshow(
442450 if zmin is None and zmax is None : # no rescaling, faster
443451 img_rescaled = img
444452 rescale_image = False
445- elif img .ndim == 2 or (img .ndim == 3 and facet_col ):
453+ elif img .ndim == 2 or (img .ndim == 3 and slice_through ):
446454 img_rescaled = rescale_intensity (
447455 img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
448456 )
@@ -457,7 +465,7 @@ def imshow(
457465 for ch in range (img .shape [- 1 ])
458466 ]
459467 )
460- if facet_col :
468+ if slice_through :
461469 img_str = [
462470 _array_to_b64str (
463471 img_rescaled_slice ,
@@ -477,7 +485,7 @@ def imshow(
477485 ext = binary_format ,
478486 )
479487 ]
480- traces = [go .Image (source = img_str_slice ) for img_str_slice in img_str ]
488+ traces = [go .Image (source = img_str_slice , name = str ( i )) for i , img_str_slice in enumerate ( img_str ) ]
481489 else :
482490 colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
483491 traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
@@ -498,8 +506,15 @@ def imshow(
498506 layout_patch ["title_text" ] = args ["title" ]
499507 elif args ["template" ].layout .margin .t is None :
500508 layout_patch ["margin" ] = {"t" : 60 }
509+
510+ frame_list = []
501511 for index , trace in enumerate (traces ):
502- fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
512+ if facet_col or index == 0 :
513+ fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
514+ if animation_frame :
515+ frame_list .append (dict (data = trace , layout = layout , name = str (index )))
516+ if animation_frame :
517+ fig .frames = frame_list
503518 fig .update_layout (layout )
504519 fig .update_layout (layout_patch )
505520 # Hover name, z or color
@@ -530,5 +545,6 @@ def imshow(
530545 fig .update_xaxes (title_text = labels ["x" ])
531546 if labels ["y" ]:
532547 fig .update_yaxes (title_text = labels ["y" ])
533- fig .update_layout (template = args ["template" ], overwrite = True )
548+ configure_animation_controls (args , go .Image , fig )
549+ #fig.update_layout(template=args["template"], overwrite=True)
534550 return fig
0 commit comments