|
| 1 | +from typing import Any, Dict, Union |
| 2 | + |
1 | 3 | import torch |
2 | 4 | import torch.nn as nn |
3 | 5 |
|
|
7 | 9 | InPlaceABN = None |
8 | 10 |
|
9 | 11 |
|
| 12 | +def get_norm_layer( |
| 13 | + use_norm: Union[bool, str, Dict[str, Any]], out_channels: int |
| 14 | +) -> nn.Module: |
| 15 | + supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") |
| 16 | + |
| 17 | + # Step 1. Convert tot dict representation |
| 18 | + |
| 19 | + ## Check boolean |
| 20 | + if use_norm is True: |
| 21 | + norm_params = {"type": "batchnorm"} |
| 22 | + elif use_norm is False: |
| 23 | + norm_params = {"type": "identity"} |
| 24 | + |
| 25 | + ## Check string |
| 26 | + elif isinstance(use_norm, str): |
| 27 | + norm_str = use_norm.lower() |
| 28 | + if norm_str == "inplace": |
| 29 | + norm_params = { |
| 30 | + "type": "inplace", |
| 31 | + "activation": "leaky_relu", |
| 32 | + "activation_param": 0.0, |
| 33 | + } |
| 34 | + elif norm_str in supported_norms: |
| 35 | + norm_params = {"type": norm_str} |
| 36 | + else: |
| 37 | + raise ValueError( |
| 38 | + f"Unrecognized normalization type string provided: {use_norm}. Should be in " |
| 39 | + f"{supported_norms}" |
| 40 | + ) |
| 41 | + |
| 42 | + ## Check dict |
| 43 | + elif isinstance(use_norm, dict): |
| 44 | + norm_params = use_norm |
| 45 | + |
| 46 | + else: |
| 47 | + raise ValueError( |
| 48 | + f"Invalid type for use_norm should either be a bool (batchnorm/identity), " |
| 49 | + f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}" |
| 50 | + ) |
| 51 | + |
| 52 | + # Step 2. Check if the dict is valid |
| 53 | + if "type" not in norm_params: |
| 54 | + raise ValueError( |
| 55 | + f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'." |
| 56 | + ) |
| 57 | + if norm_params["type"] not in supported_norms: |
| 58 | + raise ValueError( |
| 59 | + f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}" |
| 60 | + ) |
| 61 | + if norm_params["type"] == "inplace" and InPlaceABN is None: |
| 62 | + raise RuntimeError( |
| 63 | + "In order to use `use_norm='inplace'` the inplace_abn package must be installed. Use:\n" |
| 64 | + " $ pip install -U wheel setuptools\n" |
| 65 | + " $ pip install inplace_abn --no-build-isolation\n" |
| 66 | + "Also see: https://github.com/mapillary/inplace_abn" |
| 67 | + ) |
| 68 | + |
| 69 | + # Step 3. Initialize the norm layer |
| 70 | + norm_type = norm_params["type"] |
| 71 | + norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"} |
| 72 | + |
| 73 | + if norm_type == "inplace": |
| 74 | + norm = InPlaceABN(out_channels, **norm_kwargs) |
| 75 | + elif norm_type == "batchnorm": |
| 76 | + norm = nn.BatchNorm2d(out_channels, **norm_kwargs) |
| 77 | + elif norm_type == "identity": |
| 78 | + norm = nn.Identity() |
| 79 | + elif norm_type == "layernorm": |
| 80 | + norm = nn.LayerNorm(out_channels, **norm_kwargs) |
| 81 | + elif norm_type == "instancenorm": |
| 82 | + norm = nn.InstanceNorm2d(out_channels, **norm_kwargs) |
| 83 | + else: |
| 84 | + raise ValueError(f"Unrecognized normalization type: {norm_type}") |
| 85 | + |
| 86 | + return norm |
| 87 | + |
| 88 | + |
10 | 89 | class Conv2dReLU(nn.Sequential): |
11 | 90 | def __init__( |
12 | 91 | self, |
13 | | - in_channels, |
14 | | - out_channels, |
15 | | - kernel_size, |
16 | | - padding=0, |
17 | | - stride=1, |
18 | | - use_batchnorm=True, |
| 92 | + in_channels: int, |
| 93 | + out_channels: int, |
| 94 | + kernel_size: int, |
| 95 | + padding: int = 0, |
| 96 | + stride: int = 1, |
| 97 | + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", |
19 | 98 | ): |
20 | | - if use_batchnorm == "inplace" and InPlaceABN is None: |
21 | | - raise RuntimeError( |
22 | | - "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " |
23 | | - + "To install see: https://github.com/mapillary/inplace_abn" |
24 | | - ) |
| 99 | + norm = get_norm_layer(use_norm, out_channels) |
25 | 100 |
|
| 101 | + is_identity = isinstance(norm, nn.Identity) |
26 | 102 | conv = nn.Conv2d( |
27 | 103 | in_channels, |
28 | 104 | out_channels, |
29 | 105 | kernel_size, |
30 | 106 | stride=stride, |
31 | 107 | padding=padding, |
32 | | - bias=not (use_batchnorm), |
| 108 | + bias=is_identity, |
33 | 109 | ) |
34 | | - relu = nn.ReLU(inplace=True) |
35 | | - |
36 | | - if use_batchnorm == "inplace": |
37 | | - bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) |
38 | | - relu = nn.Identity() |
39 | 110 |
|
40 | | - elif use_batchnorm and use_batchnorm != "inplace": |
41 | | - bn = nn.BatchNorm2d(out_channels) |
42 | | - |
43 | | - else: |
44 | | - bn = nn.Identity() |
| 111 | + is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN) |
| 112 | + activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True) |
45 | 113 |
|
46 | | - super(Conv2dReLU, self).__init__(conv, bn, relu) |
| 114 | + super(Conv2dReLU, self).__init__(conv, norm, activation) |
47 | 115 |
|
48 | 116 |
|
49 | 117 | class SCSEModule(nn.Module): |
|
0 commit comments