|
12 | 12 | from ..test_data.test_data_arrays import FS, MODE_SPEC, SIZE_2D |
13 | 13 | from ..utils import AssertLogLevel |
14 | 14 | from tidy3d.plugins.mode import ModeSolver |
| 15 | +from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_INTERP_SPEC |
| 16 | + |
15 | 17 |
|
16 | 18 | # Shared test constants |
17 | 19 | FREQS_DENSE = np.linspace(1e14, 2e14, 20) |
@@ -48,6 +50,52 @@ def test_interp_spec_cubic_needs_4_points(): |
48 | 50 | td.ModeInterpSpec(num_points=3, method="cubic") |
49 | 51 |
|
50 | 52 |
|
| 53 | +def test_interp_spec_valid_cheb(): |
| 54 | + """Test creating valid ModeInterpSpec with Chebyshev interpolation.""" |
| 55 | + spec = td.ModeInterpSpec(num_points=10, method="cheb") |
| 56 | + assert spec.num_points == 10 |
| 57 | + assert spec.method == "cheb" |
| 58 | + |
| 59 | + |
| 60 | +def test_interp_spec_cheb_needs_3_points(): |
| 61 | + """Test that Chebyshev interpolation requires at least 3 points.""" |
| 62 | + with pytest.raises(pydantic.ValidationError, match="Chebyshev interpolation requires at least 3"): |
| 63 | + td.ModeInterpSpec(num_points=2, method="cheb") |
| 64 | + |
| 65 | + |
| 66 | +def test_interp_spec_sampling_points_linear(): |
| 67 | + """Test sampling_points for linear interpolation.""" |
| 68 | + spec = td.ModeInterpSpec(num_points=5, method="linear") |
| 69 | + freqs = np.linspace(1e14, 2e14, 100) |
| 70 | + sampling = spec.sampling_points(freqs) |
| 71 | + |
| 72 | + assert len(sampling) == 5 |
| 73 | + assert np.isclose(sampling[0], 1e14) |
| 74 | + assert np.isclose(sampling[-1], 2e14) |
| 75 | + # Check uniform spacing |
| 76 | + diffs = np.diff(sampling) |
| 77 | + assert np.allclose(diffs, diffs[0]) |
| 78 | + |
| 79 | + |
| 80 | +def test_interp_spec_sampling_points_cheb(): |
| 81 | + """Test sampling_points for Chebyshev interpolation.""" |
| 82 | + spec = td.ModeInterpSpec(num_points=5, method="cheb") |
| 83 | + freqs = np.linspace(1e14, 2e14, 100) |
| 84 | + sampling = spec.sampling_points(freqs) |
| 85 | + |
| 86 | + assert len(sampling) == 5 |
| 87 | + # Chebyshev nodes should include endpoints |
| 88 | + assert np.isclose(sampling.min(), 1e14) |
| 89 | + assert np.isclose(sampling.max(), 2e14) |
| 90 | + |
| 91 | + # Verify they are Chebyshev nodes |
| 92 | + f_min, f_max = 1e14, 2e14 |
| 93 | + k = np.arange(5) |
| 94 | + expected_normalized = np.cos(k * np.pi / 4) |
| 95 | + expected = 0.5 * (f_min + f_max) + 0.5 * (f_max - f_min) * expected_normalized |
| 96 | + assert np.allclose(np.sort(sampling), np.sort(expected)) |
| 97 | + |
| 98 | + |
51 | 99 | def test_interp_spec_min_2_points(): |
52 | 100 | """Test that at least 2 points are required.""" |
53 | 101 | with pytest.raises(pydantic.ValidationError): |
@@ -380,6 +428,107 @@ def test_mode_solver_data_interp_cubic(): |
380 | 428 | assert data_interp.n_complex.shape[0] == 20 |
381 | 429 |
|
382 | 430 |
|
| 431 | +def test_mode_solver_data_interp_cheb(): |
| 432 | + """Test Chebyshev interpolation on ModeSolverData.""" |
| 433 | + # Create data with frequencies at Chebyshev nodes |
| 434 | + interp_spec = td.ModeInterpSpec(num_points=5, method="cheb") |
| 435 | + freqs_all = np.linspace(1e14, 2e14, 50) |
| 436 | + freqs_cheb = interp_spec.sampling_points(freqs_all) |
| 437 | + |
| 438 | + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) |
| 439 | + monitor = td.ModeSolverMonitor( |
| 440 | + center=(0, 0, 0), |
| 441 | + size=SIZE_2D, |
| 442 | + freqs=freqs_cheb, |
| 443 | + mode_spec=mode_spec, |
| 444 | + name="test_cheb", |
| 445 | + ) |
| 446 | + |
| 447 | + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array |
| 448 | + from ..test_data.test_monitor_data import N_COMPLEX |
| 449 | + |
| 450 | + mode_data = td.ModeSolverData( |
| 451 | + monitor=monitor, |
| 452 | + Ex=make_scalar_mode_field_data_array("Ex"), |
| 453 | + Ey=make_scalar_mode_field_data_array("Ey"), |
| 454 | + Ez=make_scalar_mode_field_data_array("Ez"), |
| 455 | + Hx=make_scalar_mode_field_data_array("Hx"), |
| 456 | + Hy=make_scalar_mode_field_data_array("Hy"), |
| 457 | + Hz=make_scalar_mode_field_data_array("Hz"), |
| 458 | + n_complex=N_COMPLEX.copy(), |
| 459 | + symmetry=(0, 0, 0), |
| 460 | + symmetry_center=(0, 0, 0), |
| 461 | + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), |
| 462 | + ) |
| 463 | + |
| 464 | + # Interpolate to 50 frequencies |
| 465 | + data_interp = mode_data.interp(freqs=freqs_all, method="cheb") |
| 466 | + |
| 467 | + # Check frequency dimension |
| 468 | + assert len(data_interp.monitor.freqs) == 50 |
| 469 | + assert data_interp.n_complex.shape[0] == 50 |
| 470 | + |
| 471 | + |
| 472 | +def test_mode_solver_data_interp_cheb_needs_3_source(): |
| 473 | + """Test that Chebyshev interpolation fails with too few source frequencies.""" |
| 474 | + # Create data with only 2 frequencies |
| 475 | + freqs = np.linspace(1e14, 2e14, 2) |
| 476 | + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) |
| 477 | + monitor = td.ModeSolverMonitor( |
| 478 | + center=(0, 0, 0), |
| 479 | + size=SIZE_2D, |
| 480 | + freqs=freqs, |
| 481 | + mode_spec=mode_spec, |
| 482 | + name="test", |
| 483 | + ) |
| 484 | + |
| 485 | + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array |
| 486 | + from ..test_data.test_monitor_data import N_COMPLEX |
| 487 | + |
| 488 | + mode_data = td.ModeSolverData( |
| 489 | + monitor=monitor, |
| 490 | + Ex=make_scalar_mode_field_data_array("Ex"), |
| 491 | + n_complex=N_COMPLEX.copy(), |
| 492 | + symmetry=(0, 0, 0), |
| 493 | + symmetry_center=(0, 0, 0), |
| 494 | + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), |
| 495 | + ) |
| 496 | + |
| 497 | + freqs_dense = np.linspace(1e14, 2e14, 10) |
| 498 | + with pytest.raises(td.exceptions.DataError, match="at least 3 source"): |
| 499 | + mode_data.interp(freqs=freqs_dense, method="cheb") |
| 500 | + |
| 501 | + |
| 502 | +def test_mode_solver_data_interp_cheb_validates_nodes(): |
| 503 | + """Test that Chebyshev interpolation validates source frequencies are Chebyshev nodes.""" |
| 504 | + # Create data with uniform (not Chebyshev) nodes |
| 505 | + freqs_uniform = np.linspace(1e14, 2e14, 5) |
| 506 | + mode_spec = td.ModeSpec(num_modes=2, sort_spec=td.ModeSortSpec(track_freq="central")) |
| 507 | + monitor = td.ModeSolverMonitor( |
| 508 | + center=(0, 0, 0), |
| 509 | + size=SIZE_2D, |
| 510 | + freqs=freqs_uniform, |
| 511 | + mode_spec=mode_spec, |
| 512 | + name="test", |
| 513 | + ) |
| 514 | + |
| 515 | + from ..test_data.test_data_arrays import make_scalar_mode_field_data_array |
| 516 | + from ..test_data.test_monitor_data import N_COMPLEX |
| 517 | + |
| 518 | + mode_data = td.ModeSolverData( |
| 519 | + monitor=monitor, |
| 520 | + Ex=make_scalar_mode_field_data_array("Ex"), |
| 521 | + n_complex=N_COMPLEX.copy(), |
| 522 | + symmetry=(0, 0, 0), |
| 523 | + symmetry_center=(0, 0, 0), |
| 524 | + grid_expanded=td.Grid(boundaries=td.Coords(x=[0, 1], y=[0, 1], z=[0, 1])), |
| 525 | + ) |
| 526 | + |
| 527 | + freqs_dense = np.linspace(1e14, 2e14, 10) |
| 528 | + with pytest.raises(td.exceptions.DataError, match="must be at Chebyshev nodes"): |
| 529 | + mode_data.interp(freqs=freqs_dense, method="cheb") |
| 530 | + |
| 531 | + |
383 | 532 | def test_mode_solver_data_interp_preserves_modes(): |
384 | 533 | """Test that interpolation preserves mode count.""" |
385 | 534 | mode_data = get_mode_solver_data() |
@@ -573,6 +722,32 @@ def test_mode_solver_interp_cubic(): |
573 | 722 | assert data.n_complex.shape[0] == 10 |
574 | 723 |
|
575 | 724 |
|
| 725 | +def test_mode_solver_interp_cheb(): |
| 726 | + """Test that ModeSolver works with Chebyshev interpolation.""" |
| 727 | + sim = get_simple_sim() |
| 728 | + |
| 729 | + freqs = np.linspace(1e14, 2e14, 20) |
| 730 | + mode_spec = td.ModeSpec( |
| 731 | + num_modes=2, |
| 732 | + sort_spec=td.ModeSortSpec(track_freq="central") |
| 733 | + ) |
| 734 | + |
| 735 | + # Chebyshev interpolation requires at least 3 points |
| 736 | + interp_spec = td.ModeInterpSpec(num_points=5, method="cheb") |
| 737 | + |
| 738 | + solver = ModeSolver( |
| 739 | + simulation=sim, |
| 740 | + plane=td.Box(center=(0, 0, 0), size=SIZE_2D), |
| 741 | + freqs=freqs, |
| 742 | + mode_spec=mode_spec, |
| 743 | + interp_spec=interp_spec, |
| 744 | + ) |
| 745 | + |
| 746 | + data = solver.data_raw |
| 747 | + assert len(data.monitor.freqs) == 20 |
| 748 | + assert data.n_complex.shape[0] == 20 |
| 749 | + |
| 750 | + |
576 | 751 | def test_mode_solver_without_interp_returns_full_data(): |
577 | 752 | """Test that solver without interp_spec computes at all frequencies.""" |
578 | 753 | sim = get_simple_sim() |
@@ -755,6 +930,74 @@ def test_mode_monitor_interp_spec_none(): |
755 | 930 | assert monitor.interp_spec is None |
756 | 931 |
|
757 | 932 |
|
| 933 | +# ============================================================================ |
| 934 | +# WavePort interp_spec Tests |
| 935 | +# ============================================================================ |
| 936 | + |
| 937 | +def make_wave_port(): |
| 938 | + """Make a WavePort.""" |
| 939 | + from tidy3d.plugins.smatrix.ports.wave import WavePort |
| 940 | + from tidy3d.components.microwave.path_integrals.integrals.current import AxisAlignedCurrentIntegral |
| 941 | + return WavePort( |
| 942 | + center=(0, 0, 0), |
| 943 | + size=(1, 1, 0), |
| 944 | + direction="+", |
| 945 | + name="port1", |
| 946 | + current_integral=AxisAlignedCurrentIntegral( |
| 947 | + center=(0, 0, 0), |
| 948 | + size=(1, 1, 0), |
| 949 | + sign="+", |
| 950 | + extrapolate_to_endpoints=True, |
| 951 | + snap_contour_to_grid=True, |
| 952 | + ) |
| 953 | + ) |
| 954 | + |
| 955 | + |
| 956 | +def test_wave_port_to_monitors_propagates_default_interp_spec(): |
| 957 | + """Test that WavePort.to_monitors() propagates default interp_spec to ModeMonitor.""" |
| 958 | + |
| 959 | + port = make_wave_port() |
| 960 | + |
| 961 | + freqs = np.linspace(1e14, 2e14, 20) |
| 962 | + monitors = port.to_monitors(freqs=freqs) |
| 963 | + |
| 964 | + assert len(monitors) == 1 |
| 965 | + monitor = monitors[0] |
| 966 | + assert isinstance(monitor, td.ModeMonitor) |
| 967 | + assert monitor.interp_spec is not None |
| 968 | + assert monitor.interp_spec.num_points == DEFAULT_WAVE_PORT_INTERP_SPEC.num_points |
| 969 | + assert monitor.interp_spec.method == DEFAULT_WAVE_PORT_INTERP_SPEC.method |
| 970 | + |
| 971 | + |
| 972 | +def test_wave_port_to_monitors_propagates_custom_interp_spec(): |
| 973 | + """Test that WavePort.to_monitors() propagates custom interp_spec to ModeMonitor.""" |
| 974 | + custom_interp = td.ModeInterpSpec(num_points=8, method="cheb") |
| 975 | + port = make_wave_port().updated_copy(interp_spec=custom_interp) |
| 976 | + |
| 977 | + freqs = np.linspace(1e14, 2e14, 50) |
| 978 | + monitors = port.to_monitors(freqs=freqs) |
| 979 | + |
| 980 | + assert len(monitors) == 1 |
| 981 | + monitor = monitors[0] |
| 982 | + assert isinstance(monitor, td.ModeMonitor) |
| 983 | + assert monitor.interp_spec is not None |
| 984 | + assert monitor.interp_spec.num_points == 8 |
| 985 | + assert monitor.interp_spec.method == "cheb" |
| 986 | + |
| 987 | + |
| 988 | +def test_wave_port_to_monitors_propagates_none_interp_spec(): |
| 989 | + """Test that WavePort.to_monitors() propagates interp_spec=None to ModeMonitor.""" |
| 990 | + port = make_wave_port().updated_copy(interp_spec=None) |
| 991 | + |
| 992 | + freqs = np.linspace(1e14, 2e14, 20) |
| 993 | + monitors = port.to_monitors(freqs=freqs) |
| 994 | + |
| 995 | + assert len(monitors) == 1 |
| 996 | + monitor = monitors[0] |
| 997 | + assert isinstance(monitor, td.ModeMonitor) |
| 998 | + assert monitor.interp_spec is None |
| 999 | + |
| 1000 | + |
758 | 1001 | # ============================================================================ |
759 | 1002 | # Placeholder tests for future phases |
760 | 1003 | # ============================================================================ |
|
0 commit comments