|
6 | 6 |
|
7 | 7 | import autograd as ag |
8 | 8 | import autograd.numpy as np |
| 9 | +import numpy |
9 | 10 | import pytest |
10 | 11 | import xarray.testing as xrt |
11 | 12 | from autograd.test_util import check_grads |
@@ -516,3 +517,66 @@ def test_with_updated_data_shape(): |
516 | 517 |
|
517 | 518 | with pytest.raises(ValueError): |
518 | 519 | arr2 = arr._with_updated_data(data=data, coords=coords) |
| 520 | + |
| 521 | + |
| 522 | +@pytest.mark.parametrize("method", ["nearest", "linear"]) |
| 523 | +def test_interpn_with_extrapolation(rng, method): |
| 524 | + """Checks that the extrapolation in `interpn` works as expected and that |
| 525 | + it is autograd compatible.""" |
| 526 | + arr = td.SpatialDataArray( |
| 527 | + rng.random((1, 3, 4, 5), dtype=np.float64), |
| 528 | + coords={"x": [1], "y": [1, 2, 3], "z": [2, 3, 4, 5], "f": [0, 1, 2, 3, 4]}, |
| 529 | + ) |
| 530 | + |
| 531 | + for coord in arr.dims: |
| 532 | + endpoints = [ |
| 533 | + arr.coords[coord].values[0] - 1.0, |
| 534 | + arr.coords[coord].values[-1] + 1.0, |
| 535 | + ] |
| 536 | + |
| 537 | + method_coord = method if (len(arr.coords[coord]) > 1) else "nearest" |
| 538 | + |
| 539 | + offset_interp_coords = arr.coords[coord].values + 0.5 |
| 540 | + coords_interp = {coord: [endpoints[0], *offset_interp_coords, endpoints[1]]} |
| 541 | + |
| 542 | + extrapolate = arr._ag_interp( |
| 543 | + coords_interp, method=method_coord, kwargs={"fill_value": "extrapolate"} |
| 544 | + ) |
| 545 | + |
| 546 | + compare = arr.interp( |
| 547 | + coords_interp, method=method_coord, kwargs={"fill_value": "extrapolate"} |
| 548 | + ) |
| 549 | + |
| 550 | + numpy.testing.assert_allclose( |
| 551 | + extrapolate.data, compare.data, err_msg="Expected data to be close!" |
| 552 | + ) |
| 553 | + |
| 554 | + def f(params): |
| 555 | + arr = td.SpatialDataArray( |
| 556 | + params.reshape((1, 3, 4, 5)), |
| 557 | + coords={"x": [1], "y": [1, 2, 3], "z": [2, 3, 4, 5], "f": [0, 1, 2, 3, 4]}, |
| 558 | + ) |
| 559 | + |
| 560 | + result = 0.0 |
| 561 | + |
| 562 | + for coord in arr.dims: |
| 563 | + endpoints = [ |
| 564 | + arr.coords[coord].values[0] - 1.0, |
| 565 | + arr.coords[coord].values[-1] + 1.0, |
| 566 | + ] |
| 567 | + |
| 568 | + method_coord = method if (len(arr.coords[coord]) > 1) else "nearest" |
| 569 | + |
| 570 | + offset_interp_coords = arr.coords[coord].values + 0.5 |
| 571 | + coords_interp = {coord: [endpoints[0], *offset_interp_coords, endpoints[1]]} |
| 572 | + |
| 573 | + interp_data = arr.interp( |
| 574 | + coords_interp, method=method_coord, kwargs={"fill_value": "extrapolate"} |
| 575 | + ) |
| 576 | + |
| 577 | + result += np.sum(interp_data.data) |
| 578 | + |
| 579 | + return result |
| 580 | + |
| 581 | + data = rng.random((1, 3, 4, 5), dtype=np.float64) |
| 582 | + check_grads(f, order=1, modes=["rev"])(data) |
0 commit comments