1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
/*******************************************************
* Copyright (c) 2014, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <sstream>
#include <algorithm>
#include <utility>
#include "../common/idxio.h"
bool compare(const std::pair<float, int> l,
const std::pair<float, int> r)
{
return l.first >= r.first;
}
typedef std::pair<float, int> sort_type;
template<bool expand_labels>
std::string classify(af::array arr, int k)
{
std::stringstream ss;
if (expand_labels) {
af::array vec = arr(af::span, k).as(f32);
float *h_vec = vec.host<float>();
std::vector<sort_type> data;
for (int i = 0; i < (int)vec.elements(); i++)
data.push_back(std::make_pair(h_vec[i], i));
std::stable_sort(data.begin(), data.end(), compare);
ss << data[0].second;
} else {
ss << (int)(arr(k).as(f32).scalar<float>());
}
return ss.str();
}
template<bool expand_labels>
static void setup_mnist(int *num_classes, int *num_train, int *num_test,
af::array &train_images, af::array &test_images,
af::array &train_labels, af::array &test_labels, float frac)
{
std::vector<dim_t> idims;
std::vector<float > idata;
read_idx(idims, idata, ASSETS_DIR"/examples/data/mnist/images-subset");
std::vector<dim_t> ldims;
std::vector<unsigned> ldata;
read_idx(ldims, ldata, ASSETS_DIR"/examples/data/mnist/labels-subset");
std::reverse(idims.begin(), idims.end());
unsigned numdims = idims.size();
af::array images = af::array(af::dim4(numdims, &idims[0]), &idata[0]);
af::array R = af::randu(10000, 1);
af::array cond = R < std::min(frac, 0.8f);
af::array train_indices = where( cond);
af::array test_indices = where(!cond);
train_images = lookup(images, train_indices, 2) / 255;
test_images = lookup(images, test_indices , 2) / 255;
*num_classes = 10;
*num_train = train_images.dims(2);
*num_test = test_images.dims(2);
if (expand_labels) {
train_labels = af::constant(0, *num_classes, *num_train);
test_labels = af::constant(0, *num_classes, *num_test );
unsigned *h_train_idx = train_indices.host<unsigned>();
unsigned *h_test_idx = test_indices.host<unsigned>();
for (int ii = 0; ii < *num_train; ii++) {
train_labels(ldata[h_train_idx[ii]], ii) = 1;
}
for (int ii = 0; ii < *num_test; ii++) {
test_labels(ldata[ h_test_idx[ii]], ii) = 1;
}
delete[] h_train_idx;
delete[] h_test_idx;
} else {
af::array labels = af::array(ldims[0], &ldata[0]);
train_labels = labels(train_indices);
test_labels = labels( test_indices);
}
return;
}
#if 0
static af::array randidx(int num, int total)
{
af::array locs;
do {
locs = af::where(af::randu(total, 1) < float(num * 2) / total);
} while (locs.elements() < num);
return locs(af::seq(num));
}
#endif
template<bool expand_labels>
static void display_results(const af::array &test_images,
const af::array &test_output,
const af::array &test_actual,
int num_display)
{
#if 0
af::array locs = randidx(num_display, test_images.dims(2));
af::array disp_in = test_images(af::span, af::span, locs);
af::array disp_out = expand_labels ? test_output(af::span, locs) : test_output(locs);
for (int i = 0; i < 5; i++) {
int imgs_per_iter = num_display / 5;
for (int j = 0; j < imgs_per_iter; j++) {
int k = i * imgs_per_iter + j;
af::fig("sub", 2, imgs_per_iter / 2, j+1);
af::image(disp_in(af::span, af::span, k).T());
std::string pred_name = std::string("Predicted: ");
pred_name = pred_name + classify<expand_labels>(disp_out, k);
af::fig("title", pred_name.c_str());
}
printf("Press any key to see next set");
getchar();
}
#else
using namespace af;
for(int i = 0; i < num_display; i++) {
std::cout << "Predicted: " << classify<expand_labels>(test_output, i) << std::endl;
std::cout << "Actual: " << classify<expand_labels>(test_actual, i) << std::endl;
unsigned char* img = (test_images(span, span, i) > 0.1f).as(u8).host<unsigned char>();
for(int j = 0; j < 28; j++) {
for(int k = 0; k < 28; k++) {
std::cout << (img[j*28+k] ? "\u2588" : " ") << " ";
}
std::cout << std::endl;
}
delete[] img;
getchar();
}
#endif
}
|