1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
#include <iostream>
#include <fstream>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn/dnn.hpp>
using namespace cv;
using namespace cv::dnn;
String keys =
"{ help h | | Print help message. }"
"{ inputImage i | | Path to an input image. Skip this argument to capture frames from a camera. }"
"{ modelPath mp | | Path to a binary .onnx file contains trained CRNN text recognition model. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
"{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }"
"{ evaluate e |false| false: predict with input images; true: evaluate on benchmarks. }"
"{ evalDataPath edp | | Path to benchmarks for evaluation. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
"{ vocabularyPath vp | alphabet_36.txt | Path to recognition vocabulary. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}";
String convertForEval(String &input);
int main(int argc, char** argv)
{
// Parse arguments
CommandLineParser parser(argc, argv, keys);
parser.about("Use this script to run the PyTorch implementation of "
"An End-to-End Trainable Neural Network for Image-based SequenceRecognition and Its Application to Scene Text Recognition "
"(https://arxiv.org/abs/1507.05717)");
if (argc == 1 || parser.has("help"))
{
parser.printMessage();
return 0;
}
String modelPath = parser.get<String>("modelPath");
String vocPath = parser.get<String>("vocabularyPath");
int imreadRGB = parser.get<int>("RGBInput");
if (!parser.check())
{
parser.printErrors();
return 1;
}
// Load the network
CV_Assert(!modelPath.empty());
TextRecognitionModel recognizer(modelPath);
// Load vocabulary
CV_Assert(!vocPath.empty());
std::ifstream vocFile;
vocFile.open(samples::findFile(vocPath));
CV_Assert(vocFile.is_open());
String vocLine;
std::vector<String> vocabulary;
while (std::getline(vocFile, vocLine)) {
vocabulary.push_back(vocLine);
}
recognizer.setVocabulary(vocabulary);
recognizer.setDecodeType("CTC-greedy");
// Set parameters
double scale = 1.0 / 127.5;
Scalar mean = Scalar(127.5, 127.5, 127.5);
Size inputSize = Size(100, 32);
recognizer.setInputParams(scale, inputSize, mean);
if (parser.get<bool>("evaluate"))
{
// For evaluation
String evalDataPath = parser.get<String>("evalDataPath");
CV_Assert(!evalDataPath.empty());
String gtPath = evalDataPath + "/test_gts.txt";
std::ifstream evalGts;
evalGts.open(gtPath);
CV_Assert(evalGts.is_open());
String gtLine;
int cntRight=0, cntAll=0;
TickMeter timer;
timer.reset();
while (std::getline(evalGts, gtLine)) {
size_t splitLoc = gtLine.find_first_of(' ');
String imgPath = evalDataPath + '/' + gtLine.substr(0, splitLoc);
String gt = gtLine.substr(splitLoc+1);
// Inference
Mat frame = imread(samples::findFile(imgPath), imreadRGB);
CV_Assert(!frame.empty());
timer.start();
std::string recognitionResult = recognizer.recognize(frame);
timer.stop();
if (gt == convertForEval(recognitionResult))
cntRight++;
cntAll++;
}
std::cout << "Accuracy(%): " << (double)(cntRight) / (double)(cntAll) << std::endl;
std::cout << "Average Inference Time(ms): " << timer.getTimeMilli() / (double)(cntAll) << std::endl;
}
else
{
// Create a window
static const std::string winName = "Input Cropped Image";
// Open an image file
CV_Assert(parser.has("inputImage"));
Mat frame = imread(samples::findFile(parser.get<String>("inputImage")), imreadRGB);
CV_Assert(!frame.empty());
// Recognition
std::string recognitionResult = recognizer.recognize(frame);
imshow(winName, frame);
std::cout << "Predition: '" << recognitionResult << "'" << std::endl;
waitKey();
}
return 0;
}
// Convert the predictions to lower case, and remove other characters.
// Only for Evaluation
String convertForEval(String & input)
{
String output;
for (uint i = 0; i < input.length(); i++){
char ch = input[i];
if ((int)ch >= 97 && (int)ch <= 122) {
output.push_back(ch);
} else if ((int)ch >= 65 && (int)ch <= 90) {
output.push_back((char)(ch + 32));
} else {
continue;
}
}
return output;
}
|