Skip to content

Commit 8ab5236

Browse files
authored
Merge pull request #13 from Axiomatic-AI/add-spectrum-plotter
Add spectrum plotter
2 parents fc709e4 + 732a81f commit 8ab5236

File tree

3 files changed

+167
-4
lines changed

3 files changed

+167
-4
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
- name: Set up python
1111
uses: actions/setup-python@v4
1212
with:
13-
python-version: 3.8
13+
python-version: "3.10"
1414
- name: Bootstrap poetry
1515
run: |
1616
curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1
@@ -26,7 +26,7 @@ jobs:
2626
- name: Set up python
2727
uses: actions/setup-python@v4
2828
with:
29-
python-version: 3.8
29+
python-version: "3.10"
3030
- name: Bootstrap poetry
3131
run: |
3232
curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1
@@ -46,7 +46,7 @@ jobs:
4646
- name: Set up python
4747
uses: actions/setup-python@v4
4848
with:
49-
python-version: 3.8
49+
python-version: "3.10"
5050
- name: Bootstrap poetry
5151
run: |
5252
curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ packages = [
3434
Repository = 'https://github.com/axiomatic-ai/axiomatic-python-sdk'
3535

3636
[tool.poetry.dependencies]
37-
python = "^3.8"
37+
python = "^3.10"
3838
httpx = ">=0.21.2"
3939
pydantic = ">= 1.9.2"
4040
pydantic-core = "^2.18.2"

src/axiomatic/pic_helpers.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import iklayout # type: ignore
2+
import matplotlib.pyplot as plt # type: ignore
3+
from ipywidgets import interactive, IntSlider # type: ignore
4+
from typing import List
25

36

47
def plot_circuit(component):
@@ -17,3 +20,163 @@ def plot_circuit(component):
1720
path = component.write_gds().absolute()
1821

1922
return iklayout.show(path)
23+
24+
25+
def plot_losses(
26+
losses: List[float], iterations: List[int] | None = None, return_fig: bool = True
27+
):
28+
"""
29+
Plot a list of losses with labels.
30+
31+
Args:
32+
losses: List of loss values.
33+
"""
34+
iterations = iterations or list(range(len(losses)))
35+
plt.clf()
36+
plt.figure(figsize=(10, 5))
37+
plt.title("Losses vs. Iterations")
38+
plt.xlabel("Iterations")
39+
plt.ylabel("Losses")
40+
plt.plot(iterations, losses)
41+
if return_fig:
42+
return plt.gcf()
43+
plt.show()
44+
45+
46+
def plot_constraints(
47+
constraints: List[List[float]],
48+
constraints_labels: List[str] | None = None,
49+
iterations: List[int] | None = None,
50+
return_fig: bool = True,
51+
):
52+
"""
53+
Plot a list of constraints with labels.
54+
55+
Args:
56+
constraints: List of constraint values.
57+
labels: List of labels for each constraint value.
58+
"""
59+
60+
constraints_labels = constraints_labels or [
61+
f"Constraint {i}" for i in range(len(constraints[0]))
62+
]
63+
iterations = iterations or list(range(len(constraints[0])))
64+
65+
66+
plt.clf()
67+
plt.figure(figsize=(10, 5))
68+
plt.title("Losses vs. Iterations")
69+
plt.xlabel("Iterations")
70+
plt.ylabel("Constraints")
71+
for i, constraint in enumerate(constraints):
72+
plt.plot(iterations, constraint, label=constraints_labels[i])
73+
plt.legend()
74+
plt.grid(True)
75+
if return_fig:
76+
return plt.gcf()
77+
plt.show()
78+
79+
80+
def plot_single_spectrum(
81+
spectrum: List[float],
82+
wavelengths: List[float],
83+
vlines: List[float] | None = None,
84+
hlines: List[float] | None = None,
85+
return_fig: bool = True,
86+
):
87+
"""
88+
Plot a single spectrum with vertical and horizontal lines.
89+
"""
90+
hlines = hlines or []
91+
vlines = vlines or []
92+
93+
plt.clf()
94+
plt.figure(figsize=(10, 5))
95+
plt.title("Losses vs. Iterations")
96+
plt.xlabel("Iterations")
97+
plt.ylabel("Losses")
98+
plt.plot(wavelengths, spectrum)
99+
for x_val in vlines:
100+
plt.axvline(
101+
x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})"
102+
) # Add vertical line
103+
for y_val in hlines:
104+
plt.axhline(
105+
x=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})"
106+
) # Add vertical line
107+
if return_fig:
108+
return plt.gcf()
109+
plt.show()
110+
111+
112+
def plot_interactive_spectrums(
113+
spectrums: List[List[List[float]]],
114+
wavelengths: List[float],
115+
spectrum_labels: List[str] | None = None,
116+
slider_index: List[int] | None = None,
117+
vlines: List[float] | None = None,
118+
hlines: List[float] | None = None,
119+
):
120+
"""
121+
Creates an interactive plot of spectrums with a slider to select different indices.
122+
Parameters:
123+
-----------
124+
spectrums : list of list of float
125+
A list of spectrums, where each spectrum is a list of lists of float values, each
126+
corresponding to the transmission of a single wavelength.
127+
wavelengths : list of float
128+
A list of wavelength values corresponding to the x-axis of the plot.
129+
slider_index : list of int, optional
130+
A list of indices for the slider. Defaults to range(len(spectrums[0])).
131+
vlines : list of float, optional
132+
A list of x-values where vertical lines should be drawn. Defaults to an empty list.
133+
hlines : list of float, optional
134+
A list of y-values where horizontal lines should be drawn. Defaults to an empty list.
135+
Returns:
136+
--------
137+
ipywidgets.widgets.interaction.interactive
138+
An interactive widget that allows the user to select different indices using a slider.
139+
Notes:
140+
------
141+
- The function uses matplotlib for plotting and ipywidgets for creating the interactive
142+
slider.
143+
- The y-axis limits are fixed based on the global minimum and maximum values across all
144+
spectrums.
145+
- Vertical and horizontal lines can be added to the plot using the `vlines` and `hlines`
146+
parameters.
147+
"""
148+
# Calculate global y-limits across all arrays
149+
y_min = min(min(min(arr2) for arr2 in arr1) for arr1 in spectrums)
150+
y_max = max(max(max(arr2) for arr2 in arr1) for arr1 in spectrums)
151+
152+
slider_index = slider_index or list(range(len(spectrums[0])))
153+
spectrum_labels = spectrum_labels or [f"Spectrum{i}" for i in range(len(spectrums))]
154+
vlines = vlines or []
155+
hlines = hlines or []
156+
157+
# Function to update the plot
158+
def plot_array(index=0):
159+
plt.close("all")
160+
plt.figure(figsize=(8, 4))
161+
for i, array in enumerate(spectrums):
162+
plt.plot(wavelengths, array[index], lw=2, label=spectrum_labels[i])
163+
for x_val in vlines:
164+
plt.axvline(
165+
x=x_val, color="red", linestyle="--", label=f"Wavelength (x={x_val})"
166+
) # Add vertical line
167+
for y_val in hlines:
168+
plt.axhline(
169+
x=y_val, color="red", linestyle="--", label=f"Transmission (y={y_val})"
170+
) # Add vertical line
171+
plt.title(f"Iteration: {index}")
172+
plt.xlabel("X")
173+
plt.ylabel("Y")
174+
plt.ylim(y_min, y_max) # Fix the y-limits
175+
plt.legend()
176+
plt.grid(True)
177+
plt.show()
178+
179+
slider = IntSlider(
180+
value=0, min=0, max=len(spectrums[0]) - 1, step=1, description="Index"
181+
)
182+
return interactive(plot_array, index=slider)

0 commit comments

Comments
 (0)