1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
|
#include <iostream>
#include <fstream>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn/dnn.hpp>
using namespace cv;
using namespace cv::dnn;
std::string keys =
"{ help h | | Print help message. }"
"{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }"
"{ modelPath mp | | Path to a binary .onnx file contains trained DB detector model. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
"{ inputHeight ih |736| image height of the model input. It should be multiple by 32.}"
"{ inputWidth iw |736| image width of the model input. It should be multiple by 32.}"
"{ binaryThreshold bt |0.3| Confidence threshold of the binary map. }"
"{ polygonThreshold pt |0.5| Confidence threshold of polygons. }"
"{ maxCandidate max |200| Max candidates of polygons. }"
"{ unclipRatio ratio |2.0| unclip ratio. }"
"{ evaluate e |false| false: predict with input images; true: evaluate on benchmarks. }"
"{ evalDataPath edp | | Path to benchmarks for evaluation. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}";
static
void split(const std::string& s, char delimiter, std::vector<std::string>& elems)
{
elems.clear();
size_t prev_pos = 0;
size_t pos = 0;
while ((pos = s.find(delimiter, prev_pos)) != std::string::npos)
{
elems.emplace_back(s.substr(prev_pos, pos - prev_pos));
prev_pos = pos + 1;
}
if (prev_pos < s.size())
elems.emplace_back(s.substr(prev_pos, s.size() - prev_pos));
}
int main(int argc, char** argv)
{
// Parse arguments
CommandLineParser parser(argc, argv, keys);
parser.about("Use this script to run the official PyTorch implementation (https://github.com/MhLiao/DB) of "
"Real-time Scene Text Detection with Differentiable Binarization (https://arxiv.org/abs/1911.08947)\n"
"The current version of this script is a variant of the original network without deformable convolution");
if (argc == 1 || parser.has("help"))
{
parser.printMessage();
return 0;
}
float binThresh = parser.get<float>("binaryThreshold");
float polyThresh = parser.get<float>("polygonThreshold");
uint maxCandidates = parser.get<uint>("maxCandidate");
String modelPath = parser.get<String>("modelPath");
double unclipRatio = parser.get<double>("unclipRatio");
int height = parser.get<int>("inputHeight");
int width = parser.get<int>("inputWidth");
if (!parser.check())
{
parser.printErrors();
return 1;
}
// Load the network
CV_Assert(!modelPath.empty());
TextDetectionModel_DB detector(modelPath);
detector.setBinaryThreshold(binThresh)
.setPolygonThreshold(polyThresh)
.setUnclipRatio(unclipRatio)
.setMaxCandidates(maxCandidates);
double scale = 1.0 / 255.0;
Size inputSize = Size(width, height);
Scalar mean = Scalar(122.67891434, 116.66876762, 104.00698793);
detector.setInputParams(scale, inputSize, mean);
// Create a window
static const std::string winName = "TextDetectionModel";
if (parser.get<bool>("evaluate")) {
// for evaluation
String evalDataPath = parser.get<String>("evalDataPath");
CV_Assert(!evalDataPath.empty());
String testListPath = evalDataPath + "/test_list.txt";
std::ifstream testList;
testList.open(testListPath);
CV_Assert(testList.is_open());
// Create a window for showing groundtruth
static const std::string winNameGT = "GT";
String testImgPath;
while (std::getline(testList, testImgPath)) {
String imgPath = evalDataPath + "/test_images/" + testImgPath;
std::cout << "Image Path: " << imgPath << std::endl;
Mat frame = imread(samples::findFile(imgPath), IMREAD_COLOR);
CV_Assert(!frame.empty());
Mat src = frame.clone();
// Inference
std::vector<std::vector<Point>> results;
detector.detect(frame, results);
polylines(frame, results, true, Scalar(0, 255, 0), 2);
imshow(winName, frame);
// load groundtruth
String imgName = testImgPath.substr(0, testImgPath.length() - 4);
String gtPath = evalDataPath + "/test_gts/" + imgName + ".txt";
// std::cout << gtPath << std::endl;
std::ifstream gtFile;
gtFile.open(gtPath);
CV_Assert(gtFile.is_open());
std::vector<std::vector<Point>> gts;
String gtLine;
while (std::getline(gtFile, gtLine)) {
size_t splitLoc = gtLine.find_last_of(',');
String text = gtLine.substr(splitLoc+1);
if ( text == "###\r" || text == "1") {
// ignore difficult instances
continue;
}
gtLine = gtLine.substr(0, splitLoc);
std::vector<std::string> v;
split(gtLine, ',', v);
std::vector<int> loc;
std::vector<Point> pts;
for (auto && s : v) {
loc.push_back(atoi(s.c_str()));
}
for (size_t i = 0; i < loc.size() / 2; i++) {
pts.push_back(Point(loc[2 * i], loc[2 * i + 1]));
}
gts.push_back(pts);
}
polylines(src, gts, true, Scalar(0, 255, 0), 2);
imshow(winNameGT, src);
waitKey();
}
} else {
// Open an image file
CV_Assert(parser.has("inputImage"));
Mat frame = imread(samples::findFile(parser.get<String>("inputImage")));
CV_Assert(!frame.empty());
// Detect
std::vector<std::vector<Point>> results;
detector.detect(frame, results);
polylines(frame, results, true, Scalar(0, 255, 0), 2);
imshow(winName, frame);
waitKey();
}
return 0;
}
|