1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
#lang racket/base
(require (for-syntax racket/base)
(only-in typed/racket/base assert index?)
"array-struct.rkt"
"array-pointwise.rkt"
"typed-array-fold.rkt")
;; ===================================================================================================
;; Standard folds
(define-syntax-rule (define-axis-fold name f)
(define-syntax (name stx)
(syntax-case stx ()
[(_ arr k) (syntax/loc stx (array-axis-fold arr k f))]
[(_ arr k init) (syntax/loc stx (array-axis-fold arr k f init))])))
(define-syntax-rule (define-all-fold name f)
(define-syntax (name stx)
(syntax-case stx ()
[(_ arr) (syntax/loc stx (array-all-fold arr f))]
[(_ arr init) (syntax/loc stx (array-all-fold arr f init))])))
(define-axis-fold array-axis-sum +)
(define-axis-fold array-axis-prod *)
(define-axis-fold array-axis-min min)
(define-axis-fold array-axis-max max)
(define-all-fold array-all-sum +)
(define-all-fold array-all-prod *)
(define-all-fold array-all-min min)
(define-all-fold array-all-max max)
(define-syntax-rule (array-count f arr ...)
(assert
(parameterize ([array-strictness #f])
(array-all-sum (inline-array-map (λ (b) (if b 1 0))
(array-map f arr ...))
0))
index?))
(define-syntax-rule (array-andmap pred? arr ...)
(parameterize ([array-strictness #f])
(array-all-and (array-map pred? arr ...))))
(define-syntax-rule (array-ormap pred? arr ...)
(parameterize ([array-strictness #f])
(array-all-or (array-map pred? arr ...))))
(provide array-axis-fold
array-axis-sum
array-axis-prod
array-axis-min
array-axis-max
array-axis-count
array-axis-and
array-axis-or
array-fold
array-all-fold
array-all-sum
array-all-prod
array-all-min
array-all-max
array-all-and
array-all-or
array-count
array-andmap
array-ormap
array-axis-reduce
unsafe-array-axis-reduce
array->list-array)
|