File: _sag_fast.pyx

package info (click to toggle)
scikit-learn 0.23.2-5
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 21,892 kB
  • sloc: python: 132,020; cpp: 5,765; javascript: 2,201; ansic: 831; makefile: 213; sh: 44
file content (1358 lines) | stat: -rw-r--r-- 54,051 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358

#------------------------------------------------------------------------------

# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
#
# Authors: Danny Sullivan <dbsullivan23@gmail.com>
#          Tom Dupre la Tour <tom.dupre-la-tour@m4x.org>
#          Arthur Mensch <arthur.mensch@m4x.org
#
# License: BSD 3 clause

"""
SAG and SAGA implementation
WARNING: Do not edit .pyx file directly, it is generated from .pyx.tp
"""

cimport numpy as np
import numpy as np
from libc.math cimport fabs, exp, log
from libc.time cimport time, time_t

from ._sgd_fast cimport LossFunction
from ._sgd_fast cimport Log, SquaredLoss

from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64

from libc.stdio cimport printf

np.import_array()


cdef extern from "_sgd_fast_helpers.h":
    bint skl_isfinite64(double) nogil


cdef extern from "_sgd_fast_helpers.h":
    bint skl_isfinite32(float) nogil


cdef inline double fmax64(double x, double y) nogil:
    if x > y:
        return x
    return y

cdef inline float fmax32(float x, float y) nogil:
    if x > y:
        return x
    return y

cdef double _logsumexp64(double* arr, int n_classes) nogil:
    """Computes the sum of arr assuming arr is in the log domain.

    Returns log(sum(exp(arr))) while minimizing the possibility of
    over/underflow.
    """
    # Use the max to normalize, as with the log this is what accumulates
    # the less errors
    cdef double vmax = arr[0]
    cdef double out = 0.0
    cdef int i

    for i in range(1, n_classes):
        if vmax < arr[i]:
            vmax = arr[i]

    for i in range(n_classes):
        out += exp(arr[i] - vmax)

    return log(out) + vmax

cdef float _logsumexp32(float* arr, int n_classes) nogil:
    """Computes the sum of arr assuming arr is in the log domain.

    Returns log(sum(exp(arr))) while minimizing the possibility of
    over/underflow.
    """
    # Use the max to normalize, as with the log this is what accumulates
    # the less errors
    cdef float vmax = arr[0]
    cdef float out = 0.0
    cdef int i

    for i in range(1, n_classes):
        if vmax < arr[i]:
            vmax = arr[i]

    for i in range(n_classes):
        out += exp(arr[i] - vmax)

    return log(out) + vmax

cdef class MultinomialLogLoss64:
    cdef double _loss(self, double* prediction, double y, int n_classes,
                      double sample_weight) nogil:
        r"""Multinomial Logistic regression loss.

        The multinomial logistic loss for one sample is:
        loss = - sw \sum_c \delta_{y,c} (prediction[c] - logsumexp(prediction))
             = sw (logsumexp(prediction) - prediction[y])

        where:
            prediction = dot(x_sample, weights) + intercept
            \delta_{y,c} = 1 if (y == c) else 0
            sw = sample_weight

        Parameters
        ----------
        prediction : pointer to a np.ndarray[double] of shape (n_classes,)
            Prediction of the multinomial classifier, for current sample.

        y : double, between 0 and n_classes - 1
            Indice of the correct class for current sample (i.e. label encoded).

        n_classes : integer
            Total number of classes.

        sample_weight : double
            Weight of current sample.

        Returns
        -------
        loss : double
            Multinomial loss for current sample.

        Reference
        ---------
        Bishop, C. M. (2006). Pattern recognition and machine learning.
        Springer. (Chapter 4.3.4)
        """
        cdef double logsumexp_prediction = _logsumexp64(prediction, n_classes)
        cdef double loss

        # y is the indice of the correct class of current sample.
        loss = (logsumexp_prediction - prediction[int(y)]) * sample_weight
        return loss

    cdef void _dloss(self, double* prediction, double y, int n_classes,
                     double sample_weight, double* gradient_ptr) nogil:
        r"""Multinomial Logistic regression gradient of the loss.

        The gradient of the multinomial logistic loss with respect to a class c,
        and for one sample is:
        grad_c = - sw * (p[c] - \delta_{y,c})

        where:
            p[c] = exp(logsumexp(prediction) - prediction[c])
            prediction = dot(sample, weights) + intercept
            \delta_{y,c} = 1 if (y == c) else 0
            sw = sample_weight

        Note that to obtain the true gradient, this value has to be multiplied
        by the sample vector x.

        Parameters
        ----------
        prediction : pointer to a np.ndarray[double] of shape (n_classes,)
            Prediction of the multinomial classifier, for current sample.

        y : double, between 0 and n_classes - 1
            Indice of the correct class for current sample (i.e. label encoded)

        n_classes : integer
            Total number of classes.

        sample_weight : double
            Weight of current sample.

        gradient_ptr : pointer to a np.ndarray[double] of shape (n_classes,)
            Gradient vector to be filled.

        Reference
        ---------
        Bishop, C. M. (2006). Pattern recognition and machine learning.
        Springer. (Chapter 4.3.4)
        """
        cdef double logsumexp_prediction = _logsumexp64(prediction, n_classes)
        cdef int class_ind

        for class_ind in range(n_classes):
            gradient_ptr[class_ind] = exp(prediction[class_ind] -
                                          logsumexp_prediction)

            # y is the indice of the correct class of current sample.
            if class_ind == y:
                gradient_ptr[class_ind] -= 1.0

            gradient_ptr[class_ind] *= sample_weight

    def __reduce__(self):
        return MultinomialLogLoss64, ()

