-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Expand file tree
/
Copy pathsuper_scale.cpp
More file actions
63 lines (53 loc) · 1.99 KB
/
super_scale.cpp
File metadata and controls
63 lines (53 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Tencent is pleased to support the open source community by making WeChat QRCode available.
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
#include "../precomp.hpp"
#include "super_scale.hpp"
#define CLIP(x, x1, x2) max(x1, min(x, x2))
namespace cv {
namespace wechat_qrcode {
int SuperScale::init(const std::string &onnx_path) {
srnet_ = dnn::readNetFromONNX(onnx_path);
net_loaded_ = true;
return 0;
}
Mat SuperScale::processImageScale(const Mat &src, float scale, const bool &use_sr,
int sr_max_size) {
Mat dst = src;
if (scale == 1.0) { // src
return dst;
}
int width = src.cols;
int height = src.rows;
if (scale == 2.0) { // upsample
int SR_TH = sr_max_size;
if (use_sr && (int)sqrt(width * height * 1.0) < SR_TH && net_loaded_) {
int ret = superResoutionScale(src, dst);
if (ret == 0) return dst;
}
{ resize(src, dst, Size(), scale, scale, INTER_CUBIC); }
} else if (scale < 1.0) { // downsample
resize(src, dst, Size(), scale, scale, INTER_AREA);
}
return dst;
}
int SuperScale::superResoutionScale(const Mat &src, Mat &dst) {
Mat blob;
dnn::blobFromImage(src, blob, 1.0 / 255, Size(src.cols, src.rows), {0.0f}, false, false);
srnet_.setInput(blob);
auto prob = srnet_.forward();
dst = Mat(prob.size[2], prob.size[3], CV_8UC1);
for (int row = 0; row < prob.size[2]; row++) {
const float *prob_score = prob.ptr<float>(0, 0, row);
for (int col = 0; col < prob.size[3]; col++) {
float pixel = prob_score[col] * 255.0;
dst.at<uint8_t>(row, col) = static_cast<uint8_t>(CLIP(pixel, 0.0f, 255.0f));
}
}
return 0;
}
} // namespace wechat_qrcode
} // namespace cv