Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 120 additions & 81 deletions xdggs/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
import ipywidgets
import numpy as np
import xarray as xr
from lonboard import BaseLayer, Map
from lonboard import BaseLayer
from lonboard import Map as LonboardMap


@dataclass
class Container:
obj: xr.DataArray
colorize_kwargs: dict[str, Any]
layer: BaseLayer
dimension_sliders: list[ipywidgets.IntSlider]


def on_slider_change(change, container):
Expand All @@ -16,89 +25,55 @@ def on_slider_change(change, container):

indexers = {
slider.description: slider.value
for slider in container.dimension_sliders.children
for slider in container.dimension_sliders
if slider.description != dim
} | {dim: change["new"]}
new_slice = container.obj.isel(indexers)

new_slice = container.obj.isel(indexers)
colors = colorize(new_slice.variable, **container.colorize_kwargs)

layer = container.map.layers[0]
layer = container.layer
layer.get_fill_color = colors


@dataclass
class MapContainer:
"""container for the map, any control widgets and the data object"""

dimension_sliders: ipywidgets.VBox
map: Map
obj: xr.DataArray

colorize_kwargs: dict[str, Any]
def render_map(
map_: Map, dimension_sliders: list[ipywidgets.IntSlider]
) -> Map | MapWithSliders:
if not dimension_sliders:
return map_

def render(self):
# add any additional control widgets here
control_box = ipywidgets.HBox([self.dimension_sliders])
slider_box = ipywidgets.VBox(dimension_sliders)
control_box = ipywidgets.HBox([slider_box])

return MapWithSliders(
[self.map, control_box], layout=ipywidgets.Layout(width="100%")
)
return MapWithSliders([map_, control_box], layout=ipywidgets.Layout(width="100%"))


def extract_maps(obj: MapGrid | MapWithSliders | Map):
if isinstance(obj, Map):
return obj
def extract_maps(obj: MapGrid | MapWithSliders | Map | LonboardMap):
if isinstance(obj, (Map, LonboardMap)):
return [obj]

return getattr(obj, "maps", (obj.map,))


class MapGrid(ipywidgets.GridBox):
def __init__(
self,
maps: MapWithSliders | Map = None,
n_columns: int = 2,
synchronize: bool = False,
):
self.n_columns = n_columns
self.synchronize = synchronize
class Map(LonboardMap):
def __or__(self, other: Map | MapWithSliders):
if isinstance(other, MapGrid):
return NotImplemented

column_width = 100 // n_columns
layout = ipywidgets.Layout(
width="100%", grid_template_columns=f"repeat({n_columns}, {column_width}%)"
)
return MapGrid([self, other])

if maps is None:
maps = []

if synchronize and maps:
all_maps = [getattr(m, "map", m) for m in maps]

first = all_maps[0]
for second in all_maps[1:]:
ipywidgets.jslink((first, "view_state"), (second, "view_state"))

super().__init__(maps, layout=layout)

def _replace_maps(self, maps):
return type(self)(maps, n_columns=self.n_columns, synchronize=self.synchronize)

def add_map(self, map_: MapWithSliders | Map):
return self._replace_maps(self.maps + (map_,))
def __and__(self, other):
if isinstance(other, (MapWithSliders, MapGrid)):
return NotImplemented

@property
def maps(self):
return self.children

def __or__(self, other: MapGrid | MapWithSliders | Map):
other_maps = extract_maps(other)

return self._replace_maps(self.maps + other_maps)
if isinstance(other, BaseLayer):
other_layers = [other]
else:
other_layers = list(other.layers)

def __ror__(self, other: MapWithSliders | Map):
other_maps = extract_maps(other)
layers = list(self.layers) + list(other_layers)

return self._replace_maps(self.maps + other_maps)
return type(self)(layers)


class MapWithSliders(ipywidgets.VBox):
Expand All @@ -122,6 +97,11 @@ def __or__(self, other: MapWithSliders | Map):

return MapGrid([self, other], synchronize=True)

def __ror__(self, other: Map):
[other_map] = extract_maps(other)

return MapGrid([other_map, self], synchronize=True)

def _merge(self, layers, sliders):
all_layers = list(self.map.layers) + list(layers)
new_map = Map(all_layers)
Expand All @@ -142,6 +122,9 @@ def add_layer(self, layer: BaseLayer):
self.map.add_layer(layer)

def __and__(self, other: MapWithSliders | Map | BaseLayer):
if isinstance(other, MapGrid):
return NotImplemented

if isinstance(other, BaseLayer):
layers = [other]
sliders = []
Expand All @@ -151,6 +134,63 @@ def __and__(self, other: MapWithSliders | Map | BaseLayer):

return self._merge(layers, sliders)

def __rand__(self, other: Map | BaseLayer):
return self & other


class MapGrid(ipywidgets.GridBox):
def __init__(
self,
maps: MapWithSliders | Map = None,
n_columns: int = 2,
synchronize: bool = False,
):
self.n_columns = n_columns
self.synchronize = synchronize

column_width = 100 // n_columns
layout = ipywidgets.Layout(
width="100%", grid_template_columns=f"repeat({n_columns}, {column_width}%)"
)

if maps is None:
maps = []

super().__init__(maps, layout=layout)

if synchronize and maps:
self.synchronize_maps()

def _replace_maps(self, maps):
return type(self)(maps, n_columns=self.n_columns, synchronize=self.synchronize)

def add_map(self, map_: MapWithSliders | Map):
return self._replace_maps(self.maps + (map_,))

@property
def maps(self):
return self.children

def synchronize_maps(self):
if not self.maps:
raise ValueError("no maps to synchronize found")

all_maps = [getattr(m, "map", m) for m in self.maps]

first = all_maps[0]
for second in all_maps[1:]:
ipywidgets.jslink((first, "view_state"), (second, "view_state"))

def __or__(self, other: MapGrid | MapWithSliders | Map):
other_maps = extract_maps(other)

return self._replace_maps(self.maps + other_maps)

def __ror__(self, other: MapWithSliders | Map):
other_maps = extract_maps(other)

return self._replace_maps(self.maps + other_maps)


def create_arrow_table(polygons, arr, coords=None):
from arro3.core import Array, ChunkedArray, Schema, Table
Expand Down Expand Up @@ -205,7 +245,6 @@ def explore(
alpha=None,
coords=None,
):
import lonboard
from lonboard import SolidPolygonLayer
from matplotlib import colormaps

Expand All @@ -227,29 +266,29 @@ def explore(
table = create_arrow_table(polygons, initial_arr, coords=coords)
layer = SolidPolygonLayer(table=table, filled=True, get_fill_color=colors)

map_ = lonboard.Map(layer)
map_ = LonboardMap(layer)

if not initial_indexers:
# 1D data
return map_
sliders = [
ipywidgets.IntSlider(min=0, max=arr.sizes[dim] - 1, description=dim)
for dim in arr.dims
if dim != cell_dim
]

sliders = ipywidgets.VBox(
[
ipywidgets.IntSlider(min=0, max=arr.sizes[dim] - 1, description=dim)
for dim in arr.dims
if dim != cell_dim
]
)
map_object = render_map(map_, sliders)

container = MapContainer(
sliders,
map_,
container = Container(
arr,
colorize_kwargs={"alpha": alpha, "center": center, "colormap": colormap},
colorize_kwargs={
"alpha": alpha,
"center": center,
"colormap": colormap,
},
layer=layer,
dimension_sliders=sliders,
)

# connect slider with map
for slider in sliders.children:
for slider in sliders:
slider.observe(partial(on_slider_change, container=container), names="value")

return container.render()
return map_object
Loading