@@ -112,9 +112,9 @@ def plot_interactive_spectra(
112112 A list of spectra, where each spectrum is a list of lists of float values, each
113113 corresponding to the transmission of a single wavelength.
114114 wavelengths : list of float
115- A list of wavelength values corresponding to the x-axis of the plot.
115+ A list of wavelength values corresponding to the x-axis of the plot, in nm .
116116 vlines : list of float, optional
117- A list of x-values where vertical lines should be drawn. Defaults to an empty list.
117+ A list of x-values where vertical lines should be drawn, in nm . Defaults to an empty list.
118118 hlines : list of float, optional
119119 A list of y-values where horizontal lines should be drawn. Defaults to an empty list.
120120 """
@@ -149,9 +149,16 @@ def plot_interactive_spectra(
149149 all_vals = [val for spec in spectra for iteration in spec for val in iteration ]
150150 y_min = min (all_vals )
151151 y_max = max (all_vals )
152- if hlines :
153- y_min = min (hlines + [y_min ]) * 0.95
154- y_max = max (hlines + [y_max ]) * 1.05
152+
153+ # dB scale
154+ if y_max <= 0 :
155+ y_max = 0
156+ db = True
157+ else :
158+ db = False
159+ if hlines :
160+ y_min = min (hlines + [y_min ]) * 0.95
161+ y_max = max (hlines + [y_max ]) * 1.05
155162
156163 # Create hlines and vlines
157164 shapes = []
@@ -187,8 +194,8 @@ def plot_interactive_spectra(
187194
188195 # Create the layout
189196 fig .update_layout (
190- xaxis_title = "Wavelength" ,
191- yaxis_title = "Transmission" ,
197+ xaxis_title = "Wavelength (nm) " ,
198+ yaxis_title = "Transmission " + "(dB)" if db else "(linear) " ,
192199 shapes = shapes ,
193200 sliders = sliders ,
194201 yaxis = dict (range = [y_min , y_max ]),
@@ -454,10 +461,10 @@ def print_statements(
454461
455462def _str_units_to_float (str_units : str ) -> Optional [float ]:
456463 unit_conversions = {
457- "nm" : 1e-3 ,
458- "um" : 1 ,
459- "mm" : 1e3 ,
460- "m" : 1e6 ,
464+ "nm" : 1 ,
465+ "um" : 1e3 ,
466+ "mm" : 1e6 ,
467+ "m" : 1e9 ,
461468 }
462469 match = re .match (r"([\d\.]+)\s*([a-zA-Z]+)" , str_units )
463470 numeric_value = float (match .group (1 )) if match else None
@@ -469,7 +476,7 @@ def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int =
469476 """
470477 Get the wavelengths to plot based on the statements.
471478
472- Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra.
479+ Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra, in nm .
473480 """
474481
475482 min_wl = float ("inf" )
@@ -511,8 +518,8 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float,
511518 min_wl = min (min_wl , min (vlines ))
512519 max_wl = max (max_wl , max (vlines ))
513520 if min_wl >= max_wl :
514- avg_wl = sum (vlines ) / len (vlines ) if vlines else 1.55
515- min_wl , max_wl = avg_wl - 0.01 , avg_wl + 0.01
521+ avg_wl = sum (vlines ) / len (vlines ) if vlines else _str_units_to_float ( "1550 nm" )
522+ min_wl , max_wl = avg_wl - _str_units_to_float ( "10 nm" ) , avg_wl + _str_units_to_float ( "10 nm" )
516523 else :
517524 range_size = max_wl - min_wl
518525 min_wl -= 0.2 * range_size
0 commit comments