Skip to content

Commit 048c418

Browse files
author
sa_ka_na
committed
surface_matching: add binary save/load for PPF3DDetector
1 parent 3544a75 commit 048c418

File tree

2 files changed

+205
-13
lines changed

2 files changed

+205
-13
lines changed

modules/surface_matching/include/opencv2/surface_matching/ppf_match_3d.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,19 @@ class CV_EXPORTS_W PPF3DDetector
140140
*/
141141
CV_WRAP void match(const Mat& scene, CV_OUT std::vector<Pose3DPtr> &results, const double relativeSceneSampleStep=1.0/5.0, const double relativeSceneDistance=0.03);
142142

143-
void read(const FileNode& fn);
144-
void write(FileStorage& fs) const;
143+
/**
144+
* \brief Save trained model to a binary file.
145+
* \param filename Path to the output file.
146+
* \throws cv::Exception if the model is not trained or file cannot be opened.
147+
*/
148+
CV_WRAP void saveModel(const std::string& filename) const;
149+
150+
/**
151+
* \brief Load a previously saved model from a binary file.
152+
* \param filename Path to the input file.
153+
* \throws cv::Exception if the file cannot be opened, data is invalid, or memory allocation fails.
154+
*/
155+
CV_WRAP void loadModel(const std::string& filename);
145156

146157
protected:
147158

@@ -160,6 +171,7 @@ class CV_EXPORTS_W PPF3DDetector
160171
void clearTrainingModels();
161172

162173
private:
174+
hashnode_i* node_pool_;
163175
void computePPFFeatures(const Vec3d& p1, const Vec3d& n1,
164176
const Vec3d& p2, const Vec3d& n2,
165177
Vec4d& f);

modules/surface_matching/src/ppf_match_3d.cpp

Lines changed: 191 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
#include "precomp.hpp"
4242
#include "hash_murmur.hpp"
43+
#include <fstream>
4344

