-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPreprocessing.py
More file actions
148 lines (130 loc) · 4.6 KB
/
Preprocessing.py
File metadata and controls
148 lines (130 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Preprocessing script.
This script walks over the directories and dump the frames into a csv file
"""
import os
import csv
import sys
import random
import scipy
import numpy as np
import dicom
from skimage import io, transform
def mkdir(fname):
try:
os.mkdir(fname)
except:
pass
def get_frames(root_path):
"""Get path to all the frame in view SAX and contain complete frames"""
print ('get_frames')
ret = []
for root, _, files in os.walk(root_path):
root=root.replace('\\','/')
files=[s for s in files if ".dcm" in s]
if len(files) == 0 or not files[0].endswith(".dcm") or root.find("sax") == -1:
continue
prefix = files[0].rsplit('-', 1)[0]
fileset = set(files)
expected = ["%s-%04d.dcm" % (prefix, i + 1) for i in range(30)]
if all(x in fileset for x in expected):
ret.append([root + "/" + x for x in expected])
# sort for reproduciblity
return sorted(ret, key = lambda x: x[0])
def get_label_map(fname):
print ('get_label_map')
labelmap = {}
fi = open(fname, newline='')
fi.readline()
for line in fi:
arr = line.split(',')
labelmap[int(arr[0])] = line
return labelmap
def write_label_csv(fname, frames, label_map):
print ('write_label_csv')
fo = open(fname, "w")
for lst in frames:
index = int(lst[0].split("/")[3])
if label_map != None:
fo.write(label_map[index])
else:
fo.write("%d,0,0\n" % index)
fo.close()
def write_data_csv(fname, frames, preproc):
print ('write_data_csv')
"""Write data to csv file"""
fdata = open(fname, "w")
dwriter = csv.writer(fdata)
counter = 0
result = []
for lst in frames:
data = []
for path in lst:
f = dicom.read_file(path)
img = preproc(f.pixel_array.astype(float) / np.max(f.pixel_array))
dst_path = path.rsplit(".", 1)[0] + ".64x64.jpg"
scipy.misc.imsave(dst_path, img)
result.append(dst_path)
data.append(img)
data = np.array(data, dtype=np.uint8)
data = data.reshape(data.size)
dwriter.writerow(data)
counter += 1
if counter % 100 == 0:
print("%d slices processed" % counter)
print("All finished, %d slices in total" % counter)
fdata.close()
return result
def crop_resize(img, size):
"""crop center and resize"""
if img.shape[0] < img.shape[1]:
img = img.T
# we crop image from center
short_egde = min(img.shape[:2])
yy = int((img.shape[0] - short_egde) / 2)
xx = int((img.shape[1] - short_egde) / 2)
crop_img = img[yy : yy + short_egde, xx : xx + short_egde]
# resize to 64, 64
resized_img = transform.resize(crop_img, (size, size))
resized_img *= 255
return resized_img.astype("uint8")
def local_split(train_index):
random.seed(0)
train_index = set(train_index)
all_index = sorted(train_index)
num_test = int(len(all_index) / 3)
random.shuffle(all_index)
train_set = set(all_index[num_test:])
test_set = set(all_index[:num_test])
return train_set, test_set
def split_csv(src_csv, split_to_train, train_csv, test_csv):
ftrain = open(train_csv, "w")
ftest = open(test_csv, "w")
cnt = 0
for l in open(src_csv):
if not l.strip():
if split_to_train[cnt]:
ftrain.write(l)
else:
ftest.write(l)
cnt = cnt + 1
print(cnt)
ftrain.close()
ftest.close()
# Load the list of all the training frames, and shuffle them
# Shuffle the training frames
random.seed(10)
train_frames = get_frames("./data/train")
random.shuffle(train_frames)
validate_frames = get_frames("./data/validate")
# Write the corresponding label information of each frame into file.
write_label_csv("./train-label.csv", train_frames, get_label_map("./data/train.csv"))
write_label_csv("./validate-label.csv", validate_frames, None)
# Dump the data of each frame into a CSV file, apply crop to 64 preprocessor
train_lst = write_data_csv("./train-64x64-data.csv", train_frames, lambda x: crop_resize(x, 64))
valid_lst = write_data_csv("./validate-64x64-data.csv", validate_frames, lambda x: crop_resize(x, 64))
# Generate local train/test split, which you could use to tune your model locally.
train_index = np.loadtxt("./train-label.csv", delimiter=",")[:,0].astype("int")
train_set, test_set = local_split(train_index)
split_to_train = [x in train_set for x in train_index]
split_csv("./train-label.csv", split_to_train, "./local_train-label.csv", "./local_test-label.csv")
split_csv("./train-64x64-data.csv", split_to_train, "./local_train-64x64-data.csv", "./local_test-64x64-data.csv")