Skip to content

Commit b5b33e1

Browse files
1. Add customDatasetLoader 2. Add predict function in test.py
1 parent 21ed8e5 commit b5b33e1

8 files changed

Lines changed: 236 additions & 71 deletions

File tree

UNETR/BTCV/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
NIFTI_DATA_ROOT = 'data/images' # nifti image directory
2+
NIFTI_LABEL_ROOT = 'data/labels' # nifti label directory
3+
PREDICT_DATA_ROOT = 'data/predict' # predict image directory
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import os
2+
from torch.utils.data import DataLoader
3+
from monai.data import Dataset
4+
import monai.transforms as transforms
5+
import torch
6+
7+
from config import NIFTI_DATA_ROOT, NIFTI_LABEL_ROOT, PREDICT_DATA_ROOT
8+
9+
def _get_collate_fn(isTrain:bool):
10+
def collate_fn(batch):
11+
'''collate function'''
12+
images = []
13+
labels = []
14+
if isTrain:
15+
for p in batch: # [ {"image": (C, H, W ,D), "label": (C, H, W ,D)} , ...]
16+
for i in range(len(p)): # list, RandCropByPosNegLabeld will produce multiple samples
17+
images.append(p[i]['image'])
18+
labels.append(p[i]['label'])
19+
else:
20+
for p in batch:
21+
images.append(p['image'])
22+
labels.append(p['label'])
23+
24+
images = torch.stack(images, dim=0)
25+
labels = torch.stack(labels, dim=0)
26+
27+
return [torch.Tensor(images), torch.Tensor(labels)]
28+
29+
return collate_fn
30+
31+
def getDatasetLoader(args):
32+
dataName = [d for d in os.listdir(NIFTI_LABEL_ROOT)]
33+
dataDicts = [{"image": f"{os.path.join(NIFTI_DATA_ROOT, d)}", "label": f"{os.path.join(NIFTI_LABEL_ROOT, d)}"} for d in dataName]
34+
trainDicts, valDicts = _splitList(dataDicts)
35+
36+
train_transform = transforms.Compose(
37+
[
38+
transforms.LoadImaged(keys=["image", "label"]),
39+
transforms.EnsureChannelFirstd(keys=["image", "label"]),
40+
transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
41+
transforms.Spacingd(
42+
keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest")
43+
),
44+
transforms.ScaleIntensityRanged(
45+
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
46+
),
47+
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
48+
transforms.RandCropByPosNegLabeld(
49+
keys=["image", "label"],
50+
label_key="label",
51+
spatial_size=(args.roi_x, args.roi_y, args.roi_z),
52+
pos=1,
53+
neg=1,
54+
num_samples=4,
55+
image_key="image",
56+
image_threshold=0,
57+
),
58+
transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=0),
59+
transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=1),
60+
transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=2),
61+
transforms.RandRotate90d(keys=["image", "label"], prob=args.RandRotate90d_prob, max_k=3),
62+
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=args.RandScaleIntensityd_prob),
63+
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=args.RandShiftIntensityd_prob),
64+
transforms.ToTensord(keys=["image", "label"]),
65+
]
66+
)
67+
68+
val_transform = transforms.Compose(
69+
[
70+
transforms.LoadImaged(keys=["image", "label"]),
71+
transforms.EnsureChannelFirstd(keys=["image", "label"]),
72+
transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
73+
transforms.Spacingd(
74+
keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest")
75+
),
76+
transforms.ScaleIntensityRanged(
77+
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
78+
),
79+
transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
80+
transforms.ToTensord(keys=["image", "label"]),
81+
]
82+
)
83+
84+
trainDataset = Dataset(data=trainDicts, transform=train_transform)
85+
valDataset = Dataset(data=valDicts, transform=val_transform)
86+
trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True))
87+
valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False))
88+
loader = [trainLoader, valLoader]
89+
90+
return loader
91+
92+
def _splitList(l, trainRatio:float = 0.8):
93+
totalNum = len(l)
94+
splitIdx = int(totalNum * trainRatio)
95+
96+
return l[:splitIdx], l[splitIdx :]
97+
98+
def getPredictLoader(args):
99+
dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
100+
dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]
101+
102+
preTransform = transforms.Compose(
103+
[
104+
transforms.LoadImaged(keys=["image"]),
105+
transforms.EnsureChannelFirstd(keys=["image"]),
106+
transforms.Orientationd(keys=["image"], axcodes="RAS"),
107+
transforms.Spacingd(
108+
keys=["image"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear")
109+
),
110+
transforms.ScaleIntensityRanged(
111+
keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
112+
),
113+
transforms.CropForegroundd(keys=["image"], source_key="image", allow_smaller=True),
114+
# transforms.ToTensord(keys=["image"],track_meta=True), # This transformation will transform MetaTensor to Tensor
115+
]
116+
)
117+
valDataset = Dataset(data=dataDicts, transform=preTransform)
118+
valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers)
119+
120+
return valLoader, preTransform