cdef class MultinomialLogLoss32:
    cdef float _loss(self, float* prediction, float y, int n_classes,
                      float sample_weight) nogil:
        r"""Multinomial Logistic regression loss.

        The multinomial logistic loss for one sample is:
        loss = - sw \sum_c \delta_{y,c} (prediction[c] - logsumexp(prediction))
             = sw (logsumexp(prediction) - prediction[y])

        where:
            prediction = dot(x_sample, weights) + intercept
            \delta_{y,c} = 1 if (y == c) else 0
            sw = sample_weight

        Parameters
        ----------
        prediction : pointer to a np.ndarray[float] of shape (n_classes,)
            Prediction of the multinomial classifier, for current sample.

        y : float, between 0 and n_classes - 1
            Indice of the correct class for current sample (i.e. label encoded).

        n_classes : integer
            Total number of classes.

        sample_weight : float
            Weight of current sample.

        Returns
        -------
        loss : float
            Multinomial loss for current sample.

        Reference
        ---------
        Bishop, C. M. (2006). Pattern recognition and machine learning.
        Springer. (Chapter 4.3.4)
        """
        cdef float logsumexp_prediction = _logsumexp32(prediction, n_classes)
        cdef float loss

        # y is the indice of the correct class of current sample.
        loss = (logsumexp_prediction - prediction[int(y)]) * sample_weight
        return loss

    cdef void _dloss(self, float* prediction, float y, int n_classes,
                     float sample_weight, float* gradient_ptr) nogil:
        r"""Multinomial Logistic regression gradient of the loss.

        The gradient of the multinomial logistic loss with respect to a class c,
        and for one sample is:
        grad_c = - sw * (p[c] - \delta_{y,c})

        where:
            p[c] = exp(logsumexp(prediction) - prediction[c])
            prediction = dot(sample, weights) + intercept
            \delta_{y,c} = 1 if (y == c) else 0
            sw = sample_weight

        Note that to obtain the true gradient, this value has to be multiplied
        by the sample vector x.

        Parameters
        ----------
        prediction : pointer to a np.ndarray[float] of shape (n_classes,)
            Prediction of the multinomial classifier, for current sample.

        y : float, between 0 and n_classes - 1
            Indice of the correct class for current sample (i.e. label encoded)

        n_classes : integer
            Total number of classes.

        sample_weight : float
            Weight of current sample.

        gradient_ptr : pointer to a np.ndarray[float] of shape (n_classes,)
            Gradient vector to be filled.

        Reference
        ---------
        Bishop, C. M. (2006). Pattern recognition and machine learning.
        Springer. (Chapter 4.3.4)
        """
        cdef float logsumexp_prediction = _logsumexp32(prediction, n_classes)
        cdef int class_ind

        for class_ind in range(n_classes):
            gradient_ptr[class_ind] = exp(prediction[class_ind] -
                                          logsumexp_prediction)

            # y is the indice of the correct class of current sample.
            if class_ind == y:
                gradient_ptr[class_ind] -= 1.0

            gradient_ptr[class_ind] *= sample_weight

    def __reduce__(self):
        return MultinomialLogLoss32, ()

cdef inline double _soft_thresholding64(double x, double shrinkage) nogil:
    return fmax64(x - shrinkage, 0) - fmax64(- x - shrinkage, 0)

cdef inline float _soft_thresholding32(float x, float shrinkage) nogil:
    return fmax32(x - shrinkage, 0) - fmax32(- x - shrinkage, 0)

