22import numpy as np # type: ignore
33import iklayout # type: ignore
44import matplotlib .pyplot as plt # type: ignore
5- import plotly .graph_objects as go # type: ignore
5+ import plotly .graph_objects as go # type: ignore
66from typing import List , Optional , Tuple , Dict , Set
77
88from . import Parameter , StatementDictionary , StatementValidationDictionary , StatementValidation , Computation
@@ -56,9 +56,7 @@ def plot_constraints(
5656 labels: List of labels for each constraint value.
5757 """
5858
59- constraints_labels = constraints_labels or [
60- f"Constraint { i } " for i in range (len (constraints [0 ]))
61- ]
59+ constraints_labels = constraints_labels or [f"Constraint { i } " for i in range (len (constraints [0 ]))]
6260 iterations = iterations or list (range (len (constraints [0 ])))
6361
6462 plt .clf ()
@@ -92,13 +90,9 @@ def plot_single_spectrum(
9290 plt .ylabel ("Losses" )
9391 plt .plot (wavelengths , spectrum )
9492 for x_val in vlines :
95- plt .axvline (
96- x = x_val , color = "red" , linestyle = "--" , label = f"Wavelength (x={ x_val } )"
97- ) # Add vertical line
93+ plt .axvline (x = x_val , color = "red" , linestyle = "--" , label = f"Wavelength (x={ x_val } )" ) # Add vertical line
9894 for y_val in hlines :
99- plt .axhline (
100- y = y_val , color = "red" , linestyle = "--" , label = f"Transmission (y={ y_val } )"
101- ) # Add vertical line
95+ plt .axhline (y = y_val , color = "red" , linestyle = "--" , label = f"Transmission (y={ y_val } )" ) # Add vertical line
10296 return plt .gcf ()
10397
10498
@@ -109,7 +103,7 @@ def plot_interactive_spectra(
109103 vlines : Optional [List [float ]] = None ,
110104 hlines : Optional [List [float ]] = None ,
111105):
112- """"
106+ """ "
113107 Creates an interactive plot of spectra with a slider to select different indices.
114108 Parameters:
115109 -----------
@@ -131,7 +125,7 @@ def plot_interactive_spectra(
131125 vlines = []
132126 if hlines is None :
133127 hlines = []
134-
128+
135129 # Adjust y-axis range
136130 all_vals = [val for spec in spectra for iteration in spec for val in iteration ]
137131 y_min = min (all_vals )
@@ -143,49 +137,28 @@ def plot_interactive_spectra(
143137 # Create hlines and vlines
144138 shapes = []
145139 for xv in vlines :
146- shapes .append (dict (
147- type = "line" ,
148- xref = "x" , x0 = xv , x1 = xv ,
149- yref = "paper" , y0 = 0 , y1 = 1 ,
150- line = dict (color = "red" , dash = "dash" )
151- ))
140+ shapes .append (
141+ dict (type = "line" , xref = "x" , x0 = xv , x1 = xv , yref = "paper" , y0 = 0 , y1 = 1 , line = dict (color = "red" , dash = "dash" ))
142+ )
152143 for yh in hlines :
153- shapes .append (dict (
154- type = "line" ,
155- xref = "paper" , x0 = 0 , x1 = 1 ,
156- yref = "y" , y0 = yh , y1 = yh ,
157- line = dict (color = "red" , dash = "dash" )
158- ))
159-
160-
144+ shapes .append (
145+ dict (type = "line" , xref = "paper" , x0 = 0 , x1 = 1 , yref = "y" , y0 = yh , y1 = yh , line = dict (color = "red" , dash = "dash" ))
146+ )
147+
161148 # Create frames for each index
162149 slider_index = list (range (len (spectra [0 ])))
163150 fig = go .Figure ()
164151
165152 # Build initial figure for immediate display
166153 init_idx = slider_index [0 ]
167154 for i , spec in enumerate (spectra ):
168- fig .add_trace (
169- go .Scatter (
170- x = wavelengths ,
171- y = spec [init_idx ],
172- mode = "lines" ,
173- name = spectrum_labels [i ]
174- )
175- )
155+ fig .add_trace (go .Scatter (x = wavelengths , y = spec [init_idx ], mode = "lines" , name = spectrum_labels [i ]))
176156 # Build frames for animation
177157 frames = []
178158 for idx in slider_index :
179159 frame_data = []
180160 for i , spec in enumerate (spectra ):
181- frame_data .append (
182- go .Scatter (
183- x = wavelengths ,
184- y = spec [idx ],
185- mode = "lines" ,
186- name = spectrum_labels [i ]
187- )
188- )
161+ frame_data .append (go .Scatter (x = wavelengths , y = spec [idx ], mode = "lines" , name = spectrum_labels [i ]))
189162 frames .append (
190163 go .Frame (
191164 data = frame_data ,
@@ -195,30 +168,22 @@ def plot_interactive_spectra(
195168
196169 fig .frames = frames
197170
198-
199171 # Create transition steps
200172 steps = []
201173 for idx in slider_index :
202- steps .append (dict (
203- method = "animate" ,
204- args = [
205- [str (idx )],
206- {
207- "mode" : "immediate" ,
208- "frame" : {"duration" : 0 , "redraw" : True },
209- "transition" : {"duration" : 0 }
210- }
211- ],
212- label = str (idx ),
213- ))
174+ steps .append (
175+ dict (
176+ method = "animate" ,
177+ args = [
178+ [str (idx )],
179+ {"mode" : "immediate" , "frame" : {"duration" : 0 , "redraw" : True }, "transition" : {"duration" : 0 }},
180+ ],
181+ label = str (idx ),
182+ )
183+ )
214184
215185 # Create the slider
216- sliders = [dict (
217- active = 0 ,
218- currentvalue = {"prefix" : "Index: " },
219- pad = {"t" : 50 },
220- steps = steps
221- )]
186+ sliders = [dict (active = 0 , currentvalue = {"prefix" : "Index: " }, pad = {"t" : 50 }, steps = steps )]
222187
223188 # Create the layout
224189 fig .update_layout (
@@ -253,25 +218,32 @@ def plot_parameter_history(parameters: List[Parameter], parameter_history: List[
253218 plt .xlabel ("Iterations" )
254219 plt .ylabel (param .path )
255220 split_param = param .path .split ("," )
256- plt .plot (
257- [
258- parameter_history [i ][split_param [0 ]][split_param [1 ]]
259- for i in range (len (parameter_history ))
260- ]
261- )
221+ if "," in param .path :
222+ split_param = param .path .split ("," )
223+ plt .plot ([parameter_history [i ][split_param [0 ]][split_param [1 ]] for i in range (len (parameter_history ))])
224+ else :
225+ plt .plot ([parameter_history [i ][param .path ] for i in range (len (parameter_history ))])
262226 plt .show ()
263227
264228
265- def print_statements (statements : StatementDictionary , validation : Optional [StatementValidationDictionary ] = None , only_formalized : bool = False ):
229+ def print_statements (
230+ statements : StatementDictionary ,
231+ validation : Optional [StatementValidationDictionary ] = None ,
232+ only_formalized : bool = False ,
233+ ):
266234 """
267235 Print a list of statements in nice readable format.
268236 """
269237
270238 validation = StatementValidationDictionary (
271- cost_functions = (validation .cost_functions if validation is not None else None ) or [StatementValidation ()]* len (statements .cost_functions or []),
272- parameter_constraints = (validation .parameter_constraints if validation is not None else None ) or [StatementValidation ()]* len (statements .parameter_constraints or []),
273- structure_constraints = (validation .structure_constraints if validation is not None else None ) or [StatementValidation ()]* len (statements .structure_constraints or []),
274- unformalizable_statements = (validation .unformalizable_statements if validation is not None else None ) or [StatementValidation ()]* len (statements .unformalizable_statements or [])
239+ cost_functions = (validation .cost_functions if validation is not None else None )
240+ or [StatementValidation ()] * len (statements .cost_functions or []),
241+ parameter_constraints = (validation .parameter_constraints if validation is not None else None )
242+ or [StatementValidation ()] * len (statements .parameter_constraints or []),
243+ structure_constraints = (validation .structure_constraints if validation is not None else None )
244+ or [StatementValidation ()] * len (statements .structure_constraints or []),
245+ unformalizable_statements = (validation .unformalizable_statements if validation is not None else None )
246+ or [StatementValidation ()] * len (statements .unformalizable_statements or []),
275247 )
276248
277249 if len (validation .cost_functions or []) != len (statements .cost_functions or []):
@@ -299,8 +271,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
299271 if computation is not None :
300272 args_str = ", " .join (
301273 [
302- f"{ argname } ="
303- + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
274+ f"{ argname } =" + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
304275 for argname , argvalue in computation .arguments .items ()
305276 ]
306277 )
@@ -326,8 +297,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
326297 if computation is not None :
327298 args_str = ", " .join (
328299 [
329- f"{ argname } ="
330- + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
300+ f"{ argname } =" + (f"'{ argvalue } '" if isinstance (argvalue , str ) else str (argvalue ))
331301 for argname , argvalue in computation .arguments .items ()
332302 ]
333303 )
@@ -382,9 +352,7 @@ def _str_units_to_float(str_units: str) -> float:
382352 return float (numeric_value * unit_conversions [unit ])
383353
384354
385- def get_wavelengths_to_plot (
386- statements : StatementDictionary , num_samples : int = 100
387- ) -> Tuple [List [float ], List [float ]]:
355+ def get_wavelengths_to_plot (statements : StatementDictionary , num_samples : int = 100 ) -> Tuple [List [float ], List [float ]]:
388356 """
389357 Get the wavelengths to plot based on the statements.
390358
@@ -401,10 +369,16 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
401369 continue
402370 if "wavelengths" in comp .arguments :
403371 vlines = vlines | {
404- _str_units_to_float (wl ) for wl in (comp .arguments ["wavelengths" ] if isinstance (comp .arguments ["wavelengths" ], list ) else []) if isinstance (wl , str )
372+ _str_units_to_float (wl )
373+ for wl in (comp .arguments ["wavelengths" ] if isinstance (comp .arguments ["wavelengths" ], list ) else [])
374+ if isinstance (wl , str )
405375 }
406376 if "wavelength_range" in comp .arguments :
407- if isinstance (comp .arguments ["wavelength_range" ], list ) and len (comp .arguments ["wavelength_range" ]) == 2 and all (isinstance (wl , str ) for wl in comp .arguments ["wavelength_range" ]):
377+ if (
378+ isinstance (comp .arguments ["wavelength_range" ], list )
379+ and len (comp .arguments ["wavelength_range" ]) == 2
380+ and all (isinstance (wl , str ) for wl in comp .arguments ["wavelength_range" ])
381+ ):
408382 min_wl = min (min_wl , _str_units_to_float (comp .arguments ["wavelength_range" ][0 ]))
409383 max_wl = max (max_wl , _str_units_to_float (comp .arguments ["wavelength_range" ][1 ]))
410384 return min_wl , max_wl , vlines
0 commit comments