Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/grid_strategy/_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class GridStrategy(metaclass=ABCMeta):
def __init__(self, alignment="center"):
self.alignment = alignment

def get_grid(self, n):
def get_grid(self, n, figure=None):
"""
Return a list of axes designed according to the strategy.
Grid arrangements are tuples with the same length as the number of rows,
Expand All @@ -28,32 +28,34 @@ def get_grid(self, n):
x x x
x x
where each x would be a subplot.

If `figure` is None, creates a new figure.
"""

grid_arrangement = self.get_grid_arrangement(n)
return self.get_gridspec(grid_arrangement)
return self.get_gridspec(grid_arrangement, figure)

@classmethod
@abstractmethod
def get_grid_arrangement(cls, n): # pragma: nocover
pass

def get_gridspec(self, grid_arrangement):
def get_gridspec(self, grid_arrangement, figure=None):
nrows = len(grid_arrangement)
ncols = max(grid_arrangement)

# If it has justified alignment, will not be the same as the other alignments
if self.alignment == "justified":
return self._justified(nrows, grid_arrangement)
return self._justified(nrows, grid_arrangement, figure)
else:
return self._ragged(nrows, ncols, grid_arrangement)
return self._ragged(nrows, ncols, grid_arrangement, figure)

def _justified(self, nrows, grid_arrangement):
def _justified(self, nrows, grid_arrangement, figure=None):
ax_specs = []
num_small_cols = np.lcm.reduce(grid_arrangement)
gs = gridspec.GridSpec(
nrows, num_small_cols, figure=plt.figure(constrained_layout=True)
)
if figure is None:
figure = plt.figure(constrained_layout=True)
gs = gridspec.GridSpec(nrows, num_small_cols, figure=figure)
for r, row_cols in enumerate(grid_arrangement):
skip = num_small_cols // row_cols
for col in range(row_cols):
Expand All @@ -63,15 +65,15 @@ def _justified(self, nrows, grid_arrangement):
ax_specs.append(gs[r, s:e])
return ax_specs

def _ragged(self, nrows, ncols, grid_arrangement):
def _ragged(self, nrows, ncols, grid_arrangement, figure=None):
if len(set(grid_arrangement)) > 1:
col_width = 2
else:
col_width = 1

gs = gridspec.GridSpec(
nrows, ncols * col_width, figure=plt.figure(constrained_layout=True)
)
if figure is None:
figure = plt.figure(constrained_layout=True)
gs = gridspec.GridSpec(nrows, ncols * col_width, figure=figure)

ax_specs = []
for r, row_cols in enumerate(grid_arrangement):
Expand Down