File: tensor.scm

package info (click to toggle)
jacal 1c8-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 1,064 kB
  • sloc: lisp: 6,648; sh: 419; makefile: 315
file content (283 lines) | stat: -rw-r--r-- 7,326 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
;; TENSOR.SCM -- Tensor-like support functions for JACAL
;; Copyright (C) 1993 Jerry D. Hedden
;;
;; This program is free software; you can redistribute it and/or modify
;; it under the terms of the GNU General Public License as published by
;; the Free Software Foundation, either version 3 of the License, or (at
;; your option) any later version.
;; 
;; This program is distributed in the hope that it will be useful, but
;; WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
;; General Public License for more details.
;; 
;; You should have received a copy of the GNU General Public License
;; along with this program; if not, write to the Free Software
;; Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

(require 'common-list-functions)

; Does not support the notion of contra-/covariant indices.
;  Users must keep track of this information themselves.

; Assumes that all matrices are "proper" (i.e., that all "dimensions"
;  of the matrix are the same length, e.g., 4x4x4) and "compatible"
;  (e.g., a 3x3 matrix is not compatible with a 4x4x4 matrix).


(definfo 'indexshift
  "Shifts an index within a tensor")

(definfo 'indexswap
  "Swaps two indices within a tensor")

(definfo 'contract
  "Tensor contraction")

(definfo 'tmult
  "Tensor multiplication")


; helper function to determine the "rank" of a tensor
(define (tnsr:rank m)
  (let rnk ((rank  0)
	    (mm    m))
    (if (bunch? mm)
      (rnk (+ 1 rank) (car mm))
      rank)))


(defbltn 'indexshift 1 #f
  (lambda (m . args)
    (let ((rank  (tnsr:rank m)))
      (cond
	((= rank 0)
	   ; scalar -- no-op
	   m)
	((= rank 1)
	   ; vector -- returns a pseudo-tensor
	   (map list m))
	((or (= rank 2) (null? args))
	   ; matrix -- transpose
	   (apply map list m))
	(else
	   ; tensor
	   ; set and constrain the "from" and "to" positions
	   (let* ((a  (car args))
		  (b  (if (null? (cdr args))
			(+ 1 a)
			(cadr args))))
	     (if (< a 1)
	       (set! a 1)
	       (if (> a rank)
		 (set! a rank)))
	     (if (< b 1)
	       (set! b 1)
	       (if (> b rank)
		 (set! b rank)))
	     (if (= a b)
	       (if (= a rank)
		 (set! a (+ -1 b))
		 (set! b (+ 1 a))))

	     (if (< a b)
	       ; index shift right
	       (let isr1 ((ma  m)
			  (nn  1))
		 (if (< nn a)
		   (map (lambda (mm) (isr1 mm (+ 1 nn))) ma)
		   (let isr2 ((mb  ma)
			      (aa  (+ 1 a)))
		     (if (= aa b)
		       (apply map list mb)
		       (map (lambda (mm) (isr2 mm (+ 1 aa)))
			    (apply map list mb))))))
	       ; index shift left
	       (let isl1 ((ma  m)
			  (nn  1))
		 (if (< nn b)
		   (map (lambda (mm) (isl1 mm (+ 1 nn))) ma)
		   (let isl2 ((mb  ma)
			      (bb  (+ 1 b)))
		     (if (= bb a)
		       (apply map list mb)
		       (apply map list (map (lambda (mm) (isl2 mm (+ 1 bb)))
					    mb)))))))))))))


(defbltn 'indexswap 1 #f
  (lambda (m . args)
    (let ((rank  (tnsr:rank m)))
      (cond
	((= rank 0)
	   ; scalar -- no-op
	   m)
	((= rank 1)
	   ; vector -- returns a pseudo-tensor
	   (map list m))
	((or (= rank 2) (null? args))
	   ; matrix -- transpose
	   (apply map list m))
	(else
	   ; tensor
	   ; set and constrain the indices to be swapped
	   (let* ((a  (car args))
		  (b  (if (null? (cdr args))
			(+ 1 a)
			(cadr args))))
	     (if (< b a)
	       (let ((c  a))
		 (begin (set! a b)
			(set! b c))))
	     (if (< a 1)
	       (begin (set! a 1)
		      (if (<= b a)
			(set! b 2)))
	       (if (> b rank)
		 (begin (set! b rank)
			(if (<= b a)
			  (set! a (+ -1 b))))
		 (if (= a b)
		   (if (= a rank)
		     (set! a (+ -1 b))
		     (set! b (+ 1 a))))))

	     ; perform the swapping operation
	     (let swap1 ((ma  m)
			 (n   1))
	       (if (< n a)
		   (map (lambda (mm) (swap1 mm (+ 1 n))) ma)
		   (let swap2 ((mb  ma)
			       (aa  (+ 1 a)))
		     (if (= aa b)
			 (apply map list mb)
			 (apply map list
				    (map (lambda (mm) (swap2 mm (+ 1 aa)))
					 (apply map list mb)))))))))))))


; helper function for the contraction operation
;  sums the diagonal elements of a matrix
(define (tnsr:contract m)
  (let cxt ((mm  (map cdr (cdr m)))
	    (ss  (car (car m))))
    (if (null? mm)
      ss
      (cxt (map cdr (cdr mm)) (app* $1+$2 ss (car (car mm)))))))


(defbltn 'contract 1 #f
  (lambda (m . args)
    (let ((rank  (tnsr:rank m)))
      (cond
	((= rank 0)
	   ; scalar -- no-op
	   m)
	((= rank 1)
	   ; vector -- sum elements
	   (reduce (lambda (x y) (app* $1+$2 x y)) m))
	((= rank 2)
	   ; matrix -- sum diagonal elements
	   (tnsr:contract m))
	(else
	   ; tensor
	   ; set and constrain the indices for the contraction operation
	   (let* ((a  (car args))
		  (b  (if (null? (cdr args))
			(+ 1 a)
			(cadr args))))
	     (if (< b a)
	       (let ((c  a))
		 (begin (set! a b)
			(set! b c))))
	     (if (< a 1)
	       (begin (set! a 1)
		      (if (<= b a)
			(set! b 2)))
	       (if (> b rank)
		 (begin (set! b rank)
			(if (<= b a)
			  (set! a (+ -1 b))))
		 (if (= a b)
		   (if (= a rank)
		     (set! a (+ -1 b))
		     (set! b (+ 1 a))))))

	   ; perform the contraction operation
	   (let cxt1 ((ma  m)
		      (nn  1))
	     (if (< nn a)
	       (map (lambda (mm) (cxt1 mm (+ 1 nn))) ma)
	       (let cxt2 ((mb  ma)
			  (aa  (+ 1 a)))
		 (if (< aa b)
		   (map (lambda (mm) (cxt2 mm (+ 1 aa))) (apply map list mb))
		   (let cxt3 ((mc  mb)
			      (bb  b))
		     (if (= bb rank)
		       (tnsr:contract mc)
		       (map (lambda (mm) (cxt3 mm (+ 1 bb)))
			    (apply map list (map (lambda (mx)
						   (apply map list mx))
						 mc)))))))))))))))


(defbltn 'tmult 2 #f
  (lambda (m1 m2 . args)
    (let ((r1  (tnsr:rank m1))
	  (r2  (tnsr:rank m2)))
      (cond
	((or (= r1 0) (= r2 0))
	   ; scalar multiplication
	   (app* $1*$2 m1 m2))
	((null? args)
	   ; outerproduct -- scalar multiplication of the second
	   ;  tensor by each element of the first tensor
	   (let outerproduct ((ma  m1)
			      (r   1))
	     (if (< r r1)
	       (map (lambda (mm) (outerproduct mm (+ 1 r))) ma)
	       (map (lambda (x) (app* $1*$2 x m2)) ma))))
	(else
	   ; innerproduct
	   ; set and contrain indices to be used
	   (let* ((a  (car args))
		  (b  (if (null? (cdr args))
			a
			(cadr args))))
	     (if (< a 1)
	       (set! a 1)
	       (if (> a r1)
		 (set! a r1)))
	     (if (< b 1)
	       (set! b 1)
	       (if (> b r2)
		 (set! b r2)))

	     ; perform the multiplication operation
	     (let mult1 ((ma1  m1)
			 (n1   1))
	       (if (< n1 a)
		 ; find index to multiply in first tensor
		 (map (lambda (mm) (mult1 mm (+ n1 1))) ma1)
		 (let mult2 ((mb1  ma1)
			     (a1   a))
		   (if (< a1 r1)
		     ; shift index to last position in first tensor
		     (map (lambda (mm) (mult2 mm (+ a1 1)))
			  (apply map list mb1))
		     (let mult3 ((ma2  m2)
				 (n2   1))
		       (if (< n2 b)
			 ; find index to multiply in second tensor
			 (map (lambda (mm) (mult3 mm (+ n2 1))) ma2)
			 (let mult4 ((mb2  ma2)
				     (a2   b))
			   (if (< a2 r2)
			     ; shift index to last position in second tensor
			     (map (lambda (mm) (mult4 mm (+ a2 1)))
				  (apply map list mb2))
			     ; the actual multiplication is done here
			     (reduce (lambda (x y) (app* $1+$2 x y))
				     (map (lambda (x y) (app* $1*$2 x y))
					  mb1 mb2))))))))))))))))