-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathkd_forest.cpp
More file actions
125 lines (106 loc) · 3.86 KB
/
kd_forest.cpp
File metadata and controls
125 lines (106 loc) · 3.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <iostream>
#include <pico_toolshed/format/format_bin.hpp>
#include <pico_toolshed/scoped_timer.hpp>
#include <pico_tree/array_traits.hpp>
#include <pico_tree/kd_tree.hpp>
#include <pico_tree/vector_traits.hpp>
#include <pico_understory/kd_forest.hpp>
#include "mnist.hpp"
#include "sift.hpp"
template <typename Vector_, typename Scalar_>
std::size_t run_kd_forest(
std::vector<Vector_> const& train,
std::vector<Vector_> const& test,
std::vector<pico_tree::neighbor<int, Scalar_>> const& nns,
std::size_t forest_size,
std::size_t forest_max_leaf_size,
std::size_t forest_max_leaves_visited) {
using space = std::reference_wrapper<std::vector<Vector_> const>;
auto rkd_tree = [&train, &forest_max_leaf_size, &forest_size]() {
pico_tree::scoped_timer t0("kd_forest build");
return pico_tree::kd_forest<space>(
train, forest_max_leaf_size, forest_size);
}();
pico_tree::scoped_timer t1("kd_forest query");
pico_tree::neighbor<int, Scalar_> nn;
std::size_t equal = 0;
for (std::size_t i = 0; i < nns.size(); ++i) {
rkd_tree.search_nn(test[i], forest_max_leaves_visited, nn);
if (nns[i].index == nn.index) {
++equal;
}
}
return equal;
}
template <typename Vector_, typename Scalar_>
void run_kd_tree(
std::vector<Vector_> const& train,
std::vector<Vector_> const& test,
std::string const& fn_nns_gt,
pico_tree::max_leaf_size_t tree_max_leaf_size,
std::vector<pico_tree::neighbor<int, Scalar_>>& nns) {
using space = std::reference_wrapper<std::vector<Vector_> const>;
nns.resize(test.size());
if (!std::filesystem::exists(fn_nns_gt)) {
std::cout << "Creating " << fn_nns_gt
<< " using the kd_tree. Be *very* patient." << std::endl;
auto kd_tree = [&train, &tree_max_leaf_size]() {
pico_tree::scoped_timer t0("kd_tree build");
return pico_tree::kd_tree<space>(train, tree_max_leaf_size);
}();
{
pico_tree::scoped_timer t1("kd_tree query");
for (std::size_t i = 0; i < nns.size(); ++i) {
kd_tree.search_nn(test[i], nns[i]);
}
}
pico_tree::write_bin(fn_nns_gt, nns);
} else {
pico_tree::read_bin(fn_nns_gt, nns);
std::cout << "kd_tree not created. Read " << fn_nns_gt << " instead."
<< std::endl;
}
}
// A kd_forest takes roughly forest_size times longer to build compared to
// building a kd_tree. However, the kd_forest is usually a lot faster with
// queries in high dimensions with the added trade-off that the exact nearest
// neighbor may not be found.
template <typename Dataset_>
void run_dataset(
std::size_t tree_max_leaf_size,
std::size_t forest_size,
std::size_t forest_max_leaf_size,
std::size_t forest_max_leaves_visited) {
using Point = typename Dataset_::point_type;
using scalar_type = typename Point::value_type;
auto train = Dataset_::read_train();
auto test = Dataset_::read_test();
std::vector<pico_tree::neighbor<int, scalar_type>> nns;
std::string fn_nns_gt = Dataset_::dataset_name + "_nns_gt.bin";
run_kd_tree(train, test, fn_nns_gt, tree_max_leaf_size, nns);
std::size_t equal = run_kd_forest(
train,
test,
nns,
forest_size,
forest_max_leaf_size,
forest_max_leaves_visited);
std::cout << std::setprecision(10);
std::cout << "Precision: "
<< (static_cast<float>(equal) / static_cast<float>(nns.size()))
<< std::endl;
}
int main() {
// forest_max_leaf_size = 16
// forest_max_leaves_visited = 16
// forest_size 8: a precision of around 0.915.
// forest_size 16: a precision of around 0.976.
run_dataset<mnist>(16, 8, 16, 16);
// forest_max_leaf_size = 32
// forest_max_leaves_visited = 64
// forest_size 8: a precision of around 0.884.
// forest_size 16: a precision of around 0.940.
// forest_size 128: out of memory :'(
run_dataset<sift>(16, 8, 32, 64);
return 0;
}