File: linear_model_fit.h

package info (click to toggle)
r-bioc-scuttle 1.0.4%2Bdfsg-5
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 728 kB
  • sloc: cpp: 356; sh: 17; makefile: 2
file content (84 lines) | stat: -rw-r--r-- 2,180 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#ifndef LINEAR_MODEL_FIT_H 
#define LINEAR_MODEL_FIT_H

#include "Rcpp.h"
#include "R_ext/BLAS.h"
#include "R_ext/Lapack.h"

#include <stdexcept>

namespace scuttle {

class QR_multiplier {
public:
    QR_multiplier(Rcpp::NumericMatrix qr, Rcpp::NumericVector qraux, const char tr='T') :
        QR(qr), AUX(qraux), qrptr(QR.begin()), qxptr(AUX.begin()), nobs(QR.nrow()), ncoef(QR.ncol()), trans(tr) 
    {
        if (AUX.size()!=ncoef) { 
            throw std::runtime_error("QR auxiliary vector should be of length 'ncol(Q)'"); 
        }

        work.resize(nobs);
        double tmpwork=0;

        F77_CALL(dormqr)(&side, &trans, &nobs, &ncol, &ncoef, qrptr, &nobs, 
            qxptr, work.data(), &nobs, &tmpwork, &lwork, &info); 
        if (info) { 
            throw std::runtime_error("workspace query failed for 'dormqr'");
        }

        lwork=int(tmpwork+0.5);
        work.resize(lwork);
        return;
    }

    void multiply(double* rhs) {
        F77_CALL(dormqr)(&side, &trans, &nobs, &ncol, &ncoef, qrptr, &nobs, 
            qxptr, rhs, &nobs, work.data(), &lwork, &info); 
        if (info) { 
            throw std::runtime_error("residual calculations failed for 'dormqr'");
        }
        return;
    }

    int get_nobs() const {
        return nobs;
    }

    int get_ncoefs() const {
        return ncoef;
    }
protected:
    Rcpp::NumericMatrix QR;
    Rcpp::NumericVector AUX;
    const double* qrptr, * qxptr;
    const int nobs, ncoef;
    const char trans;

    int info=0, lwork=-1;
    std::vector<double> work;

    const int ncol=1;
    const char side='L';
};

class linear_model_fit : public QR_multiplier { 
public:
    linear_model_fit(Rcpp::NumericMatrix qr, Rcpp::NumericVector qraux) : QR_multiplier(qr, qraux, 'T') {}

    void solve(double* rhs) {
        F77_CALL(dtrtrs)(&uplo, &xtrans, &diag, &(this->ncoef), &(this->ncol), this->qrptr, &(this->nobs), 
            rhs, &(this->nobs), &(this->info));

        if (this->info) { 
            throw std::runtime_error("coefficient calculations failed for 'dtrtrs'");
        }
        return;
    }
private:
    const char uplo='U', xtrans='N', diag='N';
};

}

#endif