-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathDataLoad_Twoinputs.py
More file actions
78 lines (68 loc) · 2.8 KB
/
DataLoad_Twoinputs.py
File metadata and controls
78 lines (68 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import torch
# import torch.backends.cudnn as cudnn
# import torch.nn as nn
# import torch.nn.parallel
# import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
# import torchvision.models as models
import torchvision.transforms
import torchvision.transforms as transforms
from PIL import Image
# import glob
import os
import matplotlib.pyplot as plt
from torch.autograd import Variable
class DataLoader(object):
def __init__(self, data_dir, image_size, batch_size=50):
self.base_dir = os.path.dirname(os.path.abspath(__file__))
self.data_dir = data_dir
self.batch_size = batch_size
self.image_size = image_size
self.normalize_mean = [0.485, 0.456, 0.406]
self.normalize_std = [0.229, 0.224, 0.225]
self.data_transforms = {
'train': transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor(),
transforms.Normalize(self.normalize_mean, self.normalize_std)
]),
'val': transforms.Compose([
transforms.Resize([224,224]),
transforms.CenterCrop(self.image_size),
# transforms.CenterCrop([224,224]),
transforms.ToTensor(),
transforms.Normalize(self.normalize_mean, self.normalize_std)
]),
}
self._init_data_sets()
def _init_data_sets(self):
self.data_sets = {x: datasets.ImageFolder(os.path.join(self.data_dir, x), self.data_transforms[x])
for x in ['train', 'val']}
self.data_loaders = {x: torch.utils.data.DataLoader(self.data_sets[x], batch_size=self.batch_size,
shuffle=False, num_workers=4)
for x in ['train', 'val']}
self.data_sizes = {x: len(self.data_sets[x]) for x in ['train', 'val']}
self.data_classes = self.data_sets['train'].classes
def load_data(self, data_set='train'):
return self.data_loaders[data_set]
def show_image(self, tensor, title=None):
inp = tensor.numpy().transpose((1, 2, 0))
# put it back as it solved before in transforms
inp = self.normalize_std * inp + self.normalize_mean
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.show()
def make_predict_inputs(self, image_file):
"""
this will make a image to PyTorch inputs, as the same with training images.
this will return a tensor, default not using CUDA.
:param image_file:
:return:
"""
image = Image.open(image_file)
image_tensor = self.data_transforms['val'](image).float()
image_tensor.unsqueeze_(0)
return Variable(image_tensor)