diff --git a/modules/wechat_qrcode/include/opencv2/wechat_qrcode.hpp b/modules/wechat_qrcode/include/opencv2/wechat_qrcode.hpp index 676104cd022..2bd905e5f7f 100644 --- a/modules/wechat_qrcode/include/opencv2/wechat_qrcode.hpp +++ b/modules/wechat_qrcode/include/opencv2/wechat_qrcode.hpp @@ -25,18 +25,14 @@ class CV_EXPORTS_W WeChatQRCode { public: /** * @brief Initialize the WeChatQRCode. - * It includes two models, which are packaged with caffe format. - * Therefore, there are prototxt and caffe models (In total, four paramenters). + * It includes two CNN-based models in ONNX format: + * a detector model and a super resolution model. * - * @param detector_prototxt_path prototxt file path for the detector - * @param detector_caffe_model_path caffe model file path for the detector - * @param super_resolution_prototxt_path prototxt file path for the super resolution model - * @param super_resolution_caffe_model_path caffe file path for the super resolution model + * @param detector_model_path onnx model file path for the detector + * @param super_resolution_model_path onnx model file path for the super resolution model */ - CV_WRAP WeChatQRCode(const std::string& detector_prototxt_path = "", - const std::string& detector_caffe_model_path = "", - const std::string& super_resolution_prototxt_path = "", - const std::string& super_resolution_caffe_model_path = ""); + CV_WRAP WeChatQRCode(const std::string& detector_model_path = "", + const std::string& super_resolution_model_path = ""); ~WeChatQRCode(){}; /** diff --git a/modules/wechat_qrcode/perf/perf_wechat_qrcode_pipeline.cpp b/modules/wechat_qrcode/perf/perf_wechat_qrcode_pipeline.cpp index e074a9cb8f8..865db5f037a 100644 --- a/modules/wechat_qrcode/perf/perf_wechat_qrcode_pipeline.cpp +++ b/modules/wechat_qrcode/perf/perf_wechat_qrcode_pipeline.cpp @@ -23,15 +23,13 @@ std::string qrcode_images_multiple[] = {/*"2_qrcodes.png",*/ "3_qrcodes.png", "3 WeChatQRCode createQRDetectorWithDNN(std::string& model_path) { - string path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, path_sr_caffemodel; + string path_detect, path_sr; if (!model_path.empty()) { - path_detect_prototxt = findDataFile(model_path + "/detect.prototxt", false); - path_detect_caffemodel = findDataFile(model_path + "/detect.caffemodel", false); - path_sr_prototxt = findDataFile(model_path + "/sr.prototxt", false); - path_sr_caffemodel = findDataFile(model_path + "/sr.caffemodel", false); + path_detect = findDataFile(model_path + "/detect.onnx", false); + path_sr = findDataFile(model_path + "/sr.onnx", false); } - return WeChatQRCode(path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, path_sr_caffemodel); + return WeChatQRCode(path_detect, path_sr); } typedef ::perf::TestBaseWithParam< tuple< std::string,std::string > > Perf_Objdetect_QRCode; diff --git a/modules/wechat_qrcode/samples/qrcode_example.cpp b/modules/wechat_qrcode/samples/qrcode_example.cpp index 046525b90f2..0aa428d921a 100644 --- a/modules/wechat_qrcode/samples/qrcode_example.cpp +++ b/modules/wechat_qrcode/samples/qrcode_example.cpp @@ -24,19 +24,18 @@ int main(int argc, char* argv[]) { cout << " Usage: " << argv[0] << " " << endl; return 0; } - // The model is downloaded to ${CMAKE_BINARY_DIR}/downloads/wechat_qrcode if cmake runs without warnings, - // otherwise you can download them from https://github.com/WeChatCV/opencv_3rdparty/tree/wechat_qrcode. + // ONNX models: detect.onnx and sr.onnx + // available in opencv_extra/testdata/dnn/wechat_2021-01/ Ptr detector; try { - detector = makePtr("detect.prototxt", "detect.caffemodel", - "sr.prototxt", "sr.caffemodel"); + detector = makePtr("detect.onnx", "sr.onnx"); } catch (const std::exception& e) { cout << "\n---------------------------------------------------------------\n" "Failed to initialize WeChatQRCode.\n" - "Please, download 'detector.*' and 'sr.*' from\n" - "https://github.com/WeChatCV/opencv_3rdparty/tree/wechat_qrcode\n" + "Please, provide 'detect.onnx' and 'sr.onnx' from\n" + "opencv_extra/testdata/dnn/wechat_2021-01/\n" "and put them into the current directory.\n" "---------------------------------------------------------------\n"; cout << e.what() << endl; diff --git a/modules/wechat_qrcode/samples/qrcode_example_without_nn.cpp b/modules/wechat_qrcode/samples/qrcode_example_without_nn.cpp index 7428d0dd6be..ac6ffa03f46 100644 --- a/modules/wechat_qrcode/samples/qrcode_example_without_nn.cpp +++ b/modules/wechat_qrcode/samples/qrcode_example_without_nn.cpp @@ -30,7 +30,7 @@ int main(int argc, char* argv[]) { Ptr detector; try { - detector = makePtr("", "", "", ""); + detector = makePtr(); } catch (const std::exception& e) { cout << "\n---------------------------------------------------------------\n" diff --git a/modules/wechat_qrcode/src/detector/ssd_detector.cpp b/modules/wechat_qrcode/src/detector/ssd_detector.cpp index 52e8a635d56..2d28d65e5c6 100644 --- a/modules/wechat_qrcode/src/detector/ssd_detector.cpp +++ b/modules/wechat_qrcode/src/detector/ssd_detector.cpp @@ -9,8 +9,8 @@ #define CLIP(x, x1, x2) max(x1, min(x, x2)) namespace cv { namespace wechat_qrcode { -int SSDDetector::init(const string& proto_path, const string& model_path) { - net_ = dnn::readNetFromCaffe(proto_path, model_path); +int SSDDetector::init(const string& onnx_path) { + net_ = dnn::readNetFromONNX(onnx_path); return 0; } @@ -22,7 +22,7 @@ vector SSDDetector::forward(Mat img, const int target_width, const int targ dnn::blobFromImage(input, input, 1.0 / 255, Size(input.cols, input.rows), {0.0f, 0.0f, 0.0f}, false, false); - net_.setInput(input, "data"); + net_.setInput(input); auto prob = net_.forward(); vector point_list; diff --git a/modules/wechat_qrcode/src/detector/ssd_detector.hpp b/modules/wechat_qrcode/src/detector/ssd_detector.hpp index e510cb32477..b4a928a76e1 100644 --- a/modules/wechat_qrcode/src/detector/ssd_detector.hpp +++ b/modules/wechat_qrcode/src/detector/ssd_detector.hpp @@ -19,7 +19,7 @@ class SSDDetector { public: SSDDetector(){}; ~SSDDetector(){}; - int init(const std::string& proto_path, const std::string& model_path); + int init(const std::string& onnx_path); std::vector forward(Mat img, const int target_width, const int target_height); private: diff --git a/modules/wechat_qrcode/src/scale/super_scale.cpp b/modules/wechat_qrcode/src/scale/super_scale.cpp index 8b3b11383a1..fe9b98337c1 100644 --- a/modules/wechat_qrcode/src/scale/super_scale.cpp +++ b/modules/wechat_qrcode/src/scale/super_scale.cpp @@ -11,8 +11,8 @@ #define CLIP(x, x1, x2) max(x1, min(x, x2)) namespace cv { namespace wechat_qrcode { -int SuperScale::init(const std::string &proto_path, const std::string &model_path) { - srnet_ = dnn::readNetFromCaffe(proto_path, model_path); +int SuperScale::init(const std::string &onnx_path) { + srnet_ = dnn::readNetFromONNX(onnx_path); net_loaded_ = true; return 0; } diff --git a/modules/wechat_qrcode/src/scale/super_scale.hpp b/modules/wechat_qrcode/src/scale/super_scale.hpp index 2717932c555..d6c8404c4f6 100644 --- a/modules/wechat_qrcode/src/scale/super_scale.hpp +++ b/modules/wechat_qrcode/src/scale/super_scale.hpp @@ -18,7 +18,7 @@ class SuperScale { public: SuperScale(){}; ~SuperScale(){}; - int init(const std::string &proto_path, const std::string &model_path); + int init(const std::string &onnx_path); Mat processImageScale(const Mat &src, float scale, const bool &use_sr, int sr_max_size = 160); private: diff --git a/modules/wechat_qrcode/src/wechat_qrcode.cpp b/modules/wechat_qrcode/src/wechat_qrcode.cpp index 64aad73610b..5726f67feeb 100644 --- a/modules/wechat_qrcode/src/wechat_qrcode.cpp +++ b/modules/wechat_qrcode/src/wechat_qrcode.cpp @@ -47,34 +47,25 @@ class WeChatQRCode::Impl { float scaleFactor = -1.f; }; -WeChatQRCode::WeChatQRCode(const String& detector_prototxt_path, - const String& detector_caffe_model_path, - const String& super_resolution_prototxt_path, - const String& super_resolution_caffe_model_path) { +WeChatQRCode::WeChatQRCode(const String& detector_model_path, + const String& super_resolution_model_path) { p = makePtr(); - if (!detector_caffe_model_path.empty() && !detector_prototxt_path.empty()) { - // initialize detector model (caffe) + if (!detector_model_path.empty()) { p->use_nn_detector_ = true; - CV_Assert(utils::fs::exists(detector_prototxt_path)); - CV_Assert(utils::fs::exists(detector_caffe_model_path)); + CV_Assert(utils::fs::exists(detector_model_path)); p->detector_ = make_shared(); - auto ret = p->detector_->init(detector_prototxt_path, detector_caffe_model_path); + auto ret = p->detector_->init(detector_model_path); CV_Assert(ret == 0); } else { p->use_nn_detector_ = false; p->detector_ = NULL; } - // initialize super_resolution_model - // it could also support non model weights by cubic resizing - // so, we initialize it first. + // super resolution supports fallback to cubic resize when no model is given p->super_resolution_model_ = make_shared(); - if (!super_resolution_prototxt_path.empty() && !super_resolution_caffe_model_path.empty()) { + if (!super_resolution_model_path.empty()) { p->use_nn_sr_ = true; - // initialize dnn model (caffe format) - CV_Assert(utils::fs::exists(super_resolution_prototxt_path)); - CV_Assert(utils::fs::exists(super_resolution_caffe_model_path)); - auto ret = p->super_resolution_model_->init(super_resolution_prototxt_path, - super_resolution_caffe_model_path); + CV_Assert(utils::fs::exists(super_resolution_model_path)); + auto ret = p->super_resolution_model_->init(super_resolution_model_path); CV_Assert(ret == 0); } else { p->use_nn_sr_ = false; diff --git a/modules/wechat_qrcode/test/test_qrcode.cpp b/modules/wechat_qrcode/test/test_qrcode.cpp index ddc2828e716..c93be473823 100644 --- a/modules/wechat_qrcode/test/test_qrcode.cpp +++ b/modules/wechat_qrcode/test/test_qrcode.cpp @@ -237,20 +237,16 @@ typedef testing::TestWithParam Objdetect_QRCode_Multi; TEST_P(Objdetect_QRCode_Multi, regression) { const std::string name_current_image = GetParam(); const std::string root = "qrcode/multiple/"; - string path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, path_sr_caffemodel; - string model_version = "_2021-01"; - path_detect_prototxt = findDataFile("dnn/wechat"+model_version+"/detect.prototxt", false); - path_detect_caffemodel = findDataFile("dnn/wechat"+model_version+"/detect.caffemodel", false); - path_sr_prototxt = findDataFile("dnn/wechat"+model_version+"/sr.prototxt", false); - path_sr_caffemodel = findDataFile("dnn/wechat"+model_version+"/sr.caffemodel", false); + const string model_dir = "dnn/wechat_2021-01"; + string path_detect = findDataFile(model_dir + "/detect.onnx", false); + string path_sr = findDataFile(model_dir + "/sr.onnx", false); std::string image_path = findDataFile(root + name_current_image); Mat src = imread(image_path); ASSERT_FALSE(src.empty()) << "Can't read image: " << image_path; vector points; - auto detector = wechat_qrcode::WeChatQRCode(path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, - path_sr_caffemodel); + auto detector = wechat_qrcode::WeChatQRCode(path_detect, path_sr); vector decoded_info = detector.detectAndDecode(src, points); const std::string dataset_config = findDataFile(root + "dataset_config.json"); @@ -287,15 +283,11 @@ TEST_P(Objdetect_QRCode_Multi, regression) { } TEST(Objdetect_QRCode_points_position, rotate45) { - string path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, path_sr_caffemodel; - string model_version = "_2021-01"; - path_detect_prototxt = findDataFile("dnn/wechat"+model_version+"/detect.prototxt", false); - path_detect_caffemodel = findDataFile("dnn/wechat"+model_version+"/detect.caffemodel", false); - path_sr_prototxt = findDataFile("dnn/wechat"+model_version+"/sr.prototxt", false); - path_sr_caffemodel = findDataFile("dnn/wechat"+model_version+"/sr.caffemodel", false); + const string model_dir = "dnn/wechat_2021-01"; + string path_detect = findDataFile(model_dir + "/detect.onnx", false); + string path_sr = findDataFile(model_dir + "/sr.onnx", false); - auto detector = wechat_qrcode::WeChatQRCode(path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, - path_sr_caffemodel); + auto detector = wechat_qrcode::WeChatQRCode(path_detect, path_sr); const cv::String expect_msg = "OpenCV"; QRCodeEncoder::Params params; @@ -348,15 +340,11 @@ INSTANTIATE_TEST_CASE_P(/**/, Objdetect_QRCode_Curved, testing::ValuesIn(qrcode_ INSTANTIATE_TEST_CASE_P(/**/, Objdetect_QRCode_Multi, testing::ValuesIn(qrcode_images_multiple)); TEST(Objdetect_QRCode_Big, regression) { - string path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, path_sr_caffemodel; - string model_version = "_2021-01"; - path_detect_prototxt = findDataFile("dnn/wechat"+model_version+"/detect.prototxt", false); - path_detect_caffemodel = findDataFile("dnn/wechat"+model_version+"/detect.caffemodel", false); - path_sr_prototxt = findDataFile("dnn/wechat"+model_version+"/sr.prototxt", false); - path_sr_caffemodel = findDataFile("dnn/wechat"+model_version+"/sr.caffemodel", false); + const string model_dir = "dnn/wechat_2021-01"; + string path_detect = findDataFile(model_dir + "/detect.onnx", false); + string path_sr = findDataFile(model_dir + "/sr.onnx", false); - auto detector = wechat_qrcode::WeChatQRCode(path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, - path_sr_caffemodel); + auto detector = wechat_qrcode::WeChatQRCode(path_detect, path_sr); const cv::String expect_msg = "OpenCV"; QRCodeEncoder::Params params; @@ -379,15 +367,11 @@ TEST(Objdetect_QRCode_Big, regression) { } TEST(Objdetect_QRCode_Tiny, regression) { - string path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, path_sr_caffemodel; - string model_version = "_2021-01"; - path_detect_prototxt = findDataFile("dnn/wechat"+model_version+"/detect.prototxt", false); - path_detect_caffemodel = findDataFile("dnn/wechat"+model_version+"/detect.caffemodel", false); - path_sr_prototxt = findDataFile("dnn/wechat"+model_version+"/sr.prototxt", false); - path_sr_caffemodel = findDataFile("dnn/wechat"+model_version+"/sr.caffemodel", false); + const string model_dir = "dnn/wechat_2021-01"; + string path_detect = findDataFile(model_dir + "/detect.onnx", false); + string path_sr = findDataFile(model_dir + "/sr.onnx", false); - auto detector = wechat_qrcode::WeChatQRCode(path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, - path_sr_caffemodel); + auto detector = wechat_qrcode::WeChatQRCode(path_detect, path_sr); const cv::String expect_msg = "OpenCV"; QRCodeEncoder::Params params; @@ -411,18 +395,15 @@ TEST(Objdetect_QRCode_Tiny, regression) { typedef testing::TestWithParam Objdetect_QRCode_Easy_Multi; TEST_P(Objdetect_QRCode_Easy_Multi, regression) { - string path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, path_sr_caffemodel; + string path_detect, path_sr; string model_path = GetParam(); if (!model_path.empty()) { - path_detect_prototxt = findDataFile(model_path + "/detect.prototxt", false); - path_detect_caffemodel = findDataFile(model_path + "/detect.caffemodel", false); - path_sr_prototxt = findDataFile(model_path + "/sr.prototxt", false); - path_sr_caffemodel = findDataFile(model_path + "/sr.caffemodel", false); + path_detect = findDataFile(model_path + "/detect.onnx", false); + path_sr = findDataFile(model_path + "/sr.onnx", false); } - auto detector = wechat_qrcode::WeChatQRCode(path_detect_prototxt, path_detect_caffemodel, path_sr_prototxt, - path_sr_caffemodel); + auto detector = wechat_qrcode::WeChatQRCode(path_detect, path_sr); const cv::String expect_msg1 = "OpenCV1", expect_msg2 = "OpenCV2"; QRCodeEncoder::Params params;