def sag64(SequentialDataset64 dataset,
        np.ndarray[double, ndim=2, mode='c'] weights_array,
        np.ndarray[double, ndim=1, mode='c'] intercept_array,
        int n_samples,
        int n_features,
        int n_classes,
        double tol,
        int max_iter,
        str loss_function,
        double step_size,
        double alpha,
        double beta,
        np.ndarray[double, ndim=2, mode='c'] sum_gradient_init,
        np.ndarray[double, ndim=2, mode='c'] gradient_memory_init,
        np.ndarray[bint, ndim=1, mode='c'] seen_init,
        int num_seen,
        bint fit_intercept,
        np.ndarray[double, ndim=1, mode='c'] intercept_sum_gradient_init,
        double intercept_decay,
        bint saga,
        bint verbose):
    """Stochastic Average Gradient (SAG) and SAGA solvers.

    Used in Ridge and LogisticRegression.

    Reference
    ---------
    Schmidt, M., Roux, N. L., & Bach, F. (2013).
    Minimizing finite sums with the stochastic average gradient
    https://hal.inria.fr/hal-00860051/document
    (section 4.3)

    Defazio, A., Bach, F., Lacoste-Julien, S. (2014),
    SAGA: A Fast Incremental Gradient Method With Support
    for Non-Strongly Convex Composite Objectives
    https://arxiv.org/abs/1407.0202

    """
    # the data pointer for x, the current sample
    cdef double *x_data_ptr = NULL
    # the index pointer for the column of the data
    cdef int *x_ind_ptr = NULL
    # the number of non-zero features for current sample
    cdef int xnnz = -1
    # the label value for current sample
    # the label value for curent sample
    cdef double y
    # the sample weight
    cdef double sample_weight

    # helper variable for indexes
    cdef int f_idx, s_idx, feature_ind, class_ind, j
    # the number of pass through all samples
    cdef int n_iter = 0
    # helper to track iterations through samples
    cdef int sample_itr
    # the index (row number) of the current sample
    cdef int sample_ind

    # the maximum change in weights, used to compute stopping criteria
    cdef double max_change
    # a holder variable for the max weight, used to compute stopping criteria
    cdef double max_weight

    # the start time of the fit
    cdef time_t start_time
    # the end time of the fit
    cdef time_t end_time

    # precomputation since the step size does not change in this implementation
    cdef double wscale_update = 1.0 - step_size * alpha

    # vector of booleans indicating whether this sample has been seen
    cdef bint* seen = <bint*> seen_init.data

    # helper for cumulative sum
    cdef double cum_sum

    # the pointer to the coef_ or weights
    cdef double* weights = <double * >weights_array.data
    # the pointer to the intercept_array
    cdef double* intercept = <double * >intercept_array.data

    # the pointer to the intercept_sum_gradient
    cdef double* intercept_sum_gradient = \
        <double * >intercept_sum_gradient_init.data

    # the sum of gradients for each feature
    cdef double* sum_gradient = <double*> sum_gradient_init.data
    # the previously seen gradient for each sample
    cdef double* gradient_memory = <double*> gradient_memory_init.data

    # the cumulative sums needed for JIT params
    cdef np.ndarray[double, ndim=1] cumulative_sums_array = \
        np.empty(n_samples, dtype=np.float64, order="c")
    cdef double* cumulative_sums = <double*> cumulative_sums_array.data

    # the index for the last time this feature was updated
    cdef np.ndarray[int, ndim=1] feature_hist_array = \
        np.zeros(n_features, dtype=np.int32, order="c")
    cdef int* feature_hist = <int*> feature_hist_array.data

    # the previous weights to use to compute stopping criteria
    cdef np.ndarray[double, ndim=2] previous_weights_array = \
        np.zeros((n_features, n_classes), dtype=np.float64, order="c")
    cdef double* previous_weights = <double*> previous_weights_array.data

    cdef np.ndarray[double, ndim=1] prediction_array = \
        np.zeros(n_classes, dtype=np.float64, order="c")
    cdef double* prediction = <double*> prediction_array.data

    cdef np.ndarray[double, ndim=1] gradient_array = \
        np.zeros(n_classes, dtype=np.float64, order="c")
    cdef double* gradient = <double*> gradient_array.data

    # Intermediate variable that need declaration since cython cannot infer when templating
    cdef double val

    # Bias correction term in saga
    cdef double gradient_correction

    # the scalar used for multiplying z
    cdef double wscale = 1.0

    # return value (-1 if an error occurred, 0 otherwise)
    cdef int status = 0

    # the cumulative sums for each iteration for the sparse implementation
    cumulative_sums[0] = 0.0

    # the multipliative scale needed for JIT params
    cdef np.ndarray[double, ndim=1] cumulative_sums_prox_array
    cdef double* cumulative_sums_prox

    cdef bint prox = beta > 0 and saga

    # Loss function to optimize
    cdef LossFunction loss
    # Wether the loss function is multinomial
    cdef bint multinomial = False
    # Multinomial loss function
    cdef MultinomialLogLoss64 multiloss

    if loss_function == "multinomial":
        multinomial = True
        multiloss = MultinomialLogLoss64()
    elif loss_function == "log":
        loss = Log()
    elif loss_function == "squared":
        loss = SquaredLoss()
    else:
        raise ValueError("Invalid loss parameter: got %s instead of "
                         "one of ('log', 'squared', 'multinomial')"
                         % loss_function)

    if prox:
        cumulative_sums_prox_array = np.empty(n_samples,
                                              dtype=np.float64, order="c")
        cumulative_sums_prox = <double*> cumulative_sums_prox_array.data
    else:
        cumulative_sums_prox = NULL

    with nogil:
        start_time = time(NULL)
        for n_iter in range(max_iter):
            for sample_itr in range(n_samples):
                # extract a random sample
                sample_ind = dataset.random(&x_data_ptr, &x_ind_ptr, &xnnz,
                                              &y, &sample_weight)

                # cached index for gradient_memory
                s_idx = sample_ind * n_classes

                # update the number of samples seen and the seen array
                if seen[sample_ind] == 0:
                    num_seen += 1
                    seen[sample_ind] = 1

                # make the weight updates
                if sample_itr > 0:
                   status = lagged_update64(weights, wscale, xnnz,
                                                  n_samples, n_classes,
                                                  sample_itr,
                                                  cumulative_sums,
                                                  cumulative_sums_prox,
                                                  feature_hist,
                                                  prox,
                                                  sum_gradient,
                                                  x_ind_ptr,
                                                  False,
                                                  n_iter)
                   if status == -1:
                       break

                # find the current prediction
                predict_sample64(x_data_ptr, x_ind_ptr, xnnz, weights, wscale,
                                       intercept, prediction, n_classes)

                # compute the gradient for this sample, given the prediction
                if multinomial:
                    multiloss._dloss(prediction, y, n_classes, sample_weight,
                                     gradient)
                else:
                    gradient[0] = loss._dloss(prediction[0], y) * sample_weight

                # L2 regularization by simply rescaling the weights
                wscale *= wscale_update

                # make the updates to the sum of gradients
                for j in range(xnnz):
                    feature_ind = x_ind_ptr[j]
                    val = x_data_ptr[j]
                    f_idx = feature_ind * n_classes
                    for class_ind in range(n_classes):
                        gradient_correction = \
                            val * (gradient[class_ind] -
                                   gradient_memory[s_idx + class_ind])
                        if saga:
                            weights[f_idx + class_ind] -= \
                                (gradient_correction * step_size
                                 * (1 - 1. / num_seen) / wscale)
                        sum_gradient[f_idx + class_ind] += gradient_correction

                # fit the intercept
                if fit_intercept:
                    for class_ind in range(n_classes):
                        gradient_correction = (gradient[class_ind] -
                                               gradient_memory[s_idx + class_ind])
                        intercept_sum_gradient[class_ind] += gradient_correction
                        gradient_correction *= step_size * (1. - 1. / num_seen)
                        if saga:
                            intercept[class_ind] -= \
                                (step_size * intercept_sum_gradient[class_ind] /
                                 num_seen * intercept_decay) + gradient_correction
                        else:
                            intercept[class_ind] -= \
                                (step_size * intercept_sum_gradient[class_ind] /
                                 num_seen * intercept_decay)

                        # check to see that the intercept is not inf or NaN
                        if not skl_isfinite64(intercept[class_ind]):
                            status = -1
                            break
                    # Break from the n_samples outer loop if an error happened
                    # in the fit_intercept n_classes inner loop
                    if status == -1:
                        break

                # update the gradient memory for this sample
                for class_ind in range(n_classes):
                    gradient_memory[s_idx + class_ind] = gradient[class_ind]

                if sample_itr == 0:
                    cumulative_sums[0] = step_size / (wscale * num_seen)
                    if prox:
                        cumulative_sums_prox[0] = step_size * beta / wscale
                else:
                    cumulative_sums[sample_itr] = \
                        (cumulative_sums[sample_itr - 1] +
                         step_size / (wscale * num_seen))
                    if prox:
                        cumulative_sums_prox[sample_itr] = \
                        (cumulative_sums_prox[sample_itr - 1] +
                             step_size * beta / wscale)
                # If wscale gets too small, we need to reset the scale.
                if wscale < 1e-9:
                    if verbose:
                        with gil:
                            print("rescaling...")
                    status = scale_weights64(
                        weights, &wscale, n_features, n_samples, n_classes,
                        sample_itr, cumulative_sums,
                        cumulative_sums_prox,
                        feature_hist,
                        prox, sum_gradient, n_iter)
                    if status == -1:
                        break

            # Break from the n_iter outer loop if an error happened in the
            # n_samples inner loop
            if status == -1:
                break

            # we scale the weights every n_samples iterations and reset the
            # just-in-time update system for numerical stability.
            status = scale_weights64(weights, &wscale, n_features,
                                           n_samples,
                                           n_classes, n_samples - 1,
                                           cumulative_sums,
                                           cumulative_sums_prox,
                                           feature_hist,
                                           prox, sum_gradient, n_iter)

            if status == -1:
                break
            # check if the stopping criteria is reached
            max_change = 0.0
            max_weight = 0.0
            for idx in range(n_features * n_classes):
                max_weight = fmax64(max_weight, fabs(weights[idx]))
                max_change = fmax64(max_change,
                                  fabs(weights[idx] -
                                       previous_weights[idx]))
                previous_weights[idx] = weights[idx]
            if ((max_weight != 0 and max_change / max_weight <= tol)
                or max_weight == 0 and max_change == 0):
                if verbose:
                    end_time = time(NULL)
                    with gil:
                        print("convergence after %d epochs took %d seconds" %
                              (n_iter + 1, end_time - start_time))
                break
            elif verbose:
                printf('Epoch %d, change: %.8f\n', n_iter + 1,
                                                  max_change / max_weight)
    n_iter += 1
    # We do the error treatment here based on error code in status to avoid
    # re-acquiring the GIL within the cython code, which slows the computation
    # when the sag/saga solver is used concurrently in multiple Python threads.
    if status == -1:
        raise ValueError(("Floating-point under-/overflow occurred at epoch"
                          " #%d. Scaling input data with StandardScaler or"
                          " MinMaxScaler might help.") % n_iter)

    if verbose and n_iter >= max_iter:
        end_time = time(NULL)
        print(("max_iter reached after %d seconds") %
              (end_time - start_time))

    return num_seen, n_iter

