1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
#lang typed/racket/base
(require racket/flonum
(for-syntax racket/base)
"../flonum/flvector-syntax.rkt"
"../flonum/flvector.rkt"
"array-struct.rkt"
"array-broadcast.rkt"
"array-pointwise.rkt"
"mutable-array.rkt"
"flarray-struct.rkt"
"utils.rkt")
(provide
;; Mapping
inline-flarray-map
flarray-map
;; Pointwise operations
flarray-scale
flarray-sqr
flarray-sqrt
flarray-abs
flarray+
flarray*
flarray-
flarray/
flarray-min
flarray-max)
;; ===================================================================================================
;; Mapping
(define-syntax (inline-flarray-map stx)
(syntax-case stx ()
[(_ f) (syntax/loc stx (unsafe-flarray #() (flvector (f))))]
[(_ f arr-expr)
(syntax/loc stx
(let: ([arr : FlArray arr-expr])
(unsafe-flarray (array-shape arr) (inline-flvector-map f (flarray-data arr)))))]
[(_ f arr-expr arr-exprs ...)
(with-syntax ([(arrs ...) (generate-temporaries #'(arr-exprs ...))]
[(dss ...) (generate-temporaries #'(arr-exprs ...))]
[(procs ...) (generate-temporaries #'(arr-exprs ...))])
(syntax/loc stx
(let: ([arr : FlArray arr-expr]
[arrs : FlArray arr-exprs] ...)
(define ds (array-shape arr))
(define dss (array-shape arrs)) ...
(cond [(and (equal? ds dss) ...)
(unsafe-flarray
ds (inline-flvector-map f (flarray-data arr) (flarray-data arrs) ...))]
[else
(define new-ds (array-shape-broadcast (list ds dss ...)))
(define proc (unsafe-array-proc (array-broadcast arr new-ds)))
(define procs (unsafe-array-proc (array-broadcast arrs new-ds))) ...
(array->flarray
(unsafe-build-array new-ds (λ: ([js : Indexes])
(f (proc js) (procs js) ...))))]))))]))
(: flarray-map (case-> ((-> Float) -> FlArray)
((Float -> Float) FlArray -> FlArray)
((Float Float Float * -> Float) FlArray FlArray FlArray * -> FlArray)))
(define flarray-map
(case-lambda:
[([f : (-> Float)])
(inline-flarray-map f)]
[([f : (Float -> Float)] [arr : FlArray])
(inline-flarray-map f arr)]
[([f : (Float Float -> Float)] [arr0 : FlArray] [arr1 : FlArray])
(inline-flarray-map f arr0 arr1)]
[([f : (Float Float Float * -> Float)] [arr0 : FlArray] [arr1 : FlArray] . [arrs : FlArray *])
(define ds (array-shape arr0))
(define dss (map (λ: ([arr : FlArray]) (array-shape arr)) (cons arr1 arrs)))
(define new-ds (array-shape-broadcast (list* ds dss)))
(let: ([arr0 : (Array Float) (array-broadcast arr0 new-ds)]
[arr1 : (Array Float) (array-broadcast arr1 new-ds)]
[arrs : (Listof (Array Float))
(map (λ: ([arr : FlArray]) (array-broadcast arr new-ds)) arrs)])
(define proc0 (unsafe-array-proc arr0))
(define proc1 (unsafe-array-proc arr1))
(define procs (map (λ: ([arr : (Array Float)]) (unsafe-array-proc arr)) arrs))
(array->flarray
(unsafe-build-array new-ds (λ: ([js : Indexes])
(apply f (proc0 js) (proc1 js)
(map (λ: ([proc : (Indexes -> Float)]) (proc js))
procs))))))]))
;; ===================================================================================================
;; Pointwise operations
(define-syntax-rule (lift-flvector1 f)
(λ (arr) (unsafe-flarray (array-shape arr) (f (flarray-data arr)))))
(define-syntax-rule (lift-flvector2 f array-f)
(λ (arr1 arr2)
(define ds1 (array-shape arr1))
(define ds2 (array-shape arr2))
(cond [(equal? ds1 ds2) (unsafe-flarray ds1 (f (flarray-data arr1) (flarray-data arr2)))]
[else (array->flarray (array-f arr1 arr2))])))
(: flarray-scale (FlArray Float -> FlArray))
(define (flarray-scale arr y)
(define-syntax-rule (fun xs) (flvector-scale xs y))
((lift-flvector1 fun) arr))
(: flarray-sqr (FlArray -> FlArray))
(define flarray-sqr (lift-flvector1 flvector-sqr))
(: flarray-sqrt (FlArray -> FlArray))
(define flarray-sqrt (lift-flvector1 flvector-sqrt))
(: flarray-abs (FlArray -> FlArray))
(define flarray-abs (lift-flvector1 flvector-abs))
(: flarray+ (FlArray FlArray -> FlArray))
(define flarray+ (lift-flvector2 flvector+ array+))
(: flarray* (FlArray FlArray -> FlArray))
(define flarray* (lift-flvector2 flvector* array*))
(: flarray- (case-> (FlArray -> FlArray)
(FlArray FlArray -> FlArray)))
(define flarray-
(case-lambda
[(arr) ((lift-flvector1 flvector-) arr)]
[(arr1 arr2) ((lift-flvector2 flvector- array-) arr1 arr2)]))
(: flarray/ (case-> (FlArray -> FlArray)
(FlArray FlArray -> FlArray)))
(define flarray/
(case-lambda
[(arr) ((lift-flvector1 flvector/) arr)]
[(arr1 arr2) ((lift-flvector2 flvector/ array/) arr1 arr2)]))
(: flarray-min (FlArray FlArray -> FlArray))
(define flarray-min (lift-flvector2 flvector-min array-min))
(: flarray-max (FlArray FlArray -> FlArray))
(define flarray-max (lift-flvector2 flvector-max array-max))
|