4445
namespace cv
4546
{
@@ -125,6 +126,7 @@ PPF3DDetector::PPF3DDetector()
125126
angle_step_radians = (360.0/angle_step_relative)*M_PI/180.0;
126127
angle_step = angle_step_radians;
127128
trained = false;
129+
node_pool_ = nullptr;
128130

129131
hash_table = NULL;
130132
hash_nodes = NULL;
@@ -141,6 +143,7 @@ PPF3DDetector::PPF3DDetector(const double RelativeSamplingStep, const double Rel
141143
//SceneSampleStep = 1.0/RelativeSceneSampleStep;
142144
angle_step = angle_step_radians;
143145
trained = false;
146+
node_pool_ = nullptr;
144147

145148
hash_table = NULL;
146149
hash_nodes = NULL;
@@ -181,17 +184,25 @@ void PPF3DDetector::computePPFFeatures(const Vec3d& p1, const Vec3d& n1,
181184

182185
void PPF3DDetector::clearTrainingModels()
183186
{
184-
if (this->hash_nodes)
185-
{
186-
free(this->hash_nodes);
187-
this->hash_nodes=0;
188-
}
189-
190-
if (this->hash_table)
191-
{
192-
hashtableDestroy(this->hash_table);
193-
this->hash_table=0;
194-
}
187+
if (hash_table) {
188+
hashtable_int* ht = (hashtable_int*)hash_table;
189+
// 如果使用节点池,则清空桶以防 hashtableDestroy 释放池内存
190+
if (node_pool_) {
191+
for (size_t i = 0; i < ht->size; ++i) {
192+
ht->nodes[i] = nullptr;
193+
}
194+
}
195+
hashtableDestroy(hash_table);
196+
hash_table = nullptr;
197+
}
198+
if (hash_nodes) {
199+
free(hash_nodes);
200+
hash_nodes = nullptr;
201+
}
202+
if (node_pool_) {
203+
free(node_pool_);
204+
node_pool_ = nullptr;
205+
}
195206
}
196207

197208
PPF3DDetector::~PPF3DDetector()
@@ -594,6 +605,175 @@ void PPF3DDetector::match(const Mat& pc, std::vector<Pose3DPtr>& results, const
594605
clusterPoses(poseList, numPosesAdded, results);
595606
}
596607

608+
void PPF3DDetector::saveModel(const std::string& filename) const
609+
{
610+
if (!trained) {
611+
CV_Error(cv::Error::StsError, "Model not trained, nothing to save.");
612+
}
613+
614+
std::ofstream ofs(filename, std::ios::binary);
615+
if (!ofs.is_open()) {
616+
CV_Error(cv::Error::StsError, "Cannot open file for writing.");
617+
}
618+
619+
// Save training parameters
620+
ofs.write(reinterpret_cast<const char*>(&sampling_step_relative), sizeof(sampling_step_relative));
621+
ofs.write(reinterpret_cast<const char*>(&distance_step_relative), sizeof(distance_step_relative));
622+
ofs.write(reinterpret_cast<const char*>(&angle_step_relative), sizeof(angle_step_relative));
623+
ofs.write(reinterpret_cast<const char*>(&angle_step_radians), sizeof(angle_step_radians));
624+
ofs.write(reinterpret_cast<const char*>(&angle_step), sizeof(angle_step));
625+
ofs.write(reinterpret_cast<const char*>(&distance_step), sizeof(distance_step));
626+
ofs.write(reinterpret_cast<const char*>(&num_ref_points), sizeof(num_ref_points));
627+
628+
// Save sampled point cloud
629+
int rows = sampled_pc.rows;
630+
int cols = sampled_pc.cols;
631+
ofs.write(reinterpret_cast<const char*>(&rows), sizeof(rows));
632+
ofs.write(reinterpret_cast<const char*>(&cols), sizeof(cols));
633+
ofs.write(reinterpret_cast<const char*>(sampled_pc.data), rows * cols * sizeof(float));
634+
635+
// Save PPF matrix
636+
rows = ppf.rows;
637+
cols = ppf.cols;
638+
ofs.write(reinterpret_cast<const char*>(&rows), sizeof(rows));
639+
ofs.write(reinterpret_cast<const char*>(&cols), sizeof(cols));
640+
ofs.write(reinterpret_cast<const char*>(ppf.data), rows * cols * sizeof(float));
641+
642+
// Save hash_nodes array
643+
size_t numNodes = static_cast<size_t>(num_ref_points) * num_ref_points;
644+
ofs.write(reinterpret_cast<const char*>(&numNodes), sizeof(numNodes));
645+
ofs.write(reinterpret_cast<const char*>(hash_nodes), numNodes * sizeof(THash));
646+
647+
// Save bucket information for fast hash table reconstruction
648+
hashtable_int* ht = (hashtable_int*)hash_table;
649+
size_t tableSize = ht->size;
650+
ofs.write(reinterpret_cast<const char*>(&tableSize), sizeof(tableSize));
651+
652+
for (size_t i = 0; i < tableSize; ++i) {
653+
hashnode_i* node = ht->nodes[i];
654+
std::vector<int> indices;
655+
while (node) {
656+
THash* th = (THash*)node->data;
657+
ptrdiff_t idx = th - hash_nodes; // Index within hash_nodes array
658+
indices.push_back(static_cast<int>(idx));
659+
node = node->next;
660+
}
661+
int count = static_cast<int>(indices.size());
662+
ofs.write(reinterpret_cast<const char*>(&count), sizeof(count));
663+
ofs.write(reinterpret_cast<const char*>(indices.data()), count * sizeof(int));
664+
}
665+
}
666+
667+
void PPF3DDetector::loadModel(const std::string& filename)
668+
{
669+
std::ifstream ifs(filename, std::ios::binary);
670+
if (!ifs.is_open()) {
671+
CV_Error(cv::Error::StsError, "Cannot open file for reading.");
672+
}
673+
674+
// Clear existing model to ensure safe loading
675+
clearTrainingModels();
676+
677+
// Load training parameters
678+
ifs.read(reinterpret_cast<char*>(&sampling_step_relative), sizeof(sampling_step_relative));
679+
ifs.read(reinterpret_cast<char*>(&distance_step_relative), sizeof(distance_step_relative));
680+
ifs.read(reinterpret_cast<char*>(&angle_step_relative), sizeof(angle_step_relative));
681+
ifs.read(reinterpret_cast<char*>(&angle_step_radians), sizeof(angle_step_radians));
682+
ifs.read(reinterpret_cast<char*>(&angle_step), sizeof(angle_step));
683+
ifs.read(reinterpret_cast<char*>(&distance_step), sizeof(distance_step));
684+
ifs.read(reinterpret_cast<char*>(&num_ref_points), sizeof(num_ref_points));
685+
686+
// Load sampled point cloud
687+
int rows, cols;
688+
ifs.read(reinterpret_cast<char*>(&rows), sizeof(rows));
689+
ifs.read(reinterpret_cast<char*>(&cols), sizeof(cols));
690+
sampled_pc.create(rows, cols, CV_32F);
691+
ifs.read(reinterpret_cast<char*>(sampled_pc.data), rows * cols * sizeof(float));
692+
693+
// Load PPF matrix
694+
ifs.read(reinterpret_cast<char*>(&rows), sizeof(rows));
695+
ifs.read(reinterpret_cast<char*>(&cols), sizeof(cols));
696+
ppf.create(rows, cols, CV_32F);
697+
ifs.read(reinterpret_cast<char*>(ppf.data), rows * cols * sizeof(float));
698+
699+
// Load hash_nodes array
700+
size_t numNodes;
701+
ifs.read(reinterpret_cast<char*>(&numNodes), sizeof(numNodes));
702+
if (numNodes != static_cast<size_t>(num_ref_points) * num_ref_points) {
703+
CV_Error(cv::Error::StsError, "Invalid number of hash nodes.");
704+
}
705+
706+
// Free old hash_nodes and node pool
707+
if (hash_nodes) {
708+
free(hash_nodes);
709+
hash_nodes = nullptr;
710+
}
711+
if (node_pool_) {
712+
free(node_pool_);
713+
node_pool_ = nullptr;
714+
}
715+
716+
hash_nodes = static_cast<THash*>(malloc(numNodes * sizeof(THash)));
717+
if (!hash_nodes) {
718+
CV_Error(cv::Error::StsNoMem, "Failed to allocate memory for hash nodes.");
719+
}
720+
ifs.read(reinterpret_cast<char*>(hash_nodes), numNodes * sizeof(THash));
721+
722+
// Reconstruct hash table
723+
if (hash_table) {
724+
hashtableDestroy(hash_table);
725+
hash_table = nullptr;
726+
}
727+
728+
// Read number of buckets
729+
size_t tableSize;
730+
ifs.read(reinterpret_cast<char*>(&tableSize), sizeof(tableSize));
731+
732+
// Create hash table (allocate only the nodes array)
733+
hash_table = hashtableCreate(static_cast<int>(tableSize), nullptr);
734+
hashtable_int* ht = (hashtable_int*)hash_table;
735+
736+
// Pre-allocate node pool for hashnode_i objects
737+
node_pool_ = static_cast<hashnode_i*>(malloc(numNodes * sizeof(hashnode_i)));
738+
if (!node_pool_) {
739+
CV_Error(cv::Error::StsNoMem, "Failed to allocate node pool.");
740+
}
741+
742+
// Rebuild linked lists per bucket
743+
for (size_t i = 0; i < tableSize; ++i) {
744+
int count;
745+
ifs.read(reinterpret_cast<char*>(&count), sizeof(count));
746+
if (count == 0) {
747+
ht->nodes[i] = nullptr;
748+
continue;
749+
}
750+
751+
std::vector<int> indices(count);
752+
ifs.read(reinterpret_cast<char*>(indices.data()), count * sizeof(int));
753+
754+
hashnode_i* prev = nullptr;
755+
hashnode_i* head = nullptr;
756+
for (int j = 0; j < count; ++j) {
757+
int idx = indices[j];
758+
THash* th = &hash_nodes[idx];
759+
hashnode_i* node = &node_pool_[idx];
760+
node->key = th->id;
761+
node->data = th;
762+
node->next = nullptr;
763+
764+
if (prev) {
765+
prev->next = node;
766+
} else {
767+
head = node;
768+
}
769+
prev = node;
770+
}
771+
ht->nodes[i] = head;
772+
}
773+
774+
trained = true;
775+
}
776+
597777
} // namespace ppf_match_3d
598778

599779
} // namespace cv

0 commit comments

Comments
 (0)