File: test_bayes.cpp

package info (click to toggle)
opencv 4.10.0%2Bdfsg-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 282,092 kB
  • sloc: cpp: 1,178,079; xml: 682,621; python: 49,092; lisp: 31,150; java: 25,469; ansic: 11,039; javascript: 6,085; sh: 1,214; cs: 601; perl: 494; objc: 210; makefile: 173
file content (56 lines) | stat: -rw-r--r-- 1,505 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.

#include "test_precomp.hpp"

namespace opencv_test { namespace {

TEST(ML_NBAYES, regression_5911)
{
    int N=12;
    Ptr<ml::NormalBayesClassifier> nb = cv::ml::NormalBayesClassifier::create();

    // data:
    float X_data[] = {
        1,2,3,4,  1,2,3,4,   1,2,3,4,    1,2,3,4,
        5,5,5,5,  5,5,5,5,   5,5,5,5,    5,5,5,5,
        4,3,2,1,  4,3,2,1,   4,3,2,1,    4,3,2,1
    };
    Mat_<float> X(N, 4, X_data);

    // labels:
    int Y_data[] = { 0,0,0,0, 1,1,1,1, 2,2,2,2 };
    Mat_<int> Y(N, 1, Y_data);

    nb->train(X, ml::ROW_SAMPLE, Y);

    // single prediction:
    Mat R1,P1;
    for (int i=0; i<N; i++)
    {
        Mat r,p;
        nb->predictProb(X.row(i), r, p);
        R1.push_back(r);
        P1.push_back(p);
    }

    // bulk prediction (continuous memory):
    Mat R2,P2;
    nb->predictProb(X, R2, P2);

    EXPECT_EQ(255 * R2.total(), sum(R1 == R2)[0]);
    EXPECT_EQ(255 * P2.total(), sum(P1 == P2)[0]);

    // bulk prediction, with non-continuous memory storage
    Mat R3_(N, 1+1, CV_32S),
        P3_(N, 3+1, CV_32F);
    nb->predictProb(X, R3_.col(0), P3_.colRange(0,3));
    Mat R3 = R3_.col(0).clone(),
        P3 = P3_.colRange(0,3).clone();

    EXPECT_EQ(255 * R3.total(), sum(R1 == R3)[0]);
    EXPECT_EQ(255 * P3.total(), sum(P1 == P3)[0]);
}

}} // namespace