def sag32(SequentialDataset32 dataset,
        np.ndarray[float, ndim=2, mode='c'] weights_array,
        np.ndarray[float, ndim=1, mode='c'] intercept_array,
        int n_samples,
        int n_features,
        int n_classes,
        double tol,
        int max_iter,
        str loss_function,
        double step_size,
        double alpha,
        double beta,
        np.ndarray[float, ndim=2, mode='c'] sum_gradient_init,
        np.ndarray[float, ndim=2, mode='c'] gradient_memory_init,
        np.ndarray[bint, ndim=1, mode='c'] seen_init,
        int num_seen,
        bint fit_intercept,
        np.ndarray[float, ndim=1, mode='c'] intercept_sum_gradient_init,
        double intercept_decay,
        bint saga,
        bint verbose):
    """Stochastic Average Gradient (SAG) and SAGA solvers.

    Used in Ridge and LogisticRegression.

    Reference
    ---------
    Schmidt, M., Roux, N. L., & Bach, F. (2013).
    Minimizing finite sums with the stochastic average gradient
    https://hal.inria.fr/hal-00860051/document
    (section 4.3)

    Defazio, A., Bach, F., Lacoste-Julien, S. (2014),
    SAGA: A Fast Incremental Gradient Method With Support
    for Non-Strongly Convex Composite Objectives
    https://arxiv.org/abs/1407.0202

    """
    # the data pointer for x, the current sample
    cdef float *x_data_ptr = NULL
    # the index pointer for the column of the data
    cdef int *x_ind_ptr = NULL
    # the number of non-zero features for current sample
    cdef int xnnz = -1
    # the label value for current sample
    # the label value for curent sample
    cdef float y
    # the sample weight
    cdef float sample_weight

    # helper variable for indexes
    cdef int f_idx, s_idx, feature_ind, class_ind, j
    # the number of pass through all samples
    cdef int n_iter = 0
    # helper to track iterations through samples
    cdef int sample_itr
    # the index (row number) of the current sample
    cdef int sample_ind

    # the maximum change in weights, used to compute stopping criteria
    cdef float max_change
    # a holder variable for the max weight, used to compute stopping criteria
    cdef float max_weight

    # the start time of the fit
    cdef time_t start_time
    # the end time of the fit
    cdef time_t end_time

    # precomputation since the step size does not change in this implementation
    cdef float wscale_update = 1.0 - step_size * alpha

    # vector of booleans indicating whether this sample has been seen
    cdef bint* seen = <bint*> seen_init.data

    # helper for cumulative sum
    cdef float cum_sum

    # the pointer to the coef_ or weights
    cdef float* weights = <float * >weights_array.data
    # the pointer to the intercept_array
    cdef float* intercept = <float * >intercept_array.data

    # the pointer to the intercept_sum_gradient
    cdef float* intercept_sum_gradient = \
        <float * >intercept_sum_gradient_init.data

    # the sum of gradients for each feature
    cdef float* sum_gradient = <float*> sum_gradient_init.data
    # the previously seen gradient for each sample
    cdef float* gradient_memory = <float*> gradient_memory_init.data

    # the cumulative sums needed for JIT params
    cdef np.ndarray[float, ndim=1] cumulative_sums_array = \
        np.empty(n_samples, dtype=np.float32, order="c")
    cdef float* cumulative_sums = <float*> cumulative_sums_array.data

    # the index for the last time this feature was updated
    cdef np.ndarray[int, ndim=1] feature_hist_array = \
        np.zeros(n_features, dtype=np.int32, order="c")
    cdef int* feature_hist = <int*> feature_hist_array.data

    # the previous weights to use to compute stopping criteria
    cdef np.ndarray[float, ndim=2] previous_weights_array = \
        np.zeros((n_features, n_classes), dtype=np.float32, order="c")
    cdef float* previous_weights = <float*> previous_weights_array.data

    cdef np.ndarray[float, ndim=1] prediction_array = \
        np.zeros(n_classes, dtype=np.float32, order="c")
    cdef float* prediction = <float*> prediction_array.data

    cdef np.ndarray[float, ndim=1] gradient_array = \
        np.zeros(n_classes, dtype=np.float32, order="c")
    cdef float* gradient = <float*> gradient_array.data

    # Intermediate variable that need declaration since cython cannot infer when templating
    cdef float val

    # Bias correction term in saga
    cdef float gradient_correction

    # the scalar used for multiplying z
    cdef float wscale = 1.0

    # return value (-1 if an error occurred, 0 otherwise)
    cdef int status = 0

    # the cumulative sums for each iteration for the sparse implementation
    cumulative_sums[0] = 0.0

    # the multipliative scale needed for JIT params
    cdef np.ndarray[float, ndim=1] cumulative_sums_prox_array
    cdef float* cumulative_sums_prox

    cdef bint prox = beta > 0 and saga

    # Loss function to optimize
    cdef LossFunction loss
    # Wether the loss function is multinomial
    cdef bint multinomial = False
    # Multinomial loss function
    cdef MultinomialLogLoss32 multiloss

    if loss_function == "multinomial":
        multinomial = True
        multiloss = MultinomialLogLoss32()
    elif loss_function == "log":
        loss = Log()
    elif loss_function == "squared":
        loss = SquaredLoss()
    else:
        raise ValueError("Invalid loss parameter: got %s instead of "
                         "one of ('log', 'squared', 'multinomial')"
                         % loss_function)

    if prox:
        cumulative_sums_prox_array = np.empty(n_samples,
                                              dtype=np.float32, order="c")
        cumulative_sums_prox = <float*> cumulative_sums_prox_array.data
    else:
        cumulative_sums_prox = NULL

    with nogil:
        start_time = time(NULL)
        for n_iter in range(max_iter):
            for sample_itr in range(n_samples):
                # extract a random sample
                sample_ind = dataset.random(&x_data_ptr, &x_ind_ptr, &xnnz,
                                              &y, &sample_weight)

                # cached index for gradient_memory
                s_idx = sample_ind * n_classes

                # update the number of samples seen and the seen array
                if seen[sample_ind] == 0:
                    num_seen += 1
                    seen[sample_ind] = 1

                # make the weight updates
                if sample_itr > 0:
                   status = lagged_update32(weights, wscale, xnnz,
                                                  n_samples, n_classes,
                                                  sample_itr,
                                                  cumulative_sums,
                                                  cumulative_sums_prox,
                                                  feature_hist,
                                                  prox,
                                                  sum_gradient,
                                                  x_ind_ptr,
                                                  False,
                                                  n_iter)
                   if status == -1:
                       break

                # find the current prediction
                predict_sample32(x_data_ptr, x_ind_ptr, xnnz, weights, wscale,
                                       intercept, prediction, n_classes)

                # compute the gradient for this sample, given the prediction
                if multinomial:
                    multiloss._dloss(prediction, y, n_classes, sample_weight,
                                     gradient)
                else:
                    gradient[0] = loss._dloss(prediction[0], y) * sample_weight

                # L2 regularization by simply rescaling the weights
                wscale *= wscale_update

                # make the updates to the sum of gradients
                for j in range(xnnz):
                    feature_ind = x_ind_ptr[j]
                    val = x_data_ptr[j]
                    f_idx = feature_ind * n_classes
                    for class_ind in range(n_classes):
                        gradient_correction = \
                            val * (gradient[class_ind] -
                                   gradient_memory[s_idx + class_ind])
                        if saga:
                            weights[f_idx + class_ind] -= \
                                (gradient_correction * step_size
                                 * (1 - 1. / num_seen) / wscale)
                        sum_gradient[f_idx + class_ind] += gradient_correction

                # fit the intercept
                if fit_intercept:
                    for class_ind in range(n_classes):
                        gradient_correction = (gradient[class_ind] -
                                               gradient_memory[s_idx + class_ind])
                        intercept_sum_gradient[class_ind] += gradient_correction
                        gradient_correction *= step_size * (1. - 1. / num_seen)
                        if saga:
                            intercept[class_ind] -= \
                                (step_size * intercept_sum_gradient[class_ind] /
                                 num_seen * intercept_decay) + gradient_correction
                        else:
                            intercept[class_ind] -= \
                                (step_size * intercept_sum_gradient[class_ind] /
                                 num_seen * intercept_decay)

                        # check to see that the intercept is not inf or NaN
                        if not skl_isfinite32(intercept[class_ind]):
                            status = -1
                            break
                    # Break from the n_samples outer loop if an error happened
                    # in the fit_intercept n_classes inner loop
                    if status == -1:
                        break

                # update the gradient memory for this sample
                for class_ind in range(n_classes):
                    gradient_memory[s_idx + class_ind] = gradient[class_ind]

                if sample_itr == 0:
                    cumulative_sums[0] = step_size / (wscale * num_seen)
                    if prox:
                        cumulative_sums_prox[0] = step_size * beta / wscale
                else:
                    cumulative_sums[sample_itr] = \
                        (cumulative_sums[sample_itr - 1] +
                         step_size / (wscale * num_seen))
                    if prox:
                        cumulative_sums_prox[sample_itr] = \
                        (cumulative_sums_prox[sample_itr - 1] +
                             step_size * beta / wscale)
                # If wscale gets too small, we need to reset the scale.
                if wscale < 1e-9:
                    if verbose:
                        with gil:
                            print("rescaling...")
                    status = scale_weights32(
                        weights, &wscale, n_features, n_samples, n_classes,
                        sample_itr, cumulative_sums,
                        cumulative_sums_prox,
                        feature_hist,
                        prox, sum_gradient, n_iter)
                    if status == -1:
                        break

            # Break from the n_iter outer loop if an error happened in the
            # n_samples inner loop
            if status == -1:
                break

            # we scale the weights every n_samples iterations and reset the
            # just-in-time update system for numerical stability.
            status = scale_weights32(weights, &wscale, n_features,
                                           n_samples,
                                           n_classes, n_samples - 1,
                                           cumulative_sums,
                                           cumulative_sums_prox,
                                           feature_hist,
                                           prox, sum_gradient, n_iter)

            if status == -1:
                break
            # check if the stopping criteria is reached
            max_change = 0.0
            max_weight = 0.0
            for idx in range(n_features * n_classes):
                max_weight = fmax32(max_weight, fabs(weights[idx]))
                max_change = fmax32(max_change,
                                  fabs(weights[idx] -
                                       previous_weights[idx]))
                previous_weights[idx] = weights[idx]
            if ((max_weight != 0 and max_change / max_weight <= tol)
                or max_weight == 0 and max_change == 0):
                if verbose:
                    end_time = time(NULL)
                    with gil:
                        print("convergence after %d epochs took %d seconds" %
                              (n_iter + 1, end_time - start_time))
                break
            elif verbose:
                printf('Epoch %d, change: %.8f\n', n_iter + 1,
                                                  max_change / max_weight)
    n_iter += 1
    # We do the error treatment here based on error code in status to avoid
    # re-acquiring the GIL within the cython code, which slows the computation
    # when the sag/saga solver is used concurrently in multiple Python threads.
    if status == -1:
        raise ValueError(("Floating-point under-/overflow occurred at epoch"
                          " #%d. Scaling input data with StandardScaler or"
                          " MinMaxScaler might help.") % n_iter)

    if verbose and n_iter >= max_iter:
        end_time = time(NULL)
        print(("max_iter reached after %d seconds") %
              (end_time - start_time))

    return num_seen, n_iter

