File: SparseMatrix-mult.R

package info (click to toggle)
r-bioc-sparsearray 1.6.2%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,768 kB
  • sloc: ansic: 16,138; makefile: 2
file content (215 lines) | stat: -rw-r--r-- 6,848 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
### =========================================================================
### SparseMatrix crossprod(), tcrossprod(), and %*%
### -------------------------------------------------------------------------
###


### - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
### crossprod()
###

### Only input types "double" and "integer" are supported at the moment.
### TODO: Add "complex" and "logical" later.
.check_crossprod_input_type <- function(type)
{
    #supported_types <- c("double", "integer", "complex", "logical")
    supported_types <- c("double", "integer")
    if (!(type %in% supported_types))
        stop(wmsg("input objects must be of type() \"double\" or \"integer\""))
                  #"\"integer\", \"complex\", or \"logical\""))
}

.crossprod2_SparseMatrix_matrix <- function(x, y, transpose.y=FALSE)
{
    if (is(x, "SVT_SparseMatrix")) {
        check_svt_version(x)
    } else {
        x <- as(x, "SVT_SparseMatrix")
    }
    stopifnot(is.matrix(y), isTRUEorFALSE(transpose.y))
    if (transpose.y) {
        if (nrow(x) != ncol(y))
            stop(wmsg("non-conformable arguments"))
        ans_dim <- c(ncol(x), nrow(y))
        ans_dimnames <- list(colnames(x), rownames(y))
    } else {
        if (nrow(x) != nrow(y))
            stop(wmsg("non-conformable arguments"))
        ans_dim <- c(ncol(x), ncol(y))
        ans_dimnames <- list(colnames(x), colnames(y))
    }
    if (type(x) == type(y)) {
        .check_crossprod_input_type(type(x))
    } else {
        xy_type <- type(c(vector(type(x)), vector(type(y))))
        .check_crossprod_input_type(xy_type)
        type(x) <- type(y) <- xy_type
    }
    ans_type <- "double"
    ans_dimnames <- S4Arrays:::simplify_NULL_dimnames(ans_dimnames)
    SparseArray.Call("C_crossprod2_SVT_mat",
                     x@dim, x@type, x@SVT, y, transpose.y,
                     ans_type, ans_dimnames)
}

.crossprod2_matrix_SparseMatrix <- function(x, y, transpose.x=FALSE)
{
    if (is(y, "SVT_SparseMatrix")) {
        check_svt_version(y)
    } else {
        y <- as(y, "SVT_SparseMatrix")
    }
    stopifnot(is.matrix(x), isTRUEorFALSE(transpose.x))
    if (transpose.x) {
        if (ncol(x) != nrow(y))
            stop(wmsg("non-conformable arguments"))
        ans_dim <- c(nrow(x), ncol(y))
        ans_dimnames <- list(rownames(x), colnames(y))
    } else {
        if (nrow(x) != nrow(y))
            stop(wmsg("non-conformable arguments"))
        ans_dim <- c(ncol(x), ncol(y))
        ans_dimnames <- list(colnames(x), colnames(y))
    }
    if (type(x) == type(y)) {
        .check_crossprod_input_type(type(x))
    } else {
        xy_type <- type(c(vector(type(x)), vector(type(y))))
        .check_crossprod_input_type(xy_type)
        type(x) <- type(y) <- xy_type
    }
    ans_type <- "double"
    ans_dimnames <- S4Arrays:::simplify_NULL_dimnames(ans_dimnames)
    SparseArray.Call("C_crossprod2_mat_SVT",
                     x, y@dim, y@type, y@SVT, transpose.x,
                     ans_type, ans_dimnames)
}

