|
9 | 9 | SoftCrossEntropyLoss, |
10 | 10 | TverskyLoss, |
11 | 11 | MCCLoss, |
| 12 | + FocalLoss, |
12 | 13 | ) |
13 | 14 |
|
14 | 15 |
|
@@ -332,3 +333,116 @@ def test_binary_mcc_loss(): |
332 | 333 |
|
333 | 334 | loss = criterion(y_pred, y_true) |
334 | 335 | assert float(loss) == pytest.approx(0.5, abs=eps) |
| 336 | + |
| 337 | + |
| 338 | +@torch.inference_mode() |
| 339 | +def test_class_weights_uniform_equivalent_to_no_weights_multiclass(): |
| 340 | + """Uniform class_weights should produce the same loss as no weights (multiclass).""" |
| 341 | + eps = 1e-5 |
| 342 | + torch.manual_seed(0) |
| 343 | + y_pred = torch.randn(2, 3, 4, 4) |
| 344 | + y_true = torch.randint(0, 3, (2, 4, 4)) |
| 345 | + |
| 346 | + for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]: |
| 347 | + loss_no_w = loss_cls(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true) |
| 348 | + loss_uniform = loss_cls( |
| 349 | + mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 1.0, 1.0] |
| 350 | + )(y_pred, y_true) |
| 351 | + assert torch.allclose(loss_no_w, loss_uniform, atol=eps), ( |
| 352 | + f"Uniform weights should be equivalent to no weights for {loss_cls.__name__}" |
| 353 | + ) |
| 354 | + |
| 355 | + |
| 356 | +@torch.inference_mode() |
| 357 | +def test_class_weights_uniform_equivalent_to_no_weights_multilabel(): |
| 358 | + """Uniform class_weights should produce the same loss as no weights (multilabel).""" |
| 359 | + eps = 1e-5 |
| 360 | + torch.manual_seed(0) |
| 361 | + y_pred = torch.randn(2, 3, 4, 4) |
| 362 | + y_true = torch.randint(0, 2, (2, 3, 4, 4)).float() |
| 363 | + |
| 364 | + for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]: |
| 365 | + loss_no_w = loss_cls(mode=smp.losses.MULTILABEL_MODE)(y_pred, y_true) |
| 366 | + loss_uniform = loss_cls( |
| 367 | + mode=smp.losses.MULTILABEL_MODE, class_weights=[1.0, 1.0, 1.0] |
| 368 | + )(y_pred, y_true) |
| 369 | + assert torch.allclose(loss_no_w, loss_uniform, atol=eps), ( |
| 370 | + f"Uniform weights should be equivalent to no weights for {loss_cls.__name__}" |
| 371 | + ) |
| 372 | + |
| 373 | + |
| 374 | +@torch.inference_mode() |
| 375 | +def test_class_weights_nonuniform_changes_loss_multiclass(): |
| 376 | + """Non-uniform class_weights should change the loss value (multiclass).""" |
| 377 | + torch.manual_seed(0) |
| 378 | + y_pred = torch.randn(2, 3, 4, 4) |
| 379 | + y_true = torch.randint(0, 3, (2, 4, 4)) |
| 380 | + |
| 381 | + for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]: |
| 382 | + loss_no_w = loss_cls(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true) |
| 383 | + loss_weighted = loss_cls( |
| 384 | + mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5] |
| 385 | + )(y_pred, y_true) |
| 386 | + assert not torch.allclose(loss_no_w, loss_weighted, atol=1e-6), ( |
| 387 | + f"Non-uniform weights should change the loss for {loss_cls.__name__}" |
| 388 | + ) |
| 389 | + |
| 390 | + |
| 391 | +@torch.inference_mode() |
| 392 | +def test_class_weights_scale_invariant_multiclass(): |
| 393 | + """Scaling all weights by a constant should not change the loss (multiclass).""" |
| 394 | + eps = 1e-5 |
| 395 | + torch.manual_seed(0) |
| 396 | + y_pred = torch.randn(2, 3, 4, 4) |
| 397 | + y_true = torch.randint(0, 3, (2, 4, 4)) |
| 398 | + |
| 399 | + for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]: |
| 400 | + loss_w = loss_cls( |
| 401 | + mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5] |
| 402 | + )(y_pred, y_true) |
| 403 | + loss_w_scaled = loss_cls( |
| 404 | + mode=smp.losses.MULTICLASS_MODE, class_weights=[10.0, 20.0, 5.0] |
| 405 | + )(y_pred, y_true) |
| 406 | + assert torch.allclose(loss_w, loss_w_scaled, atol=eps), ( |
| 407 | + f"Loss should be scale-invariant w.r.t. class_weights for {loss_cls.__name__}" |
| 408 | + ) |
| 409 | + |
| 410 | + |
| 411 | +@torch.inference_mode() |
| 412 | +def test_class_weights_binary_mode_raises(): |
| 413 | + """class_weights should raise an error when used with binary mode.""" |
| 414 | + for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]: |
| 415 | + with pytest.raises(ValueError): |
| 416 | + loss_cls(mode=smp.losses.BINARY_MODE, class_weights=[1.0, 2.0]) |
| 417 | + |
| 418 | + |
| 419 | +@torch.inference_mode() |
| 420 | +def test_focal_class_weights_uniform_equivalent_to_no_weights(): |
| 421 | + """Uniform class_weights should produce a loss equivalent to no-weights loss.""" |
| 422 | + eps = 1e-5 |
| 423 | + torch.manual_seed(0) |
| 424 | + y_pred = torch.randn(2, 3, 4, 4) |
| 425 | + y_true = torch.randint(0, 3, (2, 4, 4)) |
| 426 | + |
| 427 | + loss_no_w = FocalLoss(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true) |
| 428 | + loss_uniform = FocalLoss( |
| 429 | + mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 1.0, 1.0] |
| 430 | + )(y_pred, y_true) |
| 431 | + assert torch.allclose(loss_no_w, loss_uniform, atol=eps) |
| 432 | + |
| 433 | + |
| 434 | +@torch.inference_mode() |
| 435 | +def test_focal_class_weights_scale_invariant(): |
| 436 | + """Scaling all weights by a constant should not change FocalLoss.""" |
| 437 | + eps = 1e-5 |
| 438 | + torch.manual_seed(0) |
| 439 | + y_pred = torch.randn(2, 3, 4, 4) |
| 440 | + y_true = torch.randint(0, 3, (2, 4, 4)) |
| 441 | + |
| 442 | + loss_w = FocalLoss(mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5])( |
| 443 | + y_pred, y_true |
| 444 | + ) |
| 445 | + loss_w_scaled = FocalLoss( |
| 446 | + mode=smp.losses.MULTICLASS_MODE, class_weights=[10.0, 20.0, 5.0] |
| 447 | + )(y_pred, y_true) |
| 448 | + assert torch.allclose(loss_w, loss_w_scaled, atol=eps) |
0 commit comments