@@ -364,7 +364,6 @@ def imshow(
364364 args ["animation_frame" ] = "plane"
365365 slice_through = True
366366
367- print ("slice_through" , slice_through )
368367 # Default behaviour of binary_string: True for RGB images, False for 2D
369368 if binary_string is None :
370369 if slice_through :
@@ -382,7 +381,11 @@ def imshow(
382381
383382 # -------- Contrast rescaling: either minmax or infer ------------------
384383 if contrast_rescaling is None :
385- contrast_rescaling = "minmax" if img .ndim == 2 else "infer"
384+ contrast_rescaling = (
385+ "minmax"
386+ if (img .ndim == 2 or (img .ndim == 3 and slice_through ))
387+ else "infer"
388+ )
386389
387390 # We try to set zmin and zmax only if necessary, because traces have good defaults
388391 if contrast_rescaling == "minmax" :
@@ -436,10 +439,8 @@ def imshow(
436439
437440 # For 2D+RGB data, use Image trace
438441 elif (
439- img .ndim == 3
440- and (img .shape [- 1 ] in [3 , 4 ] or (slice_through and binary_string ))
441- or (img .ndim == 2 and binary_string )
442- ):
442+ img .ndim >= 3 and (img .shape [- 1 ] in [3 , 4 ] or slice_through and binary_string )
443+ ) or (img .ndim == 2 and binary_string ):
443444 rescale_image = True # to check whether image has been modified
444445 if zmin is not None and zmax is not None :
445446 zmin , zmax = (
@@ -455,15 +456,16 @@ def imshow(
455456 img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
456457 )
457458 else :
458- img_rescaled = np .dstack (
459+ img_rescaled = np .stack (
459460 [
460461 rescale_intensity (
461462 img [..., ch ],
462463 in_range = (zmin [ch ], zmax [ch ]),
463464 out_range = np .uint8 ,
464465 )
465466 for ch in range (img .shape [- 1 ])
466- ]
467+ ],
468+ axis = - 1 ,
467469 )
468470 if slice_through :
469471 img_str = [
@@ -485,10 +487,19 @@ def imshow(
485487 ext = binary_format ,
486488 )
487489 ]
488- traces = [go .Image (source = img_str_slice , name = str (i )) for i , img_str_slice in enumerate (img_str )]
490+ traces = [
491+ go .Image (source = img_str_slice , name = str (i ))
492+ for i , img_str_slice in enumerate (img_str )
493+ ]
489494 else :
490495 colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
491- traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
496+ if slice_through :
497+ traces = [
498+ go .Image (z = img_slice , zmin = zmin , zmax = zmax , colormodel = colormodel )
499+ for img_slice in img
500+ ]
501+ else :
502+ traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
492503 layout = {}
493504 if origin == "lower" :
494505 layout ["yaxis" ] = dict (autorange = True )
@@ -546,5 +557,5 @@ def imshow(
546557 if labels ["y" ]:
547558 fig .update_yaxes (title_text = labels ["y" ])
548559 configure_animation_controls (args , go .Image , fig )
549- #fig.update_layout(template=args["template"], overwrite=True)
560+ # fig.update_layout(template=args["template"], overwrite=True)
550561 return fig
0 commit comments