## File: probabilistic.cpp

package info (click to toggle)
mldemos 0.5.1-3
 `123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123` ``````// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include #include #include #include #include #include #include "../stl_checked.h" #include "../array.h" #include "../rand.h" #include "checkerboard.h" #include #include "tester.h" #include namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.probabilistic"); // ---------------------------------------------------------------------------------------- class test_probabilistic : public tester { public: test_probabilistic ( ) : tester ("test_probabilistic", "Runs tests on the probabilistic trainer adapter.") {} void perform_test ( ) { print_spinner(); typedef double scalar_type; typedef matrix sample_type; std::vector x; std::vector > x_linearized; std::vector y; get_checkerboard_problem(x,y, 1000, 2); random_subset_selector rx; random_subset_selector ry; rx.set_max_size(x.size()); ry.set_max_size(x.size()); dlog << LINFO << "pos labels: "<< sum(vector_to_matrix(y) == +1); dlog << LINFO << "neg labels: "<< sum(vector_to_matrix(y) == -1); for (unsigned long i = 0; i < x.size(); ++i) { rx.add(x[i]); ry.add(y[i]); } const scalar_type gamma = 2.0; typedef radial_basis_kernel kernel_type; krr_trainer krr_trainer; krr_trainer.use_classification_loss_for_loo_cv(); krr_trainer.set_kernel(kernel_type(gamma)); krr_trainer.set_basis(randomly_subsample(x, 100)); probabilistic_decision_function df; dlog << LINFO << "cross validation: " << cross_validate_trainer(krr_trainer, rx,ry, 4); print_spinner(); running_stats rs_pos, rs_neg; print_spinner(); df = probabilistic(krr_trainer,3).train(x, y); for (unsigned long i = 0; i < x.size(); ++i) { if (y[i] > 0) rs_pos.add(df(x[i])); else rs_neg.add(df(x[i])); } dlog << LINFO << "rs_pos.mean(): "<< rs_pos.mean(); dlog << LINFO << "rs_neg.mean(): "<< rs_neg.mean(); DLIB_TEST_MSG(rs_pos.mean() > 0.95, rs_pos.mean()); DLIB_TEST_MSG(rs_neg.mean() < 0.05, rs_neg.mean()); rs_pos.clear(); rs_neg.clear(); print_spinner(); df = probabilistic(krr_trainer,3).train(rx, ry); for (unsigned long i = 0; i < x.size(); ++i) { if (y[i] > 0) rs_pos.add(df(x[i])); else rs_neg.add(df(x[i])); } dlog << LINFO << "rs_pos.mean(): "<< rs_pos.mean(); dlog << LINFO << "rs_neg.mean(): "<< rs_neg.mean(); DLIB_TEST_MSG(rs_pos.mean() > 0.95, rs_pos.mean()); DLIB_TEST_MSG(rs_neg.mean() < 0.05, rs_neg.mean()); rs_pos.clear(); rs_neg.clear(); } } a; } ``````