-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Focal loss optimisation #1236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vedantdalimkar
wants to merge
4
commits into
qubvel-org:main
Choose a base branch
from
vedantdalimkar:focal_loss_optimisation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Focal loss optimisation #1236
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
195 changes: 195 additions & 0 deletions
195
segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "245a88c9", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from segmentation_models_pytorch.losses import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE\n", | ||
| "from time import time\n", | ||
| "from typing import Optional\n", | ||
| "import torch\n", | ||
| "import segmentation_models_pytorch\n", | ||
| "\n", | ||
| "class FocalLossVectorised(segmentation_models_pytorch.losses.FocalLoss):\n", | ||
| " def __init__(\n", | ||
| " self,\n", | ||
| " mode: str,\n", | ||
| " alpha: Optional[float] = None,\n", | ||
| " gamma: Optional[float] = 2.0,\n", | ||
| " ignore_index: Optional[int] = None,\n", | ||
| " reduction: Optional[str] = \"mean\",\n", | ||
| " normalized: bool = False,\n", | ||
| " reduced_threshold: Optional[float] = None,\n", | ||
| " ):\n", | ||
| " \n", | ||
| " super().__init__(mode = mode,alpha = alpha,gamma = gamma, ignore_index = ignore_index,reduction = reduction,\n", | ||
| " normalized = normalized,reduced_threshold = reduced_threshold)\n", | ||
| " \n", | ||
| " def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n", | ||
| " if self.mode in {BINARY_MODE, MULTILABEL_MODE}:\n", | ||
| " y_true = y_true.view(-1)\n", | ||
| " y_pred = y_pred.view(-1)\n", | ||
| "\n", | ||
| " if self.ignore_index is not None:\n", | ||
| " # Filter predictions with ignore label from loss computation\n", | ||
| " not_ignored = y_true != self.ignore_index\n", | ||
| " y_pred = y_pred[not_ignored]\n", | ||
| " y_true = y_true[not_ignored]\n", | ||
| "\n", | ||
| " loss = self.focal_loss_fn(y_pred, y_true)\n", | ||
| "\n", | ||
| " elif self.mode == MULTICLASS_MODE:\n", | ||
| " num_classes = y_pred.size(1)\n", | ||
| "\n", | ||
| " if self.ignore_index is not None:\n", | ||
| " y_true[y_true == self.ignore_index] = num_classes\n", | ||
| " y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1)\n", | ||
| " y_true_one_hot = y_true_one_hot[ : , : , : , : -1]\n", | ||
| "\n", | ||
| " else: \n", | ||
| " y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes)\n", | ||
| "\n", | ||
| " y_true_one_hot = torch.permute(y_true_one_hot,(0,3,1,2))\n", | ||
| " loss = num_classes * self.focal_loss_fn(y_pred, y_true_one_hot)\n", | ||
| "\n", | ||
| " return loss" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
|
Check failure on line 62 in segmentation_models_pytorch/losses/focal_loss_optimisation_benchmarking.ipynb
|
||
| "execution_count": 2, | ||
| "id": "9c64c3ea", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "num_classes = 20\n", | ||
| "batch_size = 128\n", | ||
| "resolution = 512\n", | ||
| "device = 'cuda:1' if torch.cuda.is_available() else 'cpu'" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 3, | ||
| "id": "d4d3a5f5", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "vectorised_loss_fn = FocalLossVectorised(mode = 'multiclass',ignore_index = num_classes)\n", | ||
| "loss_fn = segmentation_models_pytorch.losses.FocalLoss(mode = 'multiclass',ignore_index = num_classes)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 4, | ||
| "id": "5e49a5b8", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "predictions = torch.randn((batch_size,num_classes,resolution,resolution)).to(device = device)\n", | ||
| "labels = torch.randint(low = 0,high = num_classes+1,size = (batch_size,resolution,resolution)).to(device = device)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 5, | ||
| "id": "36a0f89b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def benchmark(function,predictions,labels,benchmark_iterations = 100):\n", | ||
| " start_time = time()\n", | ||
| "\n", | ||
| " for _ in range(benchmark_iterations):\n", | ||
| " loss = function(predictions,labels)\n", | ||
| "\n", | ||
| " end_time = time()\n", | ||
| "\n", | ||
| " average_time_taken = (end_time - start_time) / (benchmark_iterations)\n", | ||
| "\n", | ||
| " print(f\"Average time taken by function {function} is {average_time_taken} seconds\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 6, | ||
| "id": "8de20667", | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "Average time taken by function FocalLoss() is 0.3390256547927856 seconds\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "benchmark(loss_fn,predictions,labels)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 7, | ||
| "id": "6da16fc2", | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "Average time taken by function FocalLossVectorised() is 0.11771584510803222 seconds\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "benchmark(vectorised_loss_fn,predictions,labels)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "a1083cd4", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "##### CHECKING THAT OUTPUT OF NEW CLASS IS CONSISTENT WITH THE OLD ONE" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 8, | ||
| "id": "8fba3182", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "output_from_vectorised_fn = vectorised_loss_fn(predictions,labels)\n", | ||
| "output_from_old_fn = loss_fn(predictions,labels)\n", | ||
| "\n", | ||
| "assert torch.allclose(output_from_vectorised_fn,output_from_old_fn)" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "torch", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.9.19" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.