@@ -290,14 +290,20 @@ def imshow(
290290 labels = labels .copy ()
291291 col_labels = []
292292 if facet_col is not None :
293+ if isinstance (facet_col , str ):
294+ facet_col = img .dims .index (facet_col )
293295 nslices = img .shape [facet_col ]
294296 ncols = int (facet_col_wrap ) if facet_col_wrap is not None else nslices
295297 nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
296298 col_labels = ["plane = %d" % i for i in range (nslices )]
297299 else :
298300 nrows = 1
299301 ncols = 1
302+ if animation_frame is not None :
303+ if isinstance (animation_frame , str ):
304+ animation_frame = img .dims .index (animation_frame )
300305 slice_through = (facet_col is not None ) or (animation_frame is not None )
306+ plane_label = None
301307 fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , [])
302308 # ----- Define x and y, set labels if img is an xarray -------------------
303309 if xarray_imported and isinstance (img , xarray .DataArray ):
@@ -307,7 +313,14 @@ def imshow(
307313 # "Please pass your data as a numpy array instead using"
308314 # "`img.values`"
309315 # )
310- y_label , x_label = img .dims [0 ], img .dims [1 ]
316+ dims = list (img .dims )
317+ print (dims )
318+ if slice_through :
319+ slice_index = facet_col if facet_col is not None else animation_frame
320+ _ = dims .pop (slice_index )
321+ plane_label = img .dims [slice_index ]
322+ y_label , x_label = dims [0 ], dims [1 ]
323+ print (y_label , x_label )
311324 # np.datetime64 is not handled correctly by go.Heatmap
312325 for ax in [x_label , y_label ]:
313326 if np .issubdtype (img .coords [ax ].dtype , np .datetime64 ):
@@ -322,6 +335,8 @@ def imshow(
322335 labels ["x" ] = x_label
323336 if labels .get ("y" , None ) is None :
324337 labels ["y" ] = y_label
338+ if labels .get ("plane" , None ) is None :
339+ labels ["plane" ] = plane_label
325340 if labels .get ("color" , None ) is None :
326341 labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
327342 labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
@@ -362,7 +377,9 @@ def imshow(
362377 if animation_frame is not None :
363378 img = np .moveaxis (img , animation_frame , 0 )
364379 animation_frame = True
365- args ["animation_frame" ] = "plane"
380+ args ["animation_frame" ] = (
381+ "plane" if labels .get ("plane" ) is None else labels ["plane" ]
382+ )
366383
367384 # Default behaviour of binary_string: True for RGB images, False for 2D
368385 if binary_string is None :
@@ -403,12 +420,14 @@ def imshow(
403420
404421 # For 2d data, use Heatmap trace, unless binary_string is True
405422 if (img .ndim == 2 or (img .ndim == 3 and slice_through )) and not binary_string :
406- if y is not None and img .shape [0 ] != len (y ):
423+ y_index = 1 if slice_through else 0
424+ if y is not None and img .shape [y_index ] != len (y ):
407425 raise ValueError (
408426 "The length of the y vector must match the length of the first "
409427 + "dimension of the img matrix."
410428 )
411- if x is not None and img .shape [1 ] != len (x ):
429+ x_index = 2 if slice_through else 1
430+ if x is not None and img .shape [x_index ] != len (x ):
412431 raise ValueError (
413432 "The length of the x vector must match the length of the second "
414433 + "dimension of the img matrix."
0 commit comments