File: array-broadcast.rkt

package info (click to toggle)
racket 7.2%2Bdfsg1-2
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 125,432 kB
  • sloc: ansic: 258,980; pascal: 59,975; sh: 33,650; asm: 13,558; lisp: 7,124; makefile: 3,329; cpp: 2,889; exp: 499; python: 274; xml: 11
file content (108 lines) | stat: -rw-r--r-- 4,570 bytes parent folder | download | duplicates (10)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#lang typed/racket

(require racket/fixnum
         "array-struct.rkt"
         "../unsafe.rkt"
         "utils.rkt")

(provide array-broadcasting
         array-broadcast
         array-shape-broadcast)

(: array-broadcasting (Parameterof (U #f #t 'permissive)))
(define array-broadcasting (make-parameter #t))

(: shift-stretch-axes (All (A) ((Array A) Indexes -> (Array A))))
(define (shift-stretch-axes arr new-ds)
  (define old-ds (array-shape arr))
  (define old-dims (vector-length old-ds))
  (define new-dims (vector-length new-ds))
  (define shift
    (let ([shift  (- new-dims old-dims)])
      (cond [(index? shift)  shift]
            [else  (error 'array-broadcast
                          "cannot broadcast to a lower-dimensional shape; given ~e and ~e"
                          arr new-ds)])))
  (define old-js (make-thread-local-indexes old-dims))
  (define old-f (unsafe-array-proc arr))
  (unsafe-build-array
   new-ds
   (λ: ([new-js : Indexes])
     (let ([old-js  (old-js)])
       (let: loop : A ([k : Nonnegative-Fixnum  0])
         (cond [(k . < . old-dims)
                (define new-jk (unsafe-vector-ref new-js (+ k shift)))
                (define old-dk (unsafe-vector-ref old-ds k))
                (define old-jk (unsafe-fxmodulo new-jk old-dk))
                (unsafe-vector-set! old-js k old-jk)
                (loop (+ k 1))]
               [else  (old-f old-js)]))))))

(: array-broadcast (All (A) ((Array A) Indexes -> (Array A))))
(define (array-broadcast arr ds)
  (cond [(equal? ds (array-shape arr))  arr]
        [else  (define new-arr (shift-stretch-axes arr ds))
               (if (or (array-strict? arr) ((array-size new-arr) . fx<= . (array-size arr)))
                   new-arr
                   (array-default-strict new-arr))]))

(: shape-insert-axes (Indexes Integer -> Indexes))
(define (shape-insert-axes ds n)
  (vector-append ((inst make-vector Index) n 1) ds))

(: shape-permissive-broadcast (Indexes Indexes Index (-> Nothing) -> Indexes))
(define (shape-permissive-broadcast ds1 ds2 dims fail)
  (define: new-ds : Indexes (make-vector dims 0))
  (let loop ([#{k : Nonnegative-Fixnum} 0])
    (cond [(k . < . dims)
           (define dk1 (unsafe-vector-ref ds1 k))
           (define dk2 (unsafe-vector-ref ds2 k))
           (unsafe-vector-set!
            new-ds k
            (cond [(or (= dk1 0) (= dk2 0))  (fail)]
                  [else  (fxmax dk1 dk2)]))
           (loop (+ k 1))]
          [else  new-ds])))

(: shape-normal-broadcast (Indexes Indexes Index (-> Nothing) -> Indexes))
(define (shape-normal-broadcast ds1 ds2 dims fail)
  (define: new-ds : Indexes (make-vector dims 0))
  (let loop ([#{k : Nonnegative-Fixnum} 0])
    (cond [(k . < . dims)
           (define dk1 (unsafe-vector-ref ds1 k))
           (define dk2 (unsafe-vector-ref ds2 k))
           (unsafe-vector-set!
            new-ds k
            (cond [(= dk1 dk2)  dk1]
                  [(and (= dk1 1) (dk2 . > . 0))  dk2]
                  [(and (= dk2 1) (dk1 . > . 0))  dk1]
                  [else  (fail)]))
           (loop (+ k 1))]
          [else  new-ds])))

(: shape-broadcast2 (Indexes Indexes (-> Nothing) (U #f #t 'permissive) -> Indexes))
(define (shape-broadcast2 ds1 ds2 fail broadcasting)
  (cond [(equal? ds1 ds2)  ds1]
        [(not broadcasting)  (fail)]
        [else
         (define dims1 (vector-length ds1))
         (define dims2 (vector-length ds2))
         (define n (- dims2 dims1))
         (let-values ([(ds1 ds2 dims)
                       (cond [(n . > . 0)  (values (shape-insert-axes ds1 n) ds2 dims2)]
                             [(n . < . 0)  (values ds1 (shape-insert-axes ds2 (- n)) dims1)]
                             [else         (values ds1 ds2 dims1)])])
           (if (eq? broadcasting 'permissive)
               (shape-permissive-broadcast ds1 ds2 dims fail)
               (shape-normal-broadcast ds1 ds2 dims fail)))]))

(: array-shape-broadcast (case-> ((Listof Indexes) -> Indexes)
                                 ((Listof Indexes) (U #f #t 'permissive) -> Indexes)))
(define (array-shape-broadcast dss [broadcasting (array-broadcasting)])
  (define (fail) (error 'array-shape-broadcast
                        "incompatible array shapes (array-broadcasting ~v): ~a"
                        broadcasting
                        (string-join (map (λ (ds) (format "~e" ds)) dss) ", ")))
  (cond [(empty? dss)  #()]
        [else  (for/fold ([new-ds  (first dss)]) ([ds  (in-list (rest dss))])
                 (shape-broadcast2 new-ds ds fail broadcasting))]))