22import numpy as np # type: ignore
33import iklayout # type: ignore
44import matplotlib .pyplot as plt # type: ignore
5- from ipywidgets import interactive , IntSlider # type: ignore
5+ import plotly . graph_objects as go # type: ignore
66from typing import List , Optional , Tuple , Dict , Set
77
88from . import Parameter , StatementDictionary , StatementValidationDictionary , StatementValidation , Computation
@@ -106,11 +106,10 @@ def plot_interactive_spectra(
106106 spectra : List [List [List [float ]]],
107107 wavelengths : List [float ],
108108 spectrum_labels : Optional [List [str ]] = None ,
109- slider_index : Optional [List [int ]] = None ,
110109 vlines : Optional [List [float ]] = None ,
111110 hlines : Optional [List [float ]] = None ,
112111):
113- """
112+ """"
114113 Creates an interactive plot of spectra with a slider to select different indices.
115114 Parameters:
116115 -----------
@@ -119,63 +118,118 @@ def plot_interactive_spectra(
119118 corresponding to the transmission of a single wavelength.
120119 wavelengths : list of float
121120 A list of wavelength values corresponding to the x-axis of the plot.
122- slider_index : list of int, optional
123- A list of indices for the slider. Defaults to range(len(spectra[0])).
124121 vlines : list of float, optional
125122 A list of x-values where vertical lines should be drawn. Defaults to an empty list.
126123 hlines : list of float, optional
127124 A list of y-values where horizontal lines should be drawn. Defaults to an empty list.
128- Returns:
129- --------
130- ipywidgets.widgets.interaction.interactive
131- An interactive widget that allows the user to select different indices using a slider.
132- Notes:
133- ------
134- - The function uses matplotlib for plotting and ipywidgets for creating the interactive
135- slider.
136- - The y-axis limits are fixed based on the global minimum and maximum values across all
137- spectra.
138- - Vertical and horizontal lines can be added to the plot using the `vlines` and `hlines`
139- parameters.
140125 """
141- # Calculate global y-limits across all arrays
142- y_min = min (min (min (arr2 ) for arr2 in arr1 ) for arr1 in spectra )
143- y_max = max (max (max (arr2 ) for arr2 in arr1 ) for arr1 in spectra )
144- if hlines :
145- y_min = min (hlines + [y_min ])* 0.95
146- y_max = max (hlines + [y_max ])* 1.05
147-
148- slider_index = slider_index or list (range (len (spectra [0 ])))
149- spectrum_labels = spectrum_labels or [f"Spectrum { i } " for i in range (len (spectra ))]
150- vlines = vlines or []
151- hlines = hlines or []
152126
153- # Function to update the plot
154- def plot_array (index = 0 ):
155- plt .close ("all" )
156- plt .figure (figsize = (8 , 4 ))
157- for i , array in enumerate (spectra ):
158- plt .plot (wavelengths , array [index ], lw = 2 , label = spectrum_labels [i ])
159- for x_val in vlines :
160- plt .axvline (
161- x = x_val , color = "red" , linestyle = "--" , label = f"Wavelength (x={ x_val } )"
162- ) # Add vertical line
163- for y_val in hlines :
164- plt .axhline (
165- y = y_val , color = "red" , linestyle = "--" , label = f"Transmission (y={ y_val } )"
166- ) # Add vertical line
167- plt .title (f"Iteration: { index } " )
168- plt .xlabel ("X" )
169- plt .ylabel ("Y" )
170- plt .ylim (y_min , y_max ) # Fix the y-limits
171- plt .legend ()
172- plt .grid (True )
173- plt .show ()
127+ # Defaults
128+ if spectrum_labels is None :
129+ spectrum_labels = [f"Spectrum { i } " for i in range (len (spectra ))]
130+ if vlines is None :
131+ vlines = []
132+ if hlines is None :
133+ hlines = []
134+
135+ # Adjust y-axis range
136+ all_vals = [val for spec in spectra for iteration in spec for val in iteration ]
137+ y_min = min (all_vals )
138+ y_max = max (all_vals )
139+ if hlines :
140+ y_min = min (hlines + [y_min ]) * 0.95
141+ y_max = max (hlines + [y_max ]) * 1.05
142+
143+ # Create hlines and vlines
144+ shapes = []
145+ for xv in vlines :
146+ shapes .append (dict (
147+ type = "line" ,
148+ xref = "x" , x0 = xv , x1 = xv ,
149+ yref = "paper" , y0 = 0 , y1 = 1 ,
150+ line = dict (color = "red" , dash = "dash" )
151+ ))
152+ for yh in hlines :
153+ shapes .append (dict (
154+ type = "line" ,
155+ xref = "paper" , x0 = 0 , x1 = 1 ,
156+ yref = "y" , y0 = yh , y1 = yh ,
157+ line = dict (color = "red" , dash = "dash" )
158+ ))
159+
160+
161+ # Create frames for each index
162+ slider_index = list (range (len (spectra [0 ])))
163+ fig = go .Figure ()
164+
165+ # Build initial figure for immediate display
166+ init_idx = slider_index [0 ]
167+ for i , spec in enumerate (spectra ):
168+ fig .add_trace (
169+ go .Scatter (
170+ x = wavelengths ,
171+ y = spec [init_idx ],
172+ mode = "lines" ,
173+ name = spectrum_labels [i ]
174+ )
175+ )
176+ # Build frames for animation
177+ frames = []
178+ for idx in slider_index :
179+ frame_data = []
180+ for i , spec in enumerate (spectra ):
181+ frame_data .append (
182+ go .Scatter (
183+ x = wavelengths ,
184+ y = spec [idx ],
185+ mode = "lines" ,
186+ name = spectrum_labels [i ]
187+ )
188+ )
189+ frames .append (
190+ go .Frame (
191+ data = frame_data ,
192+ name = str (idx ),
193+ )
194+ )
174195
175- slider = IntSlider (
176- value = 0 , min = 0 , max = len (spectra [0 ]) - 1 , step = 1 , description = "Index"
196+ fig .frames = frames
197+
198+
199+ # Create transition steps
200+ steps = []
201+ for idx in slider_index :
202+ steps .append (dict (
203+ method = "animate" ,
204+ args = [
205+ [str (idx )],
206+ {
207+ "mode" : "immediate" ,
208+ "frame" : {"duration" : 0 , "redraw" : True },
209+ "transition" : {"duration" : 0 }
210+ }
211+ ],
212+ label = str (idx ),
213+ ))
214+
215+ # Create the slider
216+ sliders = [dict (
217+ active = 0 ,
218+ currentvalue = {"prefix" : "Index: " },
219+ pad = {"t" : 50 },
220+ steps = steps
221+ )]
222+
223+ # Create the layout
224+ fig .update_layout (
225+ xaxis_title = "Wavelength" ,
226+ yaxis_title = "Transmission" ,
227+ shapes = shapes ,
228+ sliders = sliders ,
229+ yaxis = dict (range = [y_min , y_max ]),
177230 )
178- return interactive (plot_array , index = slider )
231+
232+ fig .show ()
179233
180234
181235def plot_parameter_history (parameters : List [Parameter ], parameter_history : List [dict ]):
0 commit comments