Skip to content

Commit 0c017a5

Browse files
committed
fix: add tests for nearest neighbor regridding shape validation and zero-sum weight handling
1 parent a3c250a commit 0c017a5

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/test_interp_regridding.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import xarray as xr
1111
from numpy.testing import assert_array_almost_equal, assert_array_equal
1212

13+
import skyborn.interp.regridding as regridding_module
1314
from skyborn.interp.regridding import (
1415
BilinearRegridder,
1516
ConservativeRegridder,
@@ -240,6 +241,14 @@ def test_nearest_neighbor_indices_function(self, sample_grids):
240241
assert np.all(indices >= 0)
241242
assert np.all(indices < source_grid.shape[0] * source_grid.shape[1])
242243

244+
def test_nearest_neighbor_2d_wrong_shape_raises(self, sample_grids):
245+
"""Test the private 2D kernel rejects arrays with the wrong source shape."""
246+
source_grid, target_grid = sample_grids
247+
regridder = NearestRegridder(source_grid, target_grid)
248+
249+
with pytest.raises(ValueError, match="to match source.shape"):
250+
regridder._nearest_neighbor_2d(np.zeros((2, 2)))
251+
243252

244253
class TestBilinearRegridder:
245254
"""Test bilinear regridding functionality."""
@@ -426,6 +435,42 @@ def test_longitude_weights_basic(self):
426435
row_sums = np.sum(weights, axis=1)
427436
assert_array_almost_equal(row_sums, np.ones(len(target_lon)), decimal=5)
428437

438+
def test_latitude_weights_zero_sum_rows_fall_back_to_uniform(self, monkeypatch):
439+
"""Test zero-overlap latitude rows use equal weights instead of dividing by zero."""
440+
source_lat = np.deg2rad(np.array([-60, 0, 60]))
441+
target_lat = np.deg2rad(np.array([-30, 30]))
442+
443+
monkeypatch.setattr(
444+
regridding_module,
445+
"_latitude_overlap",
446+
lambda source_points, target_points: np.zeros(
447+
(len(target_points), len(source_points))
448+
),
449+
)
450+
451+
weights = _conservative_latitude_weights(source_lat, target_lat)
452+
expected = np.full((len(target_lat), len(source_lat)), 1.0 / len(source_lat))
453+
454+
assert_array_almost_equal(weights, expected)
455+
456+
def test_longitude_weights_zero_sum_rows_fall_back_to_uniform(self, monkeypatch):
457+
"""Test zero-overlap longitude rows use equal weights instead of dividing by zero."""
458+
source_lon = np.deg2rad(np.array([0, 90, 180, 270]))
459+
target_lon = np.deg2rad(np.array([45, 135]))
460+
461+
monkeypatch.setattr(
462+
regridding_module,
463+
"_longitude_overlap",
464+
lambda first_points, second_points: np.zeros(
465+
(len(first_points), len(second_points))
466+
),
467+
)
468+
469+
weights = _conservative_longitude_weights(source_lon, target_lon)
470+
expected = np.full((len(target_lon), len(source_lon)), 1.0 / len(source_lon))
471+
472+
assert_array_almost_equal(weights, expected)
473+
429474
def test_latitude_overlap_function(self):
430475
"""Test latitude overlap calculation."""
431476
source_lat = np.deg2rad(np.array([-30, 0, 30]))

0 commit comments

Comments
 (0)