File: flarray-pointwise.rkt

package info (click to toggle)
racket 7.2%2Bdfsg1-2
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 125,432 kB
  • sloc: ansic: 258,980; pascal: 59,975; sh: 33,650; asm: 13,558; lisp: 7,124; makefile: 3,329; cpp: 2,889; exp: 499; python: 274; xml: 11
file content (139 lines) | stat: -rw-r--r-- 5,386 bytes parent folder | download | duplicates (8)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#lang typed/racket/base

(require racket/flonum
         (for-syntax racket/base)
         "../flonum/flvector-syntax.rkt"
         "../flonum/flvector.rkt"
         "array-struct.rkt"
         "array-broadcast.rkt"
         "array-pointwise.rkt"
         "mutable-array.rkt"
         "flarray-struct.rkt"
         "utils.rkt")

(provide
 ;; Mapping
 inline-flarray-map
 flarray-map
 ;; Pointwise operations
 flarray-scale
 flarray-sqr
 flarray-sqrt
 flarray-abs
 flarray+
 flarray*
 flarray-
 flarray/
 flarray-min
 flarray-max)

;; ===================================================================================================
;; Mapping

(define-syntax (inline-flarray-map stx)
  (syntax-case stx ()
    [(_ f)  (syntax/loc stx (unsafe-flarray #() (flvector (f))))]
    [(_ f arr-expr)
     (syntax/loc stx
       (let: ([arr : FlArray  arr-expr])
         (unsafe-flarray (array-shape arr) (inline-flvector-map f (flarray-data arr)))))]
    [(_ f arr-expr arr-exprs ...)
     (with-syntax ([(arrs ...)   (generate-temporaries #'(arr-exprs ...))]
                   [(dss ...)    (generate-temporaries #'(arr-exprs ...))]
                   [(procs ...)  (generate-temporaries #'(arr-exprs ...))])
       (syntax/loc stx
         (let: ([arr : FlArray  arr-expr]
                [arrs : FlArray  arr-exprs] ...)
           (define ds (array-shape arr))
           (define dss (array-shape arrs)) ...
           (cond [(and (equal? ds dss) ...)
                  (unsafe-flarray
                   ds (inline-flvector-map f (flarray-data arr) (flarray-data arrs) ...))]
                 [else
                  (define new-ds (array-shape-broadcast (list ds dss ...)))
                  (define proc  (unsafe-array-proc (array-broadcast arr new-ds)))
                  (define procs (unsafe-array-proc (array-broadcast arrs new-ds))) ...
                  (array->flarray
                   (unsafe-build-array new-ds (λ: ([js : Indexes])
                                                (f (proc js) (procs js) ...))))]))))]))

(: flarray-map (case-> ((-> Float) -> FlArray)
                       ((Float -> Float) FlArray -> FlArray)
                       ((Float Float Float * -> Float) FlArray FlArray FlArray * -> FlArray)))
(define flarray-map
  (case-lambda:
    [([f : (-> Float)])
     (inline-flarray-map f)]
    [([f : (Float -> Float)] [arr : FlArray])
     (inline-flarray-map f arr)]
    [([f : (Float Float -> Float)] [arr0 : FlArray] [arr1 : FlArray])
     (inline-flarray-map f arr0 arr1)]
    [([f : (Float Float Float * -> Float)] [arr0 : FlArray] [arr1 : FlArray] . [arrs : FlArray *])
     (define ds (array-shape arr0))
     (define dss (map (λ: ([arr : FlArray]) (array-shape arr)) (cons arr1 arrs)))
     (define new-ds (array-shape-broadcast (list* ds dss)))
     (let: ([arr0 : (Array Float)  (array-broadcast arr0 new-ds)]
            [arr1 : (Array Float)  (array-broadcast arr1 new-ds)]
            [arrs : (Listof (Array Float))
                  (map (λ: ([arr : FlArray]) (array-broadcast arr new-ds)) arrs)])
       (define proc0 (unsafe-array-proc arr0))
       (define proc1 (unsafe-array-proc arr1))
       (define procs (map (λ: ([arr : (Array Float)]) (unsafe-array-proc arr)) arrs))
       (array->flarray
        (unsafe-build-array new-ds (λ: ([js : Indexes])
                                     (apply f (proc0 js) (proc1 js)
                                            (map (λ: ([proc : (Indexes -> Float)]) (proc js))
                                                 procs))))))]))

;; ===================================================================================================
;; Pointwise operations

(define-syntax-rule (lift-flvector1 f)
  (λ (arr) (unsafe-flarray (array-shape arr) (f (flarray-data arr)))))

(define-syntax-rule (lift-flvector2 f array-f)
  (λ (arr1 arr2)
    (define ds1 (array-shape arr1))
    (define ds2 (array-shape arr2))
    (cond [(equal? ds1 ds2)  (unsafe-flarray ds1 (f (flarray-data arr1) (flarray-data arr2)))]
          [else  (array->flarray (array-f arr1 arr2))])))

(: flarray-scale (FlArray Float -> FlArray))
(define (flarray-scale arr y)
  (define-syntax-rule (fun xs) (flvector-scale xs y))
  ((lift-flvector1 fun) arr))

(: flarray-sqr (FlArray -> FlArray))
(define flarray-sqr (lift-flvector1 flvector-sqr))

(: flarray-sqrt (FlArray -> FlArray))
(define flarray-sqrt (lift-flvector1 flvector-sqrt))

(: flarray-abs (FlArray -> FlArray))
(define flarray-abs (lift-flvector1 flvector-abs))

(: flarray+ (FlArray FlArray -> FlArray))
(define flarray+ (lift-flvector2 flvector+ array+))

(: flarray* (FlArray FlArray -> FlArray))
(define flarray* (lift-flvector2 flvector* array*))

(: flarray- (case-> (FlArray -> FlArray)
                    (FlArray FlArray -> FlArray)))
(define flarray-
  (case-lambda
    [(arr)  ((lift-flvector1 flvector-) arr)]
    [(arr1 arr2)  ((lift-flvector2 flvector- array-) arr1 arr2)]))

(: flarray/ (case-> (FlArray -> FlArray)
                    (FlArray FlArray -> FlArray)))
(define flarray/
  (case-lambda
    [(arr)  ((lift-flvector1 flvector/) arr)]
    [(arr1 arr2)  ((lift-flvector2 flvector/ array/) arr1 arr2)]))

(: flarray-min  (FlArray FlArray -> FlArray))
(define flarray-min  (lift-flvector2 flvector-min array-min))

(: flarray-max  (FlArray FlArray -> FlArray))
(define flarray-max  (lift-flvector2 flvector-max array-max))