cdef int scale_weights64(double* weights, double* wscale,
                               int n_features,
                               int n_samples, int n_classes, int sample_itr,
                               double* cumulative_sums,
                               double* cumulative_sums_prox,
                               int* feature_hist,
                               bint prox,
                               double* sum_gradient,
                               int n_iter) nogil:
    """Scale the weights with wscale for numerical stability.

    wscale = (1 - step_size * alpha) ** (n_iter * n_samples + sample_itr)
    can become very small, so we reset it every n_samples iterations to 1.0 for
    numerical stability. To be able to scale, we first need to update every
    coefficients and reset the just-in-time update system.
    This also limits the size of `cumulative_sums`.
    """

    cdef int status
    status = lagged_update64(weights, wscale[0], n_features,
                                   n_samples, n_classes, sample_itr + 1,
                                   cumulative_sums,
                                   cumulative_sums_prox,
                                   feature_hist,
                                   prox,
                                   sum_gradient,
                                   NULL,
                                   True,
                                   n_iter)
    # if lagged update succeeded, reset wscale to 1.0
    if status == 0:
        wscale[0] = 1.0
    return status

cdef int scale_weights32(float* weights, float* wscale,
                               int n_features,
                               int n_samples, int n_classes, int sample_itr,
                               float* cumulative_sums,
                               float* cumulative_sums_prox,
                               int* feature_hist,
                               bint prox,
                               float* sum_gradient,
                               int n_iter) nogil:
    """Scale the weights with wscale for numerical stability.

    wscale = (1 - step_size * alpha) ** (n_iter * n_samples + sample_itr)
    can become very small, so we reset it every n_samples iterations to 1.0 for
    numerical stability. To be able to scale, we first need to update every
    coefficients and reset the just-in-time update system.
    This also limits the size of `cumulative_sums`.
    """

    cdef int status
    status = lagged_update32(weights, wscale[0], n_features,
                                   n_samples, n_classes, sample_itr + 1,
                                   cumulative_sums,
                                   cumulative_sums_prox,
                                   feature_hist,
                                   prox,
                                   sum_gradient,
                                   NULL,
                                   True,
                                   n_iter)
    # if lagged update succeeded, reset wscale to 1.0
    if status == 0:
        wscale[0] = 1.0
    return status

