|
10 | 10 | import xarray as xr |
11 | 11 | from numpy.testing import assert_array_almost_equal, assert_array_equal |
12 | 12 |
|
| 13 | +import skyborn.interp.regridding as regridding_module |
13 | 14 | from skyborn.interp.regridding import ( |
14 | 15 | BilinearRegridder, |
15 | 16 | ConservativeRegridder, |
@@ -240,6 +241,14 @@ def test_nearest_neighbor_indices_function(self, sample_grids): |
240 | 241 | assert np.all(indices >= 0) |
241 | 242 | assert np.all(indices < source_grid.shape[0] * source_grid.shape[1]) |
242 | 243 |
|
| 244 | + def test_nearest_neighbor_2d_wrong_shape_raises(self, sample_grids): |
| 245 | + """Test the private 2D kernel rejects arrays with the wrong source shape.""" |
| 246 | + source_grid, target_grid = sample_grids |
| 247 | + regridder = NearestRegridder(source_grid, target_grid) |
| 248 | + |
| 249 | + with pytest.raises(ValueError, match="to match source.shape"): |
| 250 | + regridder._nearest_neighbor_2d(np.zeros((2, 2))) |
| 251 | + |
243 | 252 |
|
244 | 253 | class TestBilinearRegridder: |
245 | 254 | """Test bilinear regridding functionality.""" |
@@ -426,6 +435,42 @@ def test_longitude_weights_basic(self): |
426 | 435 | row_sums = np.sum(weights, axis=1) |
427 | 436 | assert_array_almost_equal(row_sums, np.ones(len(target_lon)), decimal=5) |
428 | 437 |
|
| 438 | + def test_latitude_weights_zero_sum_rows_fall_back_to_uniform(self, monkeypatch): |
| 439 | + """Test zero-overlap latitude rows use equal weights instead of dividing by zero.""" |
| 440 | + source_lat = np.deg2rad(np.array([-60, 0, 60])) |
| 441 | + target_lat = np.deg2rad(np.array([-30, 30])) |
| 442 | + |
| 443 | + monkeypatch.setattr( |
| 444 | + regridding_module, |
| 445 | + "_latitude_overlap", |
| 446 | + lambda source_points, target_points: np.zeros( |
| 447 | + (len(target_points), len(source_points)) |
| 448 | + ), |
| 449 | + ) |
| 450 | + |
| 451 | + weights = _conservative_latitude_weights(source_lat, target_lat) |
| 452 | + expected = np.full((len(target_lat), len(source_lat)), 1.0 / len(source_lat)) |
| 453 | + |
| 454 | + assert_array_almost_equal(weights, expected) |
| 455 | + |
| 456 | + def test_longitude_weights_zero_sum_rows_fall_back_to_uniform(self, monkeypatch): |
| 457 | + """Test zero-overlap longitude rows use equal weights instead of dividing by zero.""" |
| 458 | + source_lon = np.deg2rad(np.array([0, 90, 180, 270])) |
| 459 | + target_lon = np.deg2rad(np.array([45, 135])) |
| 460 | + |
| 461 | + monkeypatch.setattr( |
| 462 | + regridding_module, |
| 463 | + "_longitude_overlap", |
| 464 | + lambda first_points, second_points: np.zeros( |
| 465 | + (len(first_points), len(second_points)) |
| 466 | + ), |
| 467 | + ) |
| 468 | + |
| 469 | + weights = _conservative_longitude_weights(source_lon, target_lon) |
| 470 | + expected = np.full((len(target_lon), len(source_lon)), 1.0 / len(source_lon)) |
| 471 | + |
| 472 | + assert_array_almost_equal(weights, expected) |
| 473 | + |
429 | 474 | def test_latitude_overlap_function(self): |
430 | 475 | """Test latitude overlap calculation.""" |
431 | 476 | source_lat = np.deg2rad(np.array([-30, 0, 30])) |
|
0 commit comments