@@ -194,6 +194,40 @@ def test_imshow_xarray_slicethrough():
194194 assert np .all (np .array (fig .data [0 ].x ) == np .array (da .coords ["dim_2" ]))
195195
196196
197+ def test_imshow_xarray_facet_col_string ():
198+ img = np .random .random ((3 , 4 , 5 ))
199+ da = xr .DataArray (
200+ img , dims = ["str_dim" , "dim_1" , "dim_2" ], coords = {"str_dim" : ["A" , "B" , "C" ]}
201+ )
202+ fig = px .imshow (da , facet_col = "str_dim" )
203+ # Dimensions are used for axis labels and coordinates
204+ assert fig .layout .xaxis .title .text == "dim_2"
205+ assert fig .layout .yaxis .title .text == "dim_1"
206+ assert np .all (np .array (fig .data [0 ].x ) == np .array (da .coords ["dim_2" ]))
207+
208+
209+ def test_imshow_xarray_animation_frame_string ():
210+ img = np .random .random ((3 , 4 , 5 ))
211+ da = xr .DataArray (
212+ img , dims = ["str_dim" , "dim_1" , "dim_2" ], coords = {"str_dim" : ["A" , "B" , "C" ]}
213+ )
214+ fig = px .imshow (da , animation_frame = "str_dim" )
215+ # Dimensions are used for axis labels and coordinates
216+ assert fig .layout .xaxis .title .text == "dim_2"
217+ assert fig .layout .yaxis .title .text == "dim_1"
218+ assert np .all (np .array (fig .data [0 ].x ) == np .array (da .coords ["dim_2" ]))
219+
220+
221+ def test_imshow_xarray_animation_facet_slicethrough ():
222+ img = np .random .random ((3 , 4 , 5 , 6 ))
223+ da = xr .DataArray (img , dims = ["dim_0" , "dim_1" , "dim_2" , "dim_3" ])
224+ fig = px .imshow (da , facet_col = "dim_0" , animation_frame = "dim_1" )
225+ # Dimensions are used for axis labels and coordinates
226+ assert fig .layout .xaxis .title .text == "dim_3"
227+ assert fig .layout .yaxis .title .text == "dim_2"
228+ assert np .all (np .array (fig .data [0 ].x ) == np .array (da .coords ["dim_3" ]))
229+
230+
197231def test_imshow_labels_and_ranges ():
198232 fig = px .imshow (
199233 [[1 , 2 ], [3 , 4 ], [5 , 6 ]],
0 commit comments