File: tensor.R

package info (click to toggle)
r-cran-tensor 1.5-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye, forky, sid, trixie
  • size: 76 kB
  • sloc: makefile: 2
file content (72 lines) | stat: -rw-r--r-- 1,625 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"tensor" <-
function(A, B, alongA = integer(0), alongB = integer(0))
{
  A <- as.array(A)
  dimA <- dim(A)
  dnA <- dimnames(A)
  if (nnA <- is.null(dnA))
    dnA <- rep(list(NULL), length(dimA))

  B <- as.array(B)
  dimB <- dim(B)
  dnB <- dimnames(B)
  if (nnB <- is.null(dnB))
    dnB <- rep(list(NULL), length(dimB))

  if (length(alongA) != length(alongB))
    stop("\"along\" vectors must be same length")

  # special case of both length zero

  if (length(alongA) == 0) {
    R <- as.vector(A) %*% t(as.vector(B))
    dim(R) <- c(dimA, dimB)
    if (!(nnA && nnB))
    dimnames(R) <- c(dnA, dnB)
    return(R)
  }

  mtch <- dimA[alongA] == dimB[alongB]
  if (any(is.na(mtch)) || !all(mtch))
    stop("Mismatch in \"along\" dimensions")

  seqA <- seq(along=dimA)
  allA <- length(seqA) == length(alongA)
  permA <- c(seqA[-alongA], alongA)
  if (!all(seqA == permA))
    A <- aperm(A, permA)
  dim(A) <- c(
    if (allA) 1 else prod(dimA[-alongA]),
    prod(dimA[alongA])
  )

  seqB <- seq(along=dimB)
  allB <- length(seqB) == length(alongB)
  permB <- c(alongB, seqB[-alongB])
  if (!all(seqB == permB))
    B <- aperm(B, permB)
  dim(B) <- c(
    prod(dimB[alongB]),
    if (allB) 1 else prod(dimB[-alongB])
  )

  R <- A %*% B

  if (allA && allB)
    R <- drop(R)
  else {
    dim(R) <- c(
      if (allA) integer(0) else dimA[-alongA],
      if (allB) integer(0) else dimB[-alongB]
    )
    if (!(nnA && nnB))
      dimnames(R) <- c(dnA[-alongA], dnB[-alongB])
  }
  R
}

"%*t%" <- function(x, y) tensor(x, y, 2, 2)

"%t*%" <- function(x, y) tensor(x, y, 1, 1)

"%t*t%" <- function(x, y) tensor(x, y, 1, 2)