.crossprod2_SparseMatrix_SparseMatrix <- function(x, y=NULL)
{
    if (is(x, "SVT_SparseMatrix")) {
        check_svt_version(x)
    } else {
        x <- as(x, "SVT_SparseMatrix")
    }
    if (is(y, "SVT_SparseMatrix")) {
        check_svt_version(y)
    } else {
        y <- as(y, "SVT_SparseMatrix")
    }
    if (nrow(x) != nrow(y))
        stop(wmsg("non-conformable arguments"))
    ans_dim <- c(ncol(x), ncol(y))
    if (type(x) == type(y)) {
        .check_crossprod_input_type(type(x))
    } else {
        xy_type <- type(c(vector(type(x)), vector(type(y))))
        .check_crossprod_input_type(xy_type)
        type(x) <- type(y) <- xy_type
    }
    ans_type <- "double"
    ans_dimnames <- list(colnames(x), colnames(y))
    ans_dimnames <- S4Arrays:::simplify_NULL_dimnames(ans_dimnames)
    SparseArray.Call("C_crossprod2_SVT_SVT",
                     x@dim, x@type, x@SVT, y@dim, y@type, y@SVT,
                     ans_type, ans_dimnames)
}

.crossprod1_SparseMatrix <- function(x, y=NULL)
{
    if (is(x, "SVT_SparseMatrix")) {
        check_svt_version(x)
    } else {
        x <- as(x, "SVT_SparseMatrix")
    }
    stopifnot(is.null(y))
    ans_dim <- c(ncol(x), ncol(x))
    .check_crossprod_input_type(type(x))
    ans_type <- "double"
    ans_dimnames <- list(colnames(x), colnames(x))
    ans_dimnames <- S4Arrays:::simplify_NULL_dimnames(ans_dimnames)
    SparseArray.Call("C_crossprod1_SVT",
                     x@dim, x@type, x@SVT,
                     ans_type, ans_dimnames)
}

setMethod("crossprod", c("SparseMatrix", "matrix"),
    .crossprod2_SparseMatrix_matrix
)

setMethod("crossprod", c("matrix", "SparseMatrix"),
    .crossprod2_matrix_SparseMatrix
)

setMethod("crossprod", c("SparseMatrix", "SparseMatrix"),
    .crossprod2_SparseMatrix_SparseMatrix
)

setMethod("crossprod", c("SparseMatrix", "ANY"),
    function(x, y=NULL) .crossprod2_SparseMatrix_SparseMatrix(x, y)
)

setMethod("crossprod", c("ANY", "SparseMatrix"),
    function(x, y=NULL) .crossprod2_SparseMatrix_SparseMatrix(x, y)
)

setMethod("crossprod", c("SparseMatrix", "missing"),
    .crossprod1_SparseMatrix
)


### - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
### tcrossprod()
###

setMethod("tcrossprod", c("SparseMatrix", "matrix"),
    function(x, y=NULL) .crossprod2_SparseMatrix_matrix(t(x), y,
                                                        transpose.y=TRUE)
)

setMethod("tcrossprod", c("matrix", "SparseMatrix"),
    function(x, y=NULL) .crossprod2_matrix_SparseMatrix(x, t(y),
                                                        transpose.x=TRUE)
)

setMethod("tcrossprod", c("SparseMatrix", "SparseMatrix"),
    function(x, y=NULL) .crossprod2_SparseMatrix_SparseMatrix(t(x), t(y))
)

setMethod("tcrossprod", c("SparseMatrix", "ANY"),
    function(x, y=NULL) .crossprod2_SparseMatrix_SparseMatrix(t(x), t(y))
)

setMethod("tcrossprod", c("ANY", "SparseMatrix"),
    function(x, y=NULL) .crossprod2_SparseMatrix_SparseMatrix(t(x), t(y))
)

setMethod("tcrossprod", c("SparseMatrix", "missing"),
    function(x, y=NULL) .crossprod1_SparseMatrix(t(x))
)


### - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
### Matrix multiplication
###

setMethod("%*%", c("SparseMatrix", "matrix"),
    function(x, y) .crossprod2_SparseMatrix_matrix(t(x), y)
)

setMethod("%*%", c("matrix", "SparseMatrix"),
    function(x, y) .crossprod2_matrix_SparseMatrix(x, y, transpose.x=TRUE)
)

setMethod("%*%", c("SparseMatrix", "SparseMatrix"),
    function(x, y) .crossprod2_SparseMatrix_SparseMatrix(t(x), y)
)

setMethod("%*%", c("SparseMatrix", "ANY"),
    function(x, y) .crossprod2_SparseMatrix_SparseMatrix(t(x), y)
)

setMethod("%*%", c("ANY", "SparseMatrix"),
    function(x, y) .crossprod2_SparseMatrix_SparseMatrix(t(x), y)
)