UNETR/BTCV/main.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,53 +12,55 @@
1212
import argparse
1313
import os
1414
from functools import partial
15-
1615
import numpy as np
1716
import torch
1817
import torch.distributed as dist
1918
import torch.multiprocessing as mp
2019
import torch.nn.parallel
2120
import torch.utils.data.distributed
22-
from networks.unetr import UNETR
23-
from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
24-
from trainer import run_training
25-
from utils.data_utils import get_loader
2621

2722
from monai.inferers import sliding_window_inference
2823
from monai.losses import DiceCELoss, DiceLoss
2924
from monai.metrics import DiceMetric
3025
from monai.transforms import Activations, AsDiscrete, Compose
3126
from monai.utils.enums import MetricReduction
3227

28+
from networks.unetr import UNETR
29+
from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
30+
from trainer import run_training
31+
from utils.data_utils import get_loader
32+
from dataset.customDataset import getDatasetLoader
33+
3334
parser = argparse.ArgumentParser(description="UNETR segmentation pipeline")
3435
parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint")
3536
parser.add_argument("--logdir", default="test", type=str, help="directory to save the tensorboard logs")
3637
parser.add_argument(
3738
"--pretrained_dir", default="./pretrained_models/", type=str, help="pretrained checkpoint directory"
3839
)
39-
parser.add_argument("--data_dir", default="/dataset/dataset0/", type=str, help="dataset directory")
40+
parser.add_argument("--btcv", action="store_true", help="Use BTCV dataset")
41+
parser.add_argument("--data_dir", default="./dataset/dataset0/", type=str, help="dataset directory")
4042
parser.add_argument("--json_list", default="dataset_0.json", type=str, help="dataset json file")
4143
parser.add_argument(
4244
"--pretrained_model_name", default="UNETR_model_best_acc.pth", type=str, help="pretrained model name"
4345
)
44-
parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training")
45-
parser.add_argument("--max_epochs", default=5000, type=int, help="max number of training epochs")
46+
parser.add_argument("--save_checkpoint", action="store_true", default=True, help="save checkpoint during training")
47+
parser.add_argument("--max_epochs", default=100, type=int, help="max number of training epochs")
4648
parser.add_argument("--batch_size", default=1, type=int, help="number of batch size")
4749
parser.add_argument("--sw_batch_size", default=1, type=int, help="number of sliding window batch size")
4850
parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate")
4951
parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm")
5052
parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight")
5153
parser.add_argument("--momentum", default=0.99, type=float, help="momentum")
5254
parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training")
53-
parser.add_argument("--val_every", default=100, type=int, help="validation frequency")
55+
parser.add_argument("--val_every", default=10, type=int, help="validation frequency")
5456
parser.add_argument("--distributed", action="store_true", help="start distributed training")
5557
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
5658
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
5759
parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url")
5860
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
5961
parser.add_argument("--workers", default=8, type=int, help="number of workers")
6062
parser.add_argument("--model_name", default="unetr", type=str, help="model name")
61-
parser.add_argument("--pos_embed", default="perceptron", type=str, help="type of position embedding")
63+
parser.add_argument("--pos_embed", default="learnable", type=str, help="type of position embedding")
6264
parser.add_argument("--norm_name", default="instance", type=str, help="normalization layer type in decoder")
6365
parser.add_argument("--num_heads", default=12, type=int, help="number of attention heads in ViT encoder")
6466
parser.add_argument("--mlp_dim", default=3072, type=int, help="mlp dimention in ViT encoder")
@@ -73,12 +75,12 @@
7375
parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
7476
parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
7577
parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
76-
parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
77-
parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
78-
parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
79-
parser.add_argument("--roi_x", default=96, type=int, help="roi size in x direction")
80-
parser.add_argument("--roi_y", default=96, type=int, help="roi size in y direction")
81-
parser.add_argument("--roi_z", default=96, type=int, help="roi size in z direction")
78+
parser.add_argument("--space_x", default=1.0, type=float, help="spacing in x direction")
79+
parser.add_argument("--space_y", default=1.0, type=float, help="spacing in y direction")
80+
parser.add_argument("--space_z", default=1.0, type=float, help="spacing in z direction")
81+
parser.add_argument("--roi_x", default=64, type=int, help="roi size in x direction")
82+
parser.add_argument("--roi_y", default=64, type=int, help="roi size in y direction")
83+
parser.add_argument("--roi_z", default=64, type=int, help="roi size in z direction")
8284
parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
8385
parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability")
8486
parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability")
@@ -102,10 +104,9 @@ def main():
102104
print("Found total gpus", args.ngpus_per_node)
103105
args.world_size = args.ngpus_per_node * args.world_size
104106
mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,))
105-
else:
107+
else:
106108
main_worker(gpu=0, args=args)
107109