cdef int lagged_update64(double* weights, double wscale, int xnnz,
                               int n_samples, int n_classes, int sample_itr,
                               double* cumulative_sums,
                               double* cumulative_sums_prox,
                               int* feature_hist,
                               bint prox,
                               double* sum_gradient,
                               int* x_ind_ptr,
                               bint reset,
                               int n_iter) nogil:
    """Hard perform the JIT updates for non-zero features of present sample.
    The updates that awaits are kept in memory using cumulative_sums,
    cumulative_sums_prox, wscale and feature_hist. See original SAGA paper
    (Defazio et al. 2014) for details. If reset=True, we also reset wscale to
    1 (this is done at the end of each epoch).
    """
    cdef int feature_ind, class_ind, idx, f_idx, lagged_ind, last_update_ind
    cdef double cum_sum, grad_step, prox_step, cum_sum_prox
    for feature_ind in range(xnnz):
        if not reset:
            feature_ind = x_ind_ptr[feature_ind]
        f_idx = feature_ind * n_classes

        cum_sum = cumulative_sums[sample_itr - 1]
        if prox:
            cum_sum_prox = cumulative_sums_prox[sample_itr - 1]
        if feature_hist[feature_ind] != 0:
            cum_sum -= cumulative_sums[feature_hist[feature_ind] - 1]
            if prox:
                cum_sum_prox -= cumulative_sums_prox[feature_hist[feature_ind] - 1]
        if not prox:
            for class_ind in range(n_classes):
                idx = f_idx + class_ind
                weights[idx] -= cum_sum * sum_gradient[idx]
                if reset:
                    weights[idx] *= wscale
                    if not skl_isfinite64(weights[idx]):
                        # returning here does not require the gil as the return
                        # type is a C integer
                        return -1
        else:
            for class_ind in range(n_classes):
                idx = f_idx + class_ind
                if fabs(sum_gradient[idx] * cum_sum) < cum_sum_prox:
                    # In this case, we can perform all the gradient steps and
                    # all the proximal steps in this order, which is more
                    # efficient than unrolling all the lagged updates.
                    # Idea taken from scikit-learn-contrib/lightning.
                    weights[idx] -= cum_sum * sum_gradient[idx]
                    weights[idx] = _soft_thresholding64(weights[idx],
                                                      cum_sum_prox)
                else:
                    last_update_ind = feature_hist[feature_ind]
                    if last_update_ind == -1:
                        last_update_ind = sample_itr - 1
                    for lagged_ind in range(sample_itr - 1,
                                   last_update_ind - 1, -1):
                        if lagged_ind > 0:
                            grad_step = (cumulative_sums[lagged_ind]
                               - cumulative_sums[lagged_ind - 1])
                            prox_step = (cumulative_sums_prox[lagged_ind]
                               - cumulative_sums_prox[lagged_ind - 1])
                        else:
                            grad_step = cumulative_sums[lagged_ind]
                            prox_step = cumulative_sums_prox[lagged_ind]
                        weights[idx] -= sum_gradient[idx] * grad_step
                        weights[idx] = _soft_thresholding64(weights[idx],
                                                          prox_step)

                if reset:
                    weights[idx] *= wscale
                    # check to see that the weight is not inf or NaN
                    if not skl_isfinite64(weights[idx]):
                        return -1
        if reset:
            feature_hist[feature_ind] = sample_itr % n_samples
        else:
            feature_hist[feature_ind] = sample_itr

    if reset:
        cumulative_sums[sample_itr - 1] = 0.0
        if prox:
            cumulative_sums_prox[sample_itr - 1] = 0.0

    return 0

