|
| 1 | +import re |
| 2 | +import numpy as np # type: ignore |
1 | 3 | import iklayout # type: ignore |
2 | 4 | import matplotlib.pyplot as plt # type: ignore |
3 | 5 | from ipywidgets import interactive, IntSlider # type: ignore |
4 | | -from typing import List, Optional |
| 6 | +from typing import List, Optional, Tuple, Dict, Set |
5 | 7 |
|
6 | | -from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation |
| 8 | +from . import Parameter, StatementDictionary, StatementValidationDictionary, StatementValidation, Computation |
7 | 9 |
|
8 | 10 |
|
9 | 11 | def plot_circuit(component): |
@@ -250,7 +252,7 @@ def print_statements(statements: StatementDictionary, validation: Optional[State |
250 | 252 | print(f"Satisfiable: {val.satisfiable}") |
251 | 253 | print(val.message) |
252 | 254 | print("\n-----------------------------------\n") |
253 | | - for param_stmt, param_val in zip(statements.cost_functions or [], validation.cost_functions or []): |
| 255 | + for param_stmt, param_val in zip(statements.parameter_constraints or [], validation.parameter_constraints or []): |
254 | 256 | print("Type:", param_stmt.type) |
255 | 257 | print("Statement:", param_stmt.text) |
256 | 258 | print("Formalization:", end=" ") |
@@ -301,3 +303,66 @@ def print_statements(statements: StatementDictionary, validation: Optional[State |
301 | 303 | print("Statement:", unf_stmt.text) |
302 | 304 | print("Formalization: UNFORMALIZABLE") |
303 | 305 | print("\n-----------------------------------\n") |
| 306 | + |
| 307 | + |
| 308 | +def _str_units_to_float(str_units: str) -> float: |
| 309 | + unit_conversions = { |
| 310 | + "nm": 1e-3, |
| 311 | + "um": 1, |
| 312 | + "mm": 1e3, |
| 313 | + "m": 1e6, |
| 314 | + } |
| 315 | + match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units) |
| 316 | + numeric_value = float(match.group(1) if match else 1.55) |
| 317 | + unit = match.group(2) if match else "um" |
| 318 | + return float(numeric_value * unit_conversions[unit]) |
| 319 | + |
| 320 | + |
| 321 | +def get_wavelengths_to_plot( |
| 322 | + statements: StatementDictionary, num_samples: int = 100 |
| 323 | +) -> Tuple[List[float], List[float]]: |
| 324 | + """ |
| 325 | + Get the wavelengths to plot based on the statements. |
| 326 | +
|
| 327 | + Returns a list of wavelengths to plot the spectra and a list of vertical lines to plot on top the spectra. |
| 328 | + """ |
| 329 | + |
| 330 | + min_wl = float("inf") |
| 331 | + max_wl = float("-inf") |
| 332 | + vlines: set = set() |
| 333 | + |
| 334 | + def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float, max_wl: float, vlines: Set): |
| 335 | + for comp in mapping.values(): |
| 336 | + if comp is None: |
| 337 | + continue |
| 338 | + if "wavelengths" in comp.arguments: |
| 339 | + vlines = vlines | { |
| 340 | + _str_units_to_float(wl) for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else []) if isinstance(wl, str) |
| 341 | + } |
| 342 | + if "wavelength_range" in comp.arguments: |
| 343 | + if isinstance(comp.arguments["wavelength_range"], list) and len(comp.arguments["wavelength_range"]) == 2 and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"]): |
| 344 | + min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0])) |
| 345 | + max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1])) |
| 346 | + return min_wl, max_wl, vlines |
| 347 | + |
| 348 | + for cost_stmt in statements.cost_functions or []: |
| 349 | + if cost_stmt.formalization is not None and cost_stmt.formalization.mapping is not None: |
| 350 | + min_wl, max_wl, vlines = update_wavelengths(cost_stmt.formalization.mapping, min_wl, max_wl, vlines) |
| 351 | + |
| 352 | + for param_stmt in statements.parameter_constraints or []: |
| 353 | + if param_stmt.formalization is not None and param_stmt.formalization.mapping is not None: |
| 354 | + min_wl, max_wl, vlines = update_wavelengths(param_stmt.formalization.mapping, min_wl, max_wl, vlines) |
| 355 | + |
| 356 | + if vlines: |
| 357 | + min_wl = min(min_wl, min(vlines)) |
| 358 | + max_wl = max(max_wl, max(vlines)) |
| 359 | + if min_wl >= max_wl: |
| 360 | + avg_wl = sum(vlines) / len(vlines) if vlines else 1550 |
| 361 | + min_wl, max_wl = avg_wl - 0.1, avg_wl + 0.1 |
| 362 | + else: |
| 363 | + range_size = max_wl - min_wl |
| 364 | + min_wl -= 0.2 * range_size |
| 365 | + max_wl += 0.2 * range_size |
| 366 | + |
| 367 | + wls = np.linspace(min_wl, max_wl, num_samples) |
| 368 | + return [float(wl) for wl in wls], list(vlines) |
0 commit comments