|
5 | 5 | from os.path import dirname |
6 | 6 | from pathlib import Path |
7 | 7 |
|
8 | | -import numpy as np |
9 | 8 | from autograd.builtins import dict as dict_ag |
10 | 9 | from autograd.extend import defvjp, primitive |
11 | 10 |
|
|
47 | 46 | ) |
48 | 47 | from .forward import postprocess_fwd as _postprocess_fwd_impl |
49 | 48 | from .forward import setup_fwd as _setup_fwd_impl |
50 | | -from .io_utils import ( |
51 | | - get_vjp_traced_fields as _get_vjp_traced_fields_impl, |
52 | | -) |
53 | | -from .io_utils import ( |
54 | | - upload_sim_fields_keys as _upload_sim_fields_keys_impl, |
55 | | -) |
56 | | - |
57 | | -# if True, will plot the adjoint fields on the plane provided. used for debugging only |
58 | | -_INSPECT_ADJOINT_FIELDS = False |
59 | | -_INSPECT_ADJOINT_PLANE = td.Box(center=(0, 0, 0), size=(td.inf, td.inf, 0)) |
60 | 49 |
|
61 | 50 |
|
62 | 51 | def is_valid_for_autograd(simulation: td.Simulation) -> bool: |
@@ -623,21 +612,6 @@ def postprocess_fwd( |
623 | 612 | ) |
624 | 613 |
|
625 | 614 |
|
626 | | -def upload_sim_fields_keys(sim_fields_keys: list[tuple], task_id: str, verbose: bool = False): |
627 | | - """Upload traced simulation field keys for adjoint runs (delegated).""" |
628 | | - return _upload_sim_fields_keys_impl( |
629 | | - sim_fields_keys=sim_fields_keys, task_id=task_id, verbose=verbose |
630 | | - ) |
631 | | - |
632 | | - |
633 | | -""" VJP maker for ADJ pass.""" |
634 | | - |
635 | | - |
636 | | -def get_vjp_traced_fields(task_id_adj: str, verbose: bool) -> AutogradFieldMap: |
637 | | - """Fetch VJP traced fields for a completed adjoint job (delegated).""" |
638 | | - return _get_vjp_traced_fields_impl(task_id_adj=task_id_adj, verbose=verbose) |
639 | | - |
640 | | - |
641 | 615 | def _run_bwd( |
642 | 616 | data_fields_original: AutogradFieldMap, |
643 | 617 | sim_fields_original: AutogradFieldMap, |
@@ -919,22 +893,6 @@ def setup_adj( |
919 | 893 | ) |
920 | 894 |
|
921 | 895 |
|
922 | | -def _compute_eps_array(medium, frequencies): |
923 | | - """Deprecated shim, kept for backward compatibility; use ops_backward._compute_eps_array.""" |
924 | | - from .backward import _compute_eps_array as __impl |
925 | | - |
926 | | - return __impl(medium, frequencies) |
927 | | - |
928 | | - |
929 | | -def _slice_field_data( |
930 | | - field_data: dict, freqs: np.ndarray, component_indicator: typing.Optional[str] = None |
931 | | -) -> dict: |
932 | | - """Deprecated shim, kept for backward compatibility; use ops_backward._slice_field_data.""" |
933 | | - from .backward import _slice_field_data as __impl |
934 | | - |
935 | | - return __impl(field_data, freqs, component_indicator) |
936 | | - |
937 | | - |
938 | 896 | def postprocess_adj( |
939 | 897 | sim_data_adj: td.SimulationData, |
940 | 898 | sim_data_orig: td.SimulationData, |
|
0 commit comments