|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import functools |
6 | | -import warnings |
7 | 6 | from abc import ABC, abstractmethod |
8 | 7 | from math import isclose |
9 | 8 | from typing import Callable, Optional, Union |
10 | 9 |
|
11 | | -import autograd as ag |
12 | 10 | import autograd.numpy as np |
13 | 11 |
|
14 | 12 | # TODO: it's hard to figure out which functions need this, for now all get it |
@@ -3425,50 +3423,52 @@ def loss_upper_bound(self) -> float: |
3425 | 3423 | ep = ep[~np.isnan(ep)] |
3426 | 3424 | return max(ep.imag) |
3427 | 3425 |
|
| 3426 | + @staticmethod |
| 3427 | + def _get_vjps_from_params( |
| 3428 | + dJ_deps_complex: Union[complex, np.ndarray], |
| 3429 | + poles_vals: list[tuple[Union[complex, np.ndarray], Union[complex, np.ndarray]]], |
| 3430 | + omega: float, |
| 3431 | + requested_paths: list[tuple], |
| 3432 | + ) -> AutogradFieldMap: |
| 3433 | + """ |
| 3434 | + Static helper to compute VJPs from parameters using the analytical chain rule. |
| 3435 | + """ |
| 3436 | + jw = 1j * omega |
| 3437 | + vjps = {} |
| 3438 | + |
| 3439 | + if ("eps_inf",) in requested_paths: |
| 3440 | + vjps[("eps_inf",)] = np.real(dJ_deps_complex) |
| 3441 | + |
| 3442 | + for i, (a_val, c_val) in enumerate(poles_vals): |
| 3443 | + if any(path[1] == i for path in requested_paths if path[0] == "poles"): |
| 3444 | + if ("poles", i, 0) in requested_paths: |
| 3445 | + deps_da = c_val / (jw + a_val) ** 2 |
| 3446 | + dJ_da = dJ_deps_complex * deps_da |
| 3447 | + vjps[("poles", i, 0)] = dJ_da |
| 3448 | + if ("poles", i, 1) in requested_paths: |
| 3449 | + deps_dc = -1 / (jw + a_val) |
| 3450 | + dJ_dc = dJ_deps_complex * deps_dc |
| 3451 | + vjps[("poles", i, 1)] = dJ_dc |
| 3452 | + |
| 3453 | + return vjps |
| 3454 | + |
3428 | 3455 | def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: |
3429 | | - """Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D.""" |
| 3456 | + """Compute adjoint derivatives by preparing scalar data and calling the static helper.""" |
3430 | 3457 |
|
3431 | | - # compute all derivatives beforehand |
3432 | | - dJ_deps = self._derivative_eps_complex_volume( |
| 3458 | + dJ_deps_complex = self._derivative_eps_complex_volume( |
3433 | 3459 | E_der_map=derivative_info.E_der_map, |
3434 | 3460 | bounds=derivative_info.bounds, |
3435 | 3461 | freqs=np.atleast_1d(derivative_info.frequency), |
3436 | 3462 | ) |
3437 | 3463 |
|
3438 | | - dJ_deps = complex(dJ_deps) |
3439 | | - |
3440 | | - # TODO: fix for multi-frequency |
3441 | | - frequency = derivative_info.frequency |
3442 | | - poles_complex = [(complex(a), complex(c)) for a, c in self.poles] |
3443 | | - poles_complex = np.stack(poles_complex, axis=0) |
3444 | | - |
3445 | | - # compute gradients of eps_model with respect to eps_inf and poles |
3446 | | - grad_eps_model = ag.holomorphic_grad(self._eps_model, argnum=(0, 1)) |
3447 | | - with warnings.catch_warnings(): |
3448 | | - # ignore warnings about holmorphic grad being passed a non-complex input (poles) |
3449 | | - warnings.simplefilter("ignore") |
3450 | | - deps_deps_inf, deps_dpoles = grad_eps_model( |
3451 | | - complex(self.eps_inf), poles_complex, complex(frequency) |
3452 | | - ) |
3453 | | - |
3454 | | - # multiply with partial dJ/deps to give full gradients |
3455 | | - |
3456 | | - dJ_deps_inf = dJ_deps * deps_deps_inf |
3457 | | - dJ_dpoles = [(dJ_deps * a, dJ_deps * c) for a, c in deps_dpoles] |
3458 | | - |
3459 | | - # get vjps w.r.t. permittivity and conductivity of the bulk |
3460 | | - derivative_map = {} |
3461 | | - for field_path in derivative_info.paths: |
3462 | | - field_name, *rest = field_path |
| 3464 | + poles_vals = [(complex(a), complex(c)) for a, c in self.poles] |
3463 | 3465 |
|
3464 | | - if field_name == "eps_inf": |
3465 | | - derivative_map[field_path] = float(np.real(dJ_deps_inf)) |
3466 | | - |
3467 | | - elif field_name == "poles": |
3468 | | - pole_index, a_or_c = rest |
3469 | | - derivative_map[field_path] = complex(dJ_dpoles[pole_index][a_or_c]) |
3470 | | - |
3471 | | - return derivative_map |
| 3466 | + return self._get_vjps_from_params( |
| 3467 | + dJ_deps_complex=complex(dJ_deps_complex), |
| 3468 | + poles_vals=poles_vals, |
| 3469 | + omega=2 * np.pi * derivative_info.frequency, |
| 3470 | + requested_paths=derivative_info.paths, |
| 3471 | + ) |
3472 | 3472 |
|
3473 | 3473 | @classmethod |
3474 | 3474 | def _real_partial_fraction_decomposition( |
@@ -3903,73 +3903,28 @@ def _sel_custom_data_inside(self, bounds: Bound): |
3903 | 3903 | return self.updated_copy(eps_inf=eps_inf_reduced, poles=poles_reduced) |
3904 | 3904 |
|
3905 | 3905 | def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: |
3906 | | - """Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D.""" |
| 3906 | + """Compute adjoint derivatives by preparing array data and calling the static helper.""" |
3907 | 3907 |
|
3908 | | - dJ_deps = 0.0 |
| 3908 | + dJ_deps_complex = 0.0 |
3909 | 3909 | for dim in "xyz": |
3910 | | - dJ_deps += self._derivative_field_cmp( |
| 3910 | + dJ_deps_complex += self._derivative_field_cmp( |
3911 | 3911 | E_der_map=derivative_info.E_der_map, |
3912 | 3912 | eps_data=self.eps_inf, |
3913 | 3913 | dim=dim, |
3914 | 3914 | freqs=np.atleast_1d(derivative_info.frequency), |
3915 | 3915 | ) |
3916 | 3916 |
|
3917 | | - # TODO: fix for multi-frequency |
3918 | | - frequency = derivative_info.frequency |
3919 | | - |
3920 | | - poles_complex = [ |
| 3917 | + poles_vals = [ |
3921 | 3918 | (np.array(a.values, dtype=complex), np.array(c.values, dtype=complex)) |
3922 | 3919 | for a, c in self.poles |
3923 | 3920 | ] |
3924 | | - poles_complex = np.stack(poles_complex, axis=0) |
3925 | | - |
3926 | | - def eps_model_r( |
3927 | | - eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float |
3928 | | - ) -> float: |
3929 | | - """Real part of ``eps_model`` evaluated on ``self`` fields.""" |
3930 | | - return np.real(self._eps_model(eps_inf, poles, frequency)) |
3931 | | - |
3932 | | - def eps_model_i( |
3933 | | - eps_inf: complex, poles: list[tuple[complex, complex]], frequency: float |
3934 | | - ) -> float: |
3935 | | - """Real part of ``eps_model`` evaluated on ``self`` fields.""" |
3936 | | - return np.imag(self._eps_model(eps_inf, poles, frequency)) |
3937 | | - |
3938 | | - # compute the gradients w.r.t. each real and imaginary parts for eps_inf and poles |
3939 | | - grad_eps_model_r = ag.elementwise_grad(eps_model_r, argnum=(0, 1)) |
3940 | | - grad_eps_model_i = ag.elementwise_grad(eps_model_i, argnum=(0, 1)) |
3941 | | - deps_deps_inf_r, deps_dpoles_r = grad_eps_model_r( |
3942 | | - self.eps_inf.values, poles_complex, frequency |
3943 | | - ) |
3944 | | - deps_deps_inf_i, deps_dpoles_i = grad_eps_model_i( |
3945 | | - self.eps_inf.values, poles_complex, frequency |
3946 | | - ) |
3947 | | - |
3948 | | - # multiply with dJ_deps partial derivative to give full gradients |
3949 | | - |
3950 | | - deps_deps_inf = deps_deps_inf_r + 1j * deps_deps_inf_i |
3951 | | - dJ_deps_inf = dJ_deps * deps_deps_inf / 3.0 # mysterious 3 |
3952 | 3921 |
|
3953 | | - dJ_dpoles = [] |
3954 | | - for (da_r, dc_r), (da_i, dc_i) in zip(deps_dpoles_r, deps_dpoles_i): |
3955 | | - da = da_r + 1j * da_i |
3956 | | - dc = dc_r + 1j * dc_i |
3957 | | - dJ_da = dJ_deps * da / 2.0 # mysterious 2 |
3958 | | - dJ_dc = dJ_deps * dc / 2.0 # mysterious 2 |
3959 | | - dJ_dpoles.append((dJ_da, dJ_dc)) |
3960 | | - |
3961 | | - derivative_map = {} |
3962 | | - for field_path in derivative_info.paths: |
3963 | | - field_name, *rest = field_path |
3964 | | - |
3965 | | - if field_name == "eps_inf": |
3966 | | - derivative_map[field_path] = np.real(dJ_deps_inf) |
3967 | | - |
3968 | | - elif field_name == "poles": |
3969 | | - pole_index, a_or_c = rest |
3970 | | - derivative_map[field_path] = dJ_dpoles[pole_index][a_or_c] |
3971 | | - |
3972 | | - return derivative_map |
| 3922 | + return PoleResidue._get_vjps_from_params( |
| 3923 | + dJ_deps_complex=dJ_deps_complex, |
| 3924 | + poles_vals=poles_vals, |
| 3925 | + omega=2 * np.pi * derivative_info.frequency, |
| 3926 | + requested_paths=derivative_info.paths, |
| 3927 | + ) |
3973 | 3928 |
|
3974 | 3929 |
|
3975 | 3930 | class Sellmeier(DispersiveMedium): |
|
0 commit comments