Skip to content

Commit f0b3dc4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 832618e commit f0b3dc4

6 files changed

Lines changed: 10 additions & 10 deletions

File tree

UNETR/BTCV/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
NIFTI_DATA_ROOT = 'data/images' # nifti image directory
22
NIFTI_LABEL_ROOT = 'data/labels' # nifti label directory
3-
PREDICT_DATA_ROOT = 'data/predict' # predict image directory
3+
PREDICT_DATA_ROOT = 'data/predict' # predict image directory

UNETR/BTCV/dataset/customDataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def collate_fn(batch):
2323

2424
images = torch.stack(images, dim=0)
2525
labels = torch.stack(labels, dim=0)
26-
26+
2727
return [torch.Tensor(images), torch.Tensor(labels)]
28-
28+
2929
return collate_fn
3030

3131
def getDatasetLoader(args):
@@ -80,7 +80,7 @@ def getDatasetLoader(args):
8080
transforms.ToTensord(keys=["image", "label"]),
8181
]
8282
)
83-
83+
8484
trainDataset = Dataset(data=trainDicts, transform=train_transform)
8585
valDataset = Dataset(data=valDicts, transform=val_transform)
8686
trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True))

UNETR/BTCV/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def main():
104104
print("Found total gpus", args.ngpus_per_node)
105105
args.world_size = args.ngpus_per_node * args.world_size
106106
mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,))
107-
else:
107+
else:
108108
main_worker(gpu=0, args=args)
109109

110110
def main_worker(gpu, args):
@@ -121,7 +121,7 @@ def main_worker(gpu, args):
121121
torch.backends.cudnn.benchmark = True
122122
args.test_mode = False
123123
loader = get_loader(args) if args.btcv else getDatasetLoader(args)
124-
124+
125125
print(args.rank, " gpu", args.gpu)
126126
if args.rank == 0:
127127
print("Batch size is:", args.batch_size, "epochs", args.max_epochs)

UNETR/BTCV/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ numpy==2.3.2
33
opencv_python
44
simpleitk==2.5.2
55
tensorboardx==2.6.4
6-
torch
6+
torch

UNETR/BTCV/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def main():
138138
print_log=True,
139139
)
140140
])
141-
141+
142142
for d in loader:
143-
143+
144144
input_data = d['image'].cuda() # (b, c, h, w, d)
145145
predict_raw = inference(input_data, model, args) # shape: (B, H, W, D)
146146
predict_tensor = torch.from_numpy(predict_raw.astype(np.float32)) # shape: (B, H, W, D)

UNETR/BTCV/utils/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,4 @@ def get_loader(args):
165165
)
166166
loader = [train_loader, val_loader]
167167

168-
return loader
168+
return loader

0 commit comments

Comments
 (0)