Skip to content

Commit f38a6a5

Browse files
authored
Merge pull request #25 from Axiomatic-AI/persistent-interactive-plot
Persistent interactive plots
2 parents c976789 + aa25517 commit f38a6a5

File tree

1 file changed

+106
-52
lines changed

1 file changed

+106
-52
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 106 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np # type: ignore
33
import iklayout # type: ignore
44
import matplotlib.pyplot as plt # type: ignore
5-
from ipywidgets import interactive, IntSlider # type: ignore
5+
import plotly.graph_objects as go # type: ignore
66
from typing import List, Optional, Tuple, Dict, Set
77

88
from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation
@@ -106,11 +106,10 @@ def plot_interactive_spectra(
106106
spectra: List[List[List[float]]],
107107
wavelengths: List[float],
108108
spectrum_labels: Optional[List[str]] = None,
109-
slider_index: Optional[List[int]] = None,
110109
vlines: Optional[List[float]] = None,
111110
hlines: Optional[List[float]] = None,
112111
):
113-
"""
112+
""""
114113
Creates an interactive plot of spectra with a slider to select different indices.
115114
Parameters:
116115
-----------
@@ -119,63 +118,118 @@ def plot_interactive_spectra(
119118
corresponding to the transmission of a single wavelength.
120119
wavelengths : list of float
121120
A list of wavelength values corresponding to the x-axis of the plot.
122-
slider_index : list of int, optional
123-
A list of indices for the slider. Defaults to range(len(spectra[0])).
124121
vlines : list of float, optional
125122
A list of x-values where vertical lines should be drawn. Defaults to an empty list.
126123
hlines : list of float, optional
127124
A list of y-values where horizontal lines should be drawn. Defaults to an empty list.
128-
Returns:
129-
--------
130-
ipywidgets.widgets.interaction.interactive
131-
An interactive widget that allows the user to select different indices using a slider.
132-
Notes:
133-
------
134-
- The function uses matplotlib for plotting and ipywidgets for creating the interactive
135-
slider.
136-
- The y-axis limits are fixed based on the global minimum and maximum values across all
137-
spectra.
138-
- Vertical and horizontal lines can be added to the plot using the `vlines` and `hlines`
139-
parameters.
140125
"""
141-
# Calculate global y-limits across all arrays
142-
y_min = min(min(min(arr2) for arr2 in arr1) for arr1 in spectra)
143-
y_max = max(max(max(arr2) for arr2 in arr1) for arr1 in spectra)
144-
if hlines:
145-
y_min = min(hlines + [y_min])*0.95
146-
y_max = max(hlines + [y_max])*1.05
147-
148-
slider_index = slider_index or list(range(len(spectra[0])))
149-
spectrum_labels = spectrum_labels or [f"Spectrum {i}" for i in range(len(spectra))]
150-
vlines = vlines or []
151-
hlines = hlines or []
152126

153-
# Function to update the plot
154-
def plot_array(index=0):
155-
plt.close("all")
156-
plt.figure(figsize=(8, 4))
157-
for i, array in enumerate(spectra):
158-
plt.plot(wavelengths, array[index], lw=2, label=spectrum_labels[i])
159-
for x_val in vlines:
160-
plt.axvline(
161-
x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})"
162-
) # Add vertical line
163-
for y_val in hlines:
164-
plt.axhline(
165-
y=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})"
166-
) # Add vertical line
167-
plt.title(f"Iteration: {index}")
168-
plt.xlabel("X")
169-
plt.ylabel("Y")
170-
plt.ylim(y_min, y_max) # Fix the y-limits
171-
plt.legend()
172-
plt.grid(True)
173-
plt.show()
127+
# Defaults
128+
if spectrum_labels is None:
129+
spectrum_labels = [f"Spectrum {i}" for i in range(len(spectra))]
130+
if vlines is None:
131+
vlines = []
132+
if hlines is None:
133+
hlines = []
134+
135+
# Adjust y-axis range
136+
all_vals = [val for spec in spectra for iteration in spec for val in iteration]
137+
y_min = min(all_vals)
138+
y_max = max(all_vals)
139+
if hlines:
140+
y_min = min(hlines + [y_min]) * 0.95
141+
y_max = max(hlines + [y_max]) * 1.05
142+
143+
# Create hlines and vlines
144+
shapes = []
145+
for xv in vlines:
146+
shapes.append(dict(
147+
type="line",
148+
xref="x", x0=xv, x1=xv,
149+
yref="paper", y0=0, y1=1,
150+
line=dict(color="red", dash="dash")
151+
))
152+
for yh in hlines:
153+
shapes.append(dict(
154+
type="line",
155+
xref="paper", x0=0, x1=1,
156+
yref="y", y0=yh, y1=yh,
157+
line=dict(color="red", dash="dash")
158+
))
159+
160+
161+
# Create frames for each index
162+
slider_index = list(range(len(spectra[0])))
163+
fig = go.Figure()
164+
165+
# Build initial figure for immediate display
166+
init_idx = slider_index[0]
167+
for i, spec in enumerate(spectra):
168+
fig.add_trace(
169+
go.Scatter(
170+
x=wavelengths,
171+
y=spec[init_idx],
172+
mode="lines",
173+
name=spectrum_labels[i]
174+
)
175+
)
176+
# Build frames for animation
177+
frames = []
178+
for idx in slider_index:
179+
frame_data = []
180+
for i, spec in enumerate(spectra):
181+
frame_data.append(
182+
go.Scatter(
183+
x=wavelengths,
184+
y=spec[idx],
185+
mode="lines",
186+
name=spectrum_labels[i]
187+
)
188+
)
189+
frames.append(
190+
go.Frame(
191+
data=frame_data,
192+
name=str(idx),
193+
)
194+
)
174195

175-
slider = IntSlider(
176-
value=0, min=0, max=len(spectra[0]) - 1, step=1, description="Index"
196+
fig.frames = frames
197+
198+
199+
# Create transition steps
200+
steps = []
201+
for idx in slider_index:
202+
steps.append(dict(
203+
method="animate",
204+
args=[
205+
[str(idx)],
206+
{
207+
"mode": "immediate",
208+
"frame": {"duration": 0, "redraw": True},
209+
"transition": {"duration": 0}
210+
}
211+
],
212+
label=str(idx),
213+
))
214+
215+
# Create the slider
216+
sliders = [dict(
217+
active=0,
218+
currentvalue={"prefix": "Index: "},
219+
pad={"t": 50},
220+
steps=steps
221+
)]
222+
223+
# Create the layout
224+
fig.update_layout(
225+
xaxis_title="Wavelength",
226+
yaxis_title="Transmission",
227+
shapes=shapes,
228+
sliders=sliders,
229+
yaxis=dict(range=[y_min, y_max]),
177230
)
178-
return interactive(plot_array, index=slider)
231+
232+
fig.show()
179233

180234

181235
def plot_parameter_history(parameters: List[Parameter], parameter_history: List[dict]):

0 commit comments

Comments
 (0)