Skip to content

Commit 1841f16

Browse files
committed
Get wavelengths to plot
1 parent 12e8f1e commit 1841f16

File tree

1 file changed

+68
-3
lines changed

1 file changed

+68
-3
lines changed

src/axiomatic/pic_helpers.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import re
2+
import numpy as np # type: ignore
13
import iklayout # type: ignore
24
import matplotlib.pyplot as plt # type: ignore
35
from ipywidgets import interactive, IntSlider # type: ignore
4-
from typing import List, Optional
6+
from typing import List, Optional, Tuple, Dict, Set
57

6-
from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation
8+
from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation
79

810

911
def plot_circuit(component):
@@ -250,7 +252,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
250252
print(f"Satisfiable: {val.satisfiable}")
251253
print(val.message)
252254
print("\n-----------------------------------\n")
253-
for param_stmt, param_val in zip(statements.cost_functions or [], validation.cost_functions or []):
255+
for param_stmt, param_val in zip(statements.parameter_constraints or [], validation.parameter_constraints or []):
254256
print("Type:", param_stmt.type)
255257
print("Statement:", param_stmt.text)
256258
print("Formalization:", end=" ")
@@ -301,3 +303,66 @@ def print_statements(statements: StatementDictionary, validation: Optional[State
301303
print("Statement:", unf_stmt.text)
302304
print("Formalization: UNFORMALIZABLE")
303305
print("\n-----------------------------------\n")
306+
307+
308+
def _str_units_to_float(str_units: str) -> float:
309+
unit_conversions = {
310+
"nm": 1e-3,
311+
"um": 1,
312+
"mm": 1e3,
313+
"m": 1e6,
314+
}
315+
match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units)
316+
numeric_value = float(match.group(1) if match else 1.55)
317+
unit = match.group(2) if match else "um"
318+
return float(numeric_value * unit_conversions[unit])
319+
320+
321+
def get_wavelengths_to_plot(
322+
statements: StatementDictionary, num_samples: int = 100
323+
) -> Tuple[List[float], List[float]]:
324+
"""
325+
Get the wavelengths to plot based on the statements.
326+
327+
Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra.
328+
"""
329+
330+
min_wl = float("inf")
331+
max_wl = float("-inf")
332+
vlines: set = set()
333+
334+
def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float, max_wl: float, vlines: Set):
335+
for comp in mapping.values():
336+
if comp is None:
337+
continue
338+
if "wavelengths" in comp.arguments:
339+
vlines = vlines | {
340+
_str_units_to_float(wl) for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else []) if isinstance(wl, str)
341+
}
342+
if "wavelength_range" in comp.arguments:
343+
if isinstance(comp.arguments["wavelength_range"], list) and len(comp.arguments["wavelength_range"]) == 2 and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"]):
344+
min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0]))
345+
max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1]))
346+
return min_wl, max_wl, vlines
347+
348+
for cost_stmt in statements.cost_functions or []:
349+
if cost_stmt.formalization is not None and cost_stmt.formalization.mapping is not None:
350+
min_wl, max_wl, vlines = update_wavelengths(cost_stmt.formalization.mapping, min_wl, max_wl, vlines)
351+
352+
for param_stmt in statements.parameter_constraints or []:
353+
if param_stmt.formalization is not None and param_stmt.formalization.mapping is not None:
354+
min_wl, max_wl, vlines = update_wavelengths(param_stmt.formalization.mapping, min_wl, max_wl, vlines)
355+
356+
if vlines:
357+
min_wl = min(min_wl, min(vlines))
358+
max_wl = max(max_wl, max(vlines))
359+
if min_wl >= max_wl:
360+
avg_wl = sum(vlines) / len(vlines) if vlines else 1550
361+
min_wl, max_wl = avg_wl - 0.1, avg_wl + 0.1
362+
else:
363+
range_size = max_wl - min_wl
364+
min_wl -= 0.2 * range_size
365+
max_wl += 0.2 * range_size
366+
367+
wls = np.linspace(min_wl, max_wl, num_samples)
368+
return [float(wl) for wl in wls], list(vlines)

0 commit comments

Comments
 (0)