1212import argparse
1313import os
1414from functools import partial
15-
1615import numpy as np
1716import torch
1817import torch .distributed as dist
1918import torch .multiprocessing as mp
2019import torch .nn .parallel
2120import torch .utils .data .distributed
22- from networks .unetr import UNETR
23- from optimizers .lr_scheduler import LinearWarmupCosineAnnealingLR
24- from trainer import run_training
25- from utils .data_utils import get_loader
2621
2722from monai .inferers import sliding_window_inference
2823from monai .losses import DiceCELoss , DiceLoss
2924from monai .metrics import DiceMetric
3025from monai .transforms import Activations , AsDiscrete , Compose
3126from monai .utils .enums import MetricReduction
3227
28+ from networks .unetr import UNETR
29+ from optimizers .lr_scheduler import LinearWarmupCosineAnnealingLR
30+ from trainer import run_training
31+ from utils .data_utils import get_loader
32+ from dataset .customDataset import getDatasetLoader
33+
3334parser = argparse .ArgumentParser (description = "UNETR segmentation pipeline" )
3435parser .add_argument ("--checkpoint" , default = None , help = "start training from saved checkpoint" )
3536parser .add_argument ("--logdir" , default = "test" , type = str , help = "directory to save the tensorboard logs" )
3637parser .add_argument (
3738 "--pretrained_dir" , default = "./pretrained_models/" , type = str , help = "pretrained checkpoint directory"
3839)
39- parser .add_argument ("--data_dir" , default = "/dataset/dataset0/" , type = str , help = "dataset directory" )
40+ parser .add_argument ("--btcv" , action = "store_true" , help = "Use BTCV dataset" )
41+ parser .add_argument ("--data_dir" , default = "./dataset/dataset0/" , type = str , help = "dataset directory" )
4042parser .add_argument ("--json_list" , default = "dataset_0.json" , type = str , help = "dataset json file" )
4143parser .add_argument (
4244 "--pretrained_model_name" , default = "UNETR_model_best_acc.pth" , type = str , help = "pretrained model name"
4345)
44- parser .add_argument ("--save_checkpoint" , action = "store_true" , help = "save checkpoint during training" )
45- parser .add_argument ("--max_epochs" , default = 5000 , type = int , help = "max number of training epochs" )
46+ parser .add_argument ("--save_checkpoint" , action = "store_true" , default = True , help = "save checkpoint during training" )
47+ parser .add_argument ("--max_epochs" , default = 100 , type = int , help = "max number of training epochs" )
4648parser .add_argument ("--batch_size" , default = 1 , type = int , help = "number of batch size" )
4749parser .add_argument ("--sw_batch_size" , default = 1 , type = int , help = "number of sliding window batch size" )
4850parser .add_argument ("--optim_lr" , default = 1e-4 , type = float , help = "optimization learning rate" )
4951parser .add_argument ("--optim_name" , default = "adamw" , type = str , help = "optimization algorithm" )
5052parser .add_argument ("--reg_weight" , default = 1e-5 , type = float , help = "regularization weight" )
5153parser .add_argument ("--momentum" , default = 0.99 , type = float , help = "momentum" )
5254parser .add_argument ("--noamp" , action = "store_true" , help = "do NOT use amp for training" )
53- parser .add_argument ("--val_every" , default = 100 , type = int , help = "validation frequency" )
55+ parser .add_argument ("--val_every" , default = 10 , type = int , help = "validation frequency" )
5456parser .add_argument ("--distributed" , action = "store_true" , help = "start distributed training" )
5557parser .add_argument ("--world_size" , default = 1 , type = int , help = "number of nodes for distributed training" )
5658parser .add_argument ("--rank" , default = 0 , type = int , help = "node rank for distributed training" )
5759parser .add_argument ("--dist-url" , default = "tcp://127.0.0.1:23456" , type = str , help = "distributed url" )
5860parser .add_argument ("--dist-backend" , default = "nccl" , type = str , help = "distributed backend" )
5961parser .add_argument ("--workers" , default = 8 , type = int , help = "number of workers" )
6062parser .add_argument ("--model_name" , default = "unetr" , type = str , help = "model name" )
61- parser .add_argument ("--pos_embed" , default = "perceptron " , type = str , help = "type of position embedding" )
63+ parser .add_argument ("--pos_embed" , default = "learnable " , type = str , help = "type of position embedding" )
6264parser .add_argument ("--norm_name" , default = "instance" , type = str , help = "normalization layer type in decoder" )
6365parser .add_argument ("--num_heads" , default = 12 , type = int , help = "number of attention heads in ViT encoder" )
6466parser .add_argument ("--mlp_dim" , default = 3072 , type = int , help = "mlp dimention in ViT encoder" )
7375parser .add_argument ("--a_max" , default = 250.0 , type = float , help = "a_max in ScaleIntensityRanged" )
7476parser .add_argument ("--b_min" , default = 0.0 , type = float , help = "b_min in ScaleIntensityRanged" )
7577parser .add_argument ("--b_max" , default = 1.0 , type = float , help = "b_max in ScaleIntensityRanged" )
76- parser .add_argument ("--space_x" , default = 1.5 , type = float , help = "spacing in x direction" )
77- parser .add_argument ("--space_y" , default = 1.5 , type = float , help = "spacing in y direction" )
78- parser .add_argument ("--space_z" , default = 2 .0 , type = float , help = "spacing in z direction" )
79- parser .add_argument ("--roi_x" , default = 96 , type = int , help = "roi size in x direction" )
80- parser .add_argument ("--roi_y" , default = 96 , type = int , help = "roi size in y direction" )
81- parser .add_argument ("--roi_z" , default = 96 , type = int , help = "roi size in z direction" )
78+ parser .add_argument ("--space_x" , default = 1.0 , type = float , help = "spacing in x direction" )
79+ parser .add_argument ("--space_y" , default = 1.0 , type = float , help = "spacing in y direction" )
80+ parser .add_argument ("--space_z" , default = 1 .0 , type = float , help = "spacing in z direction" )
81+ parser .add_argument ("--roi_x" , default = 64 , type = int , help = "roi size in x direction" )
82+ parser .add_argument ("--roi_y" , default = 64 , type = int , help = "roi size in y direction" )
83+ parser .add_argument ("--roi_z" , default = 64 , type = int , help = "roi size in z direction" )
8284parser .add_argument ("--dropout_rate" , default = 0.0 , type = float , help = "dropout rate" )
8385parser .add_argument ("--RandFlipd_prob" , default = 0.2 , type = float , help = "RandFlipd aug probability" )
8486parser .add_argument ("--RandRotate90d_prob" , default = 0.2 , type = float , help = "RandRotate90d aug probability" )
@@ -102,10 +104,9 @@ def main():
102104 print ("Found total gpus" , args .ngpus_per_node )
103105 args .world_size = args .ngpus_per_node * args .world_size
104106 mp .spawn (main_worker , nprocs = args .ngpus_per_node , args = (args ,))
105- else :
107+ else :
106108 main_worker (gpu = 0 , args = args )
107109
108-
109110def main_worker (gpu , args ):
110111 if args .distributed :
111112 torch .multiprocessing .set_start_method ("fork" , force = True )
@@ -119,7 +120,8 @@ def main_worker(gpu, args):
119120 torch .cuda .set_device (args .gpu )
120121 torch .backends .cudnn .benchmark = True
121122 args .test_mode = False
122- loader = get_loader (args )
123+ loader = get_loader (args ) if args .btcv else getDatasetLoader (args )
124+
123125 print (args .rank , " gpu" , args .gpu )
124126 if args .rank == 0 :
125127 print ("Batch size is:" , args .batch_size , "epochs" , args .max_epochs )
@@ -157,8 +159,8 @@ def main_worker(gpu, args):
157159 dice_loss = DiceCELoss (
158160 to_onehot_y = True , softmax = True , squared_pred = True , smooth_nr = args .smooth_nr , smooth_dr = args .smooth_dr
159161 )
160- post_label = AsDiscrete (to_onehot = True , n_classes = args .out_channels )
161- post_pred = AsDiscrete (argmax = True , to_onehot = True , n_classes = args .out_channels )
162+ post_label = AsDiscrete (to_onehot = args .out_channels )
163+ post_pred = AsDiscrete (argmax = True , to_onehot = args .out_channels )
162164 dice_acc = DiceMetric (include_background = True , reduction = MetricReduction .MEAN , get_not_nans = True )
163165 model_inferer = partial (
164166 sliding_window_inference ,
@@ -235,6 +237,5 @@ def main_worker(gpu, args):
235237 )
236238 return accuracy
237239
238-
239240if __name__ == "__main__" :
240241 main ()
0 commit comments