cdef int lagged_update32(float* weights, float wscale, int xnnz,
                               int n_samples, int n_classes, int sample_itr,
                               float* cumulative_sums,
                               float* cumulative_sums_prox,
                               int* feature_hist,
                               bint prox,
                               float* sum_gradient,
                               int* x_ind_ptr,
                               bint reset,
                               int n_iter) nogil:
    """Hard perform the JIT updates for non-zero features of present sample.
    The updates that awaits are kept in memory using cumulative_sums,
    cumulative_sums_prox, wscale and feature_hist. See original SAGA paper
    (Defazio et al. 2014) for details. If reset=True, we also reset wscale to
    1 (this is done at the end of each epoch).
    """
    cdef int feature_ind, class_ind, idx, f_idx, lagged_ind, last_update_ind
    cdef float cum_sum, grad_step, prox_step, cum_sum_prox
    for feature_ind in range(xnnz):
        if not reset:
            feature_ind = x_ind_ptr[feature_ind]
        f_idx = feature_ind * n_classes

        cum_sum = cumulative_sums[sample_itr - 1]
        if prox:
            cum_sum_prox = cumulative_sums_prox[sample_itr - 1]
        if feature_hist[feature_ind] != 0:
            cum_sum -= cumulative_sums[feature_hist[feature_ind] - 1]
            if prox:
                cum_sum_prox -= cumulative_sums_prox[feature_hist[feature_ind] - 1]
        if not prox:
            for class_ind in range(n_classes):
                idx = f_idx + class_ind
                weights[idx] -= cum_sum * sum_gradient[idx]
                if reset:
                    weights[idx] *= wscale
                    if not skl_isfinite32(weights[idx]):
                        # returning here does not require the gil as the return
                        # type is a C integer
                        return -1
        else:
            for class_ind in range(n_classes):
                idx = f_idx + class_ind
                if fabs(sum_gradient[idx] * cum_sum) < cum_sum_prox:
                    # In this case, we can perform all the gradient steps and
                    # all the proximal steps in this order, which is more
                    # efficient than unrolling all the lagged updates.
                    # Idea taken from scikit-learn-contrib/lightning.
                    weights[idx] -= cum_sum * sum_gradient[idx]
                    weights[idx] = _soft_thresholding32(weights[idx],
                                                      cum_sum_prox)
                else:
                    last_update_ind = feature_hist[feature_ind]
                    if last_update_ind == -1:
                        last_update_ind = sample_itr - 1
                    for lagged_ind in range(sample_itr - 1,
                                   last_update_ind - 1, -1):
                        if lagged_ind > 0:
                            grad_step = (cumulative_sums[lagged_ind]
                               - cumulative_sums[lagged_ind - 1])
                            prox_step = (cumulative_sums_prox[lagged_ind]
                               - cumulative_sums_prox[lagged_ind - 1])
                        else:
                            grad_step = cumulative_sums[lagged_ind]
                            prox_step = cumulative_sums_prox[lagged_ind]
                        weights[idx] -= sum_gradient[idx] * grad_step
                        weights[idx] = _soft_thresholding32(weights[idx],
                                                          prox_step)

                if reset:
                    weights[idx] *= wscale
                    # check to see that the weight is not inf or NaN
                    if not skl_isfinite32(weights[idx]):
                        return -1
        if reset:
            feature_hist[feature_ind] = sample_itr % n_samples
        else:
            feature_hist[feature_ind] = sample_itr

    if reset:
        cumulative_sums[sample_itr - 1] = 0.0
        if prox:
            cumulative_sums_prox[sample_itr - 1] = 0.0

    return 0

