-
Notifications
You must be signed in to change notification settings - Fork 946
Expand file tree
/
Copy pathchanges.py
More file actions
168 lines (161 loc) · 6.98 KB
/
changes.py
File metadata and controls
168 lines (161 loc) · 6.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) Facebook, Inc. and its affiliates.
import copy
import json
import os
import torch
from mmf.common.sample import Sample, SampleList
from mmf.datasets.builders.vqa2 import VQA2Dataset
from mmf.datasets.databases.scene_graph_database import SceneGraphDatabase
from mmf.utils.configuration import get_mmf_env
from mmf.utils.general import get_absolute_path
_CONSTANTS = {"image_id_key": "image_id"}
class VisualGenomeDataset(VQA2Dataset):
def __init__(self, config, dataset_type, imdb_file_index, *args, **kwargs):
super().__init__(
config,
dataset_type,
imdb_file_index,
dataset_name="visual_genome",
*args,
**kwargs
)
self._return_scene_graph = config.return_scene_graph
self._return_objects = config.return_objects
self._return_relationships = config.return_relationships
self._no_unk = config.get("no_unk", False)
self.scene_graph_db = None
build_scene_graph_db = (
self._return_scene_graph
or self._return_objects
or self._return_relationships
)
if build_scene_graph_db:
scene_graph_file = config.scene_graph_files[dataset_type][imdb_file_index]
scene_graph_file = self._get_absolute_path(scene_graph_file)
self.scene_graph_db = SceneGraphDatabase(config, scene_graph_file)
def load_item(self, idx):
sample_info = self.annotation_db[idx]
sample_info = self._preprocess_answer(sample_info)
sample_info["question_id"] = sample_info["id"]
if self._check_unk(sample_info):
return self.load_item((idx + 1) % len(self.annotation_db))
current_sample = super().load_item(idx)
current_sample = self._load_scene_graph(idx, current_sample)
return current_sample
def _get_absolute_path(self, scene_graph_file):
data_dir = get_mmf_env(key="data_dir")
absolute_scene_graph_file = get_absolute_path(
os.path.join(data_dir, scene_graph_file)
)
return absolute_scene_graph_file
def _get_image_id(self, idx):
return self.annotation_db[idx][_CONSTANTS["image_id_key"]]
def _get_image_info(self, idx):
# Deep copy so that we can directly update the nested dicts
return copy.deepcopy(self.scene_graph_db[self._get_image_id(idx)])
def _preprocess_answer(self, sample_info):
sample_info["answers"] = [
self.vg_answer_preprocessor(
{"text": sample_info["answers"][0]},
remove=["?", ",", ".", "a", "an", "the"],
)["text"]
]
return sample_info
def _check_unk(self, sample_info):
if not self._no_unk:
return False
else:
index = self.answer_processor.word2idx(sample_info["answers"][0])
return index == self.answer_processor.answer_vocab.UNK_INDEX
def _load_scene_graph(self, idx, sample):
if self.scene_graph_db is None:
return sample
image_info = self._get_image_info(idx)
regions = image_info["regions"]
objects, object_map = self._load_objects(idx)
if self._return_objects:
sample.objects = objects
relationships, relationship_map = self._load_relationships(idx, object_map)
if self._return_relationships:
sample.relationships = relationships
regions, _ = self._load_regions(idx, object_map, relationship_map)
if self._return_scene_graph:
sample.scene_graph = regions
return sample
def _load_objects(self, idx):
image_info = self._get_image_info(idx)
image_height = image_info["height"]
image_width = image_info["width"]
object_map = {}
objects = []
for obj in image_info["objects"]:
obj["synsets"] = self.synset_processor({"tokens": obj["synsets"]})["text"]
obj["names"] = self.name_processor({"tokens": obj["names"]})["text"]
obj["height"] = obj["h"] / image_height
obj.pop("h")
obj["width"] = obj["w"] / image_width
obj.pop("w")
obj["y"] /= image_height
obj["x"] /= image_width
obj["attributes"] = self.attribute_processor({"tokens": obj["attributes"]})[
"text"
]
obj = Sample(obj)
object_map[obj["object_id"]] = obj
objects.append(obj)
objects = SampleList(objects)
return objects, object_map
def _load_relationships(self, idx, object_map):
if self._return_relationships is None and self._return_scene_graph is None:
return None, None
image_info = self._get_image_info(idx)
relationship_map = {}
relationships = []
for relationship in image_info["relationships"]:
relationship["synsets"] = self.synset_processor(
{"tokens": relationship["synsets"]}
)["text"]
relationship["predicate"] = self.predicate_processor(
{"tokens": relationship["predicate"]}
)["text"]
relationship["object"] = object_map[relationship["object_id"]]
relationship["subject"] = object_map[relationship["subject_id"]]
relationship = Sample(relationship)
relationship_map[relationship["relationship_id"]] = relationship
relationships.append(relationship)
relationships = SampleList(relationships)
return relationships, relationship_map
def _load_regions(self, idx, object_map, relationship_map):
if self._return_scene_graph is None:
return None, None
image_info = self._get_image_info(idx)
image_height = image_info["height"]
image_width = image_info["width"]
region_map = {}
regions = []
for region in image_info["regions"]:
for synset in region["synsets"]:
synset["entity_name"] = self.name_processor(
{"tokens": [synset["entity_name"]]}
)["text"]
synset["synset_name"] = self.synset_processor(
{"tokens": [synset["synset_name"]]}
)["text"]
region["height"] /= image_height
region["width"] /= image_width
region["y"] /= image_height
region["x"] /= image_width
relationships = []
objects = []
for relationship_idx in region["relationships"]:
relationships.append(relationship_map[relationship_idx])
for object_idx in region["objects"]:
objects.append(object_map[object_idx])
region["relationships"] = relationships
region["objects"] = objects
region["phrase"] = self.text_processor({"text": region["phrase"]})["text"]
region = Sample(region)
region_map[region["region_id"]] = region
regions.append(region)
regions = SampleList(regions)
return regions, region_map