108-
109110
def main_worker(gpu, args):
110111
if args.distributed:
111112
torch.multiprocessing.set_start_method("fork", force=True)
@@ -119,7 +120,8 @@ def main_worker(gpu, args):
119120
torch.cuda.set_device(args.gpu)
120121
torch.backends.cudnn.benchmark = True
121122
args.test_mode = False
122-
loader = get_loader(args)
123+
loader = get_loader(args) if args.btcv else getDatasetLoader(args)
124+
123125
print(args.rank, " gpu", args.gpu)
124126
if args.rank == 0:
125127
print("Batch size is:", args.batch_size, "epochs", args.max_epochs)
@@ -157,8 +159,8 @@ def main_worker(gpu, args):
157159
dice_loss = DiceCELoss(
158160
to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr
159161
)
160-
post_label = AsDiscrete(to_onehot=True, n_classes=args.out_channels)
161-
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.out_channels)
162+
post_label = AsDiscrete(to_onehot=args.out_channels)
163+
post_pred = AsDiscrete(argmax=True, to_onehot=args.out_channels)
162164
dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=True)
163165
model_inferer = partial(
164166
sliding_window_inference,
@@ -235,6 +237,5 @@ def main_worker(gpu, args):
235237
)
236238
return accuracy
237239

238-
239240
if __name__ == "__main__":
240241
main()

UNETR/BTCV/networks/unetr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
if hidden_size % num_heads != 0:
7474
raise AssertionError("hidden size should be divisible by num_heads.")
7575

76-
if pos_embed not in ["conv", "perceptron"]:
76+
if pos_embed not in ['sincos', 'learnable', 'none']:
7777
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
7878

7979
self.num_layers = 12
@@ -93,7 +93,7 @@ def __init__(
9393
mlp_dim=mlp_dim,
9494
num_layers=self.num_layers,
9595
num_heads=num_heads,
96-
pos_embed=pos_embed,
96+
pos_embed_type=pos_embed,
9797
classification=self.classification,
9898
dropout_rate=dropout_rate,
9999
)

UNETR/BTCV/requirements.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
torch==1.9.1
2-
monai==0.7.0
3-
nibabel==3.1.1
4-
tqdm==4.59.0
5-
einops==0.3.0
6-
tensorboardX==2.1
1+
monai==1.5.0
2+
numpy==2.3.2
3+
opencv_python
4+
simpleitk==2.5.2
5+
tensorboardx==2.6.4
6+
torch

0 commit comments

Comments
 (0)