File: softmax_regression.cpp

package info (click to toggle)
arrayfire 3.3.2%2Bdfsg1-4
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 109,016 kB
  • sloc: cpp: 127,909; lisp: 6,878; python: 3,923; ansic: 1,051; sh: 347; makefile: 338; xml: 175
file content (208 lines) | stat: -rw-r--r-- 5,934 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
/*******************************************************
 * Copyright (c) 2014, ArrayFire
 * All rights reserved.
 *
 * This file is distributed under 3-clause BSD license.
 * The complete license agreement can be obtained at:
 * http://arrayfire.com/licenses/BSD-3-Clause
 ********************************************************/

#include <arrayfire.h>
#include <stdio.h>
#include <vector>
#include <string>
#include <af/util.h>
#include <math.h>
#include "mnist_common.h"

using namespace af;

float accuracy(const array& predicted, const array& target)
{
    array val, plabels, tlabels;
    max(val, tlabels, target, 1);
    max(val, plabels, predicted, 1);

    return 100 * count<float>(plabels == tlabels) / tlabels.elements();
}

float abserr(const array& predicted, const array& target)
{
    return 100 * sum<float>(abs(predicted - target)) / predicted.elements();
}

array divide(const array &a, const array &b)
{
    return a / b;
}

// Predict based on given parameters
array predict(const array &X, const array &Weights)
{
    array Z = matmul(X, Weights);
    array EZ = exp(Z);
    array nrm = sum(EZ, 1);
    return batchFunc(EZ, nrm, divide);
}

void cost(array &J, array &dJ, const array &Weights,
          const array &X, const array &Y, double lambda = 1.0)
{
    // Number of samples
    int m = Y.dims(0);

    // Make the lambda corresponding to Weights(0) == 0
    array lambdat = constant(lambda, Weights.dims());

    // No regularization for bias weights
    lambdat(0, span) = 0;

    // Get the prediction
    array H = predict(X, Weights);

    // Cost of misprediction
    array Jerr =  -sum(Y * log(H));

    // Regularization cost
    array Jreg = 0.5 * sum(lambdat * Weights * Weights);

    // Total cost
    J = (Jerr + Jreg) / m;

    // Find the gradient of cost
    array D = (H - Y);
    dJ = (matmulTN(X, D) + lambdat * Weights) / m;
}

array train(const array &X, const array &Y,
            double alpha = 0.1,
            double lambda = 1.0,
            double maxerr = 0.01,
            int maxiter = 1000,
            bool verbose = false)
{

    // Initialize parameters to 0
    array Weights = constant(0, X.dims(1), Y.dims(1));

    array J, dJ;
    float err = 0;

    for (int i = 0; i < maxiter; i++) {

        // Get the cost and gradient
        cost(J, dJ, Weights, X, Y, lambda);

        err = max<float>(abs(J));
        if (err < maxerr) {
            printf("Iteration %4d Err: %.4f\n", i + 1, err);
            printf("Training converged\n");
            return Weights;
        }

        if (verbose && ((i + 1) % 10 == 0)) {
            printf("Iteration %4d Err: %.4f\n", i + 1, err);
        }

        // Update the parameters via gradient descent
        Weights = Weights - alpha * dJ;
    }

    printf("Training stopped after %d iterations\n", maxiter);
    return Weights;
}

void benchmark_softmax_regression(const array &train_feats,
                                  const array &train_targets,
                                  const array test_feats)
{
    timer::start();
    array Weights = train(train_feats, train_targets, 0.1, 1.0, 0.01, 1000);
    af::sync();
    printf("Training time: %4.4lf s\n", timer::stop());

    timer::start();
    const int iter = 100;
    for (int i = 0; i < iter; i++) {
        array test_outputs  = predict(test_feats , Weights);
        test_outputs.eval();
    }
    af::sync();
    printf("Prediction time: %4.4lf s\n", timer::stop() / iter);
}

// Demo of one vs all logistic regression
int logit_demo(bool console, int perc)
{
    array train_images, train_targets;
    array test_images, test_targets;
    int num_train, num_test, num_classes;

    // Load mnist data
    float frac = (float)(perc) / 100.0;
    setup_mnist<true>(&num_classes, &num_train, &num_test,
                      train_images, test_images,
                      train_targets, test_targets, frac);

    // Reshape images into feature vectors
    int feature_length = train_images.elements() / num_train;
    array train_feats = moddims(train_images, feature_length, num_train).T();
    array test_feats  = moddims(test_images , feature_length, num_test ).T();

    train_targets = train_targets.T();
    test_targets  = test_targets.T();

    // Add a bias that is always 1
    train_feats = join(1, constant(1, num_train, 1), train_feats);
    test_feats  = join(1, constant(1, num_test , 1), test_feats );

    // Train logistic regression parameters
    array Weights = train(train_feats, train_targets,
                          0.1,  // learning rate (aka alpha)
                          1.0,  // regularization constant (aka weight decay, aka lamdba)
                          0.01, // maximum error
                          1000, // maximum iterations
                          true);// verbose

    // Predict the results
    array train_outputs = predict(train_feats, Weights);
    array test_outputs  = predict(test_feats , Weights);

    printf("Accuracy on training data: %2.2f\n",
           accuracy(train_outputs, train_targets ));

    printf("Accuracy on testing data: %2.2f\n",
           accuracy(test_outputs , test_targets ));

    printf("Maximum error on testing data: %2.2f\n",
           abserr(test_outputs , test_targets ));

    benchmark_softmax_regression(train_feats, train_targets, test_feats);

    if (!console) {
        test_outputs = test_outputs.T();
        // Get 20 random test images.
        display_results<true>(test_images, test_outputs, test_targets.T(), 20);
    }

    return 0;
}

int main(int argc, char** argv)
{
    int device   = argc > 1 ? atoi(argv[1]) : 0;
    bool console = argc > 2 ? argv[2][0] == '-' : false;
    int perc     = argc > 3 ? atoi(argv[3]) : 60;

    try {

        af::setDevice(device);
        af::info();
        return logit_demo(console, perc);

    } catch (af::exception &ae) {
        std::cerr << ae.what() << std::endl;
    }

    return 0;
}