cdef void predict_sample64(double* x_data_ptr, int* x_ind_ptr, int xnnz,
                                 double* w_data_ptr, double wscale,
                                 double* intercept, double* prediction,
                                 int n_classes) nogil:
    """Compute the prediction given sparse sample x and dense weight w.

    Parameters
    ----------
    x_data_ptr : pointer
        Pointer to the data of the sample x

    x_ind_ptr : pointer
        Pointer to the indices of the sample  x

    xnnz : int
        Number of non-zero element in the sample  x

    w_data_ptr : pointer
        Pointer to the data of the weights w

    wscale : double
        Scale of the weights w

    intercept : pointer
        Pointer to the intercept

    prediction : pointer
        Pointer to store the resulting prediction

    n_classes : int
        Number of classes in multinomial case. Equals 1 in binary case.

    """
    cdef int feature_ind, class_ind, j
    cdef double innerprod

    for class_ind in range(n_classes):
        innerprod = 0.0
        # Compute the dot product only on non-zero elements of x
        for j in range(xnnz):
            feature_ind = x_ind_ptr[j]
            innerprod += (w_data_ptr[feature_ind * n_classes + class_ind] *
                          x_data_ptr[j])

        prediction[class_ind] = wscale * innerprod + intercept[class_ind]


cdef void predict_sample32(float* x_data_ptr, int* x_ind_ptr, int xnnz,
                                 float* w_data_ptr, float wscale,
                                 float* intercept, float* prediction,
                                 int n_classes) nogil:
    """Compute the prediction given sparse sample x and dense weight w.

    Parameters
    ----------
    x_data_ptr : pointer
        Pointer to the data of the sample x

    x_ind_ptr : pointer
        Pointer to the indices of the sample  x

    xnnz : int
        Number of non-zero element in the sample  x

    w_data_ptr : pointer
        Pointer to the data of the weights w

    wscale : float
        Scale of the weights w

    intercept : pointer
        Pointer to the intercept

    prediction : pointer
        Pointer to store the resulting prediction

    n_classes : int
        Number of classes in multinomial case. Equals 1 in binary case.

    """
    cdef int feature_ind, class_ind, j
    cdef float innerprod

    for class_ind in range(n_classes):
        innerprod = 0.0
        # Compute the dot product only on non-zero elements of x
        for j in range(xnnz):
            feature_ind = x_ind_ptr[j]
            innerprod += (w_data_ptr[feature_ind * n_classes + class_ind] *
                          x_data_ptr[j])

        prediction[class_ind] = wscale * innerprod + intercept[class_ind]



def _multinomial_grad_loss_all_samples(
        SequentialDataset64 dataset,
        np.ndarray[double, ndim=2, mode='c'] weights_array,
        np.ndarray[double, ndim=1, mode='c'] intercept_array,
        int n_samples, int n_features, int n_classes):
    """Compute multinomial gradient and loss across all samples.

    Used for testing purpose only.
    """
    cdef double* weights = <double * >weights_array.data
    cdef double* intercept = <double * >intercept_array.data

    cdef double *x_data_ptr = NULL
    cdef int *x_ind_ptr = NULL
    cdef int xnnz = -1
    cdef double y
    cdef double sample_weight

    cdef double wscale = 1.0
    cdef int i, j, class_ind, feature_ind
    cdef double val
    cdef double sum_loss = 0.0

    cdef MultinomialLogLoss64 multiloss = MultinomialLogLoss64()

    cdef np.ndarray[double, ndim=2] sum_gradient_array = \
        np.zeros((n_features, n_classes), dtype=np.double, order="c")
    cdef double* sum_gradient = <double*> sum_gradient_array.data

    cdef np.ndarray[double, ndim=1] prediction_array = \
        np.zeros(n_classes, dtype=np.double, order="c")
    cdef double* prediction = <double*> prediction_array.data

    cdef np.ndarray[double, ndim=1] gradient_array = \
        np.zeros(n_classes, dtype=np.double, order="c")
    cdef double* gradient = <double*> gradient_array.data

    with nogil:
        for i in range(n_samples):
            # get next sample on the dataset
            dataset.next(&x_data_ptr, &x_ind_ptr, &xnnz,
                         &y, &sample_weight)

            # prediction of the multinomial classifier for the sample
            predict_sample64(x_data_ptr, x_ind_ptr, xnnz, weights, wscale,
                           intercept, prediction, n_classes)

            # compute the gradient for this sample, given the prediction
            multiloss._dloss(prediction, y, n_classes, sample_weight, gradient)

            # compute the loss for this sample, given the prediction
            sum_loss += multiloss._loss(prediction, y, n_classes, sample_weight)

            # update the sum of the gradient
            for j in range(xnnz):
                feature_ind = x_ind_ptr[j]
                val = x_data_ptr[j]
                for class_ind in range(n_classes):
                    sum_gradient[feature_ind * n_classes + class_ind] += \
                        gradient[class_ind] * val

    return sum_loss, sum_gradient_array