-
Notifications
You must be signed in to change notification settings - Fork 387
(UNETR) : Add predict label function and custom dataloader which can train with own data . #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
b5b33e1
832618e
f0b3dc4
9a75007
18d35a9
4b366df
30bc5b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| NIFTI_DATA_ROOT = 'data/images' # nifti image directory | ||
| NIFTI_LABEL_ROOT = 'data/labels' # nifti label directory | ||
| PREDICT_DATA_ROOT = 'data/predict' # predict image directory |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,120 @@ | ||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||
| from torch.utils.data import DataLoader | ||||||||||||||||||||||||||||
| from monai.data import Dataset | ||||||||||||||||||||||||||||
| import monai.transforms as transforms | ||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from config import NIFTI_DATA_ROOT, NIFTI_LABEL_ROOT, PREDICT_DATA_ROOT | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _get_collate_fn(isTrain:bool): | ||||||||||||||||||||||||||||
| def collate_fn(batch): | ||||||||||||||||||||||||||||
| '''collate function''' | ||||||||||||||||||||||||||||
| images = [] | ||||||||||||||||||||||||||||
| labels = [] | ||||||||||||||||||||||||||||
| if isTrain: | ||||||||||||||||||||||||||||
| for p in batch: # [ {"image": (C, H, W ,D), "label": (C, H, W ,D)} , ...] | ||||||||||||||||||||||||||||
| for i in range(len(p)): # list, RandCropByPosNegLabeld will produce multiple samples | ||||||||||||||||||||||||||||
| images.append(p[i]['image']) | ||||||||||||||||||||||||||||
| labels.append(p[i]['label']) | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| for p in batch: | ||||||||||||||||||||||||||||
| images.append(p['image']) | ||||||||||||||||||||||||||||
| labels.append(p['label']) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| images = torch.stack(images, dim=0) | ||||||||||||||||||||||||||||
| labels = torch.stack(labels, dim=0) | ||||||||||||||||||||||||||||
| # keep images float and labels long for loss functions | ||||||||||||||||||||||||||||
| return [images.float(), labels.long()] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return collate_fn | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def getDatasetLoader(args): | ||||||||||||||||||||||||||||
| dataName = [d for d in os.listdir(NIFTI_LABEL_ROOT)] | ||||||||||||||||||||||||||||
| dataDicts = [{"image": f"{os.path.join(NIFTI_DATA_ROOT, d)}", "label": f"{os.path.join(NIFTI_LABEL_ROOT, d)}"} for d in dataName] | ||||||||||||||||||||||||||||
| trainDicts, valDicts = _splitList(dataDicts) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||||||||||||||||
| train_transform = transforms.Compose( | ||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||
| transforms.LoadImaged(keys=["image", "label"]), | ||||||||||||||||||||||||||||
| transforms.EnsureChannelFirstd(keys=["image", "label"]), | ||||||||||||||||||||||||||||
| transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), | ||||||||||||||||||||||||||||
| transforms.Spacingd( | ||||||||||||||||||||||||||||
| keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") | ||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||
| transforms.ScaleIntensityRanged( | ||||||||||||||||||||||||||||
| keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True | ||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||
| transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), | ||||||||||||||||||||||||||||
| transforms.RandCropByPosNegLabeld( | ||||||||||||||||||||||||||||
| keys=["image", "label"], | ||||||||||||||||||||||||||||
| label_key="label", | ||||||||||||||||||||||||||||
| spatial_size=(args.roi_x, args.roi_y, args.roi_z), | ||||||||||||||||||||||||||||
| pos=1, | ||||||||||||||||||||||||||||
| neg=1, | ||||||||||||||||||||||||||||
| num_samples=4, | ||||||||||||||||||||||||||||
| image_key="image", | ||||||||||||||||||||||||||||
| image_threshold=0, | ||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||
| transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=0), | ||||||||||||||||||||||||||||
| transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=1), | ||||||||||||||||||||||||||||
| transforms.RandFlipd(keys=["image", "label"], prob=args.RandFlipd_prob, spatial_axis=2), | ||||||||||||||||||||||||||||
| transforms.RandRotate90d(keys=["image", "label"], prob=args.RandRotate90d_prob, max_k=3), | ||||||||||||||||||||||||||||
| transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=args.RandScaleIntensityd_prob), | ||||||||||||||||||||||||||||
| transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=args.RandShiftIntensityd_prob), | ||||||||||||||||||||||||||||
| transforms.ToTensord(keys=["image", "label"]), | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| val_transform = transforms.Compose( | ||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||
| transforms.LoadImaged(keys=["image", "label"]), | ||||||||||||||||||||||||||||
| transforms.EnsureChannelFirstd(keys=["image", "label"]), | ||||||||||||||||||||||||||||
| transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), | ||||||||||||||||||||||||||||
| transforms.Spacingd( | ||||||||||||||||||||||||||||
| keys=["image", "label"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear", "nearest") | ||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||
| transforms.ScaleIntensityRanged( | ||||||||||||||||||||||||||||
| keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True | ||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||
| transforms.CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True), | ||||||||||||||||||||||||||||
| transforms.ToTensord(keys=["image", "label"]), | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| trainDataset = Dataset(data=trainDicts, transform=train_transform) | ||||||||||||||||||||||||||||
| valDataset = Dataset(data=valDicts, transform=val_transform) | ||||||||||||||||||||||||||||
| trainLoader = DataLoader(trainDataset,batch_size=args.batch_size,shuffle=True,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=True)) | ||||||||||||||||||||||||||||
| valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers, collate_fn=_get_collate_fn(isTrain=False)) | ||||||||||||||||||||||||||||
| loader = [trainLoader, valLoader] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return loader | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def _splitList(l, trainRatio:float = 0.8): | ||||||||||||||||||||||||||||
| totalNum = len(l) | ||||||||||||||||||||||||||||
| splitIdx = int(totalNum * trainRatio) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return l[:splitIdx], l[splitIdx :] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def getPredictLoader(args): | ||||||||||||||||||||||||||||
| dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)] | ||||||||||||||||||||||||||||
| dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
Comment on lines
+105
to
+107
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Filter predict files to NIfTI and validate presence. Unfiltered os.listdir may include non-NIfTI files (.DS_Store, JSON, etc.) and will break LoadImaged. -def getPredictLoader(args):
- dataName = [d for d in os.listdir(PREDICT_DATA_ROOT)]
- dataDicts = [{"image": f"{os.path.join(PREDICT_DATA_ROOT, d)}" } for d in dataName]
+def getPredictLoader(args):
+ exts = (".nii", ".nii.gz")
+ if not os.path.isdir(PREDICT_DATA_ROOT):
+ raise FileNotFoundError(f"PREDICT_DATA_ROOT does not exist: {PREDICT_DATA_ROOT}")
+ files = sorted(
+ f for f in os.listdir(PREDICT_DATA_ROOT)
+ if f.endswith(exts) and os.path.isfile(os.path.join(PREDICT_DATA_ROOT, f))
+ )
+ if not files:
+ raise FileNotFoundError(f"No NIfTI files (.nii, .nii.gz) found in {PREDICT_DATA_ROOT}")
+ dataDicts = [{"image": os.path.join(PREDICT_DATA_ROOT, f)} for f in files]📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||
| preTransform = transforms.Compose( | ||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||
| transforms.LoadImaged(keys=["image"]), | ||||||||||||||||||||||||||||
| transforms.EnsureChannelFirstd(keys=["image"]), | ||||||||||||||||||||||||||||
| transforms.Orientationd(keys=["image"], axcodes="RAS"), | ||||||||||||||||||||||||||||
| transforms.Spacingd( | ||||||||||||||||||||||||||||
| keys=["image"], pixdim=(args.space_x, args.space_y, args.space_z), mode=("bilinear") | ||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||
| transforms.ScaleIntensityRanged( | ||||||||||||||||||||||||||||
| keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True | ||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||
| transforms.CropForegroundd(keys=["image"], source_key="image", allow_smaller=True), | ||||||||||||||||||||||||||||
| transforms.EnsureTyped(keys=["image"], track_meta=True), | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
tylin7111095022 marked this conversation as resolved.
|
||||||||||||||||||||||||||||
| valDataset = Dataset(data=dataDicts, transform=preTransform) | ||||||||||||||||||||||||||||
| valLoader = DataLoader(valDataset,batch_size=args.batch_size,shuffle=False,num_workers=args.workers) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return valLoader, preTransform | ||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.