File: IndexSpace.hs

package info (click to toggle)
haskell-repa 3.4.1.5-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 304 kB
  • sloc: haskell: 3,135; makefile: 2
file content (207 lines) | stat: -rw-r--r-- 6,247 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
{-# LANGUAGE TypeOperators, ExplicitForAll, FlexibleContexts #-}

module Data.Array.Repa.Operators.IndexSpace
        ( reshape
        , append, (++)
        , transpose
        , extract
        , backpermute,         unsafeBackpermute
        , backpermuteDft,      unsafeBackpermuteDft
        , extend,              unsafeExtend 
        , slice,               unsafeSlice)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Base
import Data.Array.Repa.Repr.Delayed
import Data.Array.Repa.Operators.Traversal
import Data.Array.Repa.Shape            as S
import Prelude                          hiding ((++), traverse)
import qualified Prelude                as P 

stage   = "Data.Array.Repa.Operators.IndexSpace"

-- Index space transformations ------------------------------------------------
-- | Impose a new shape on the elements of an array.
--   The new extent must be the same size as the original, else `error`.
reshape :: ( Shape sh1, Shape sh2
           , Source r1 e)
        => sh2
        -> Array r1 sh1 e
        -> Array D  sh2 e

reshape sh2 arr
        | not $ S.size sh2 == S.size (extent arr)
        = error 
        $ stage P.++ ".reshape: reshaped array will not match size of the original"

reshape sh2 arr
        = fromFunction sh2 
        $ unsafeIndex arr . fromIndex (extent arr) . toIndex sh2
{-# INLINE [2] reshape #-}
 

-- | Append two arrays.
append, (++)
        :: ( Shape sh
           , Source r1 e, Source r2 e)
        => Array r1 (sh :. Int) e
        -> Array r2 (sh :. Int) e
        -> Array D  (sh :. Int) e

append arr1 arr2
 = unsafeTraverse2 arr1 arr2 fnExtent fnElem
 where
        (_ :. n)        = extent arr1

        fnExtent (sh1 :. i) (sh2  :. j)
                = intersectDim sh1 sh2 :. (i + j)

        fnElem f1 f2 (sh :. i)
                | i < n         = f1 (sh :. i)
                | otherwise     = f2 (sh :. (i - n))
{-# INLINE [2] append #-}


(++) arr1 arr2 = append arr1 arr2
{-# INLINE (++) #-}


-- | Transpose the lowest two dimensions of an array.
--      Transposing an array twice yields the original.
transpose
        :: (Shape sh, Source r e)
        => Array  r (sh :. Int :. Int) e
        -> Array  D (sh :. Int :. Int) e

transpose arr
 = unsafeTraverse arr
        (\(sh :. m :. n)        -> (sh :. n :.m))
        (\f -> \(sh :. i :. j)  -> f (sh :. j :. i))
{-# INLINE [2] transpose #-}


-- | Extract a sub-range of elements from an array.
extract :: (Shape sh, Source r e)
        => sh                   -- ^ Starting index.
        -> sh                   -- ^ Size of result.
        -> Array r sh e 
        -> Array D sh e
extract start sz arr
        = fromFunction sz (\ix -> arr `unsafeIndex` (addDim start ix))
{-# INLINE [2] extract #-}


-- | Backwards permutation of an array's elements.
backpermute, unsafeBackpermute
        :: forall r sh1 sh2 e
        .  ( Shape sh1
           , Source r e)
        => sh2                  -- ^ Extent of result array.
        -> (sh2 -> sh1)         -- ^ Function mapping each index in the result array
                                --      to an index of the source array.
        -> Array r  sh1 e       -- ^ Source array.
        -> Array D  sh2 e

backpermute newExtent perm arr
        = traverse arr (const newExtent) (. perm)
{-# INLINE [2] backpermute #-}

unsafeBackpermute newExtent perm arr
        = unsafeTraverse arr (const newExtent) (. perm)
{-# INLINE [2] unsafeBackpermute #-}


-- | Default backwards permutation of an array's elements.
--      If the function returns `Nothing` then the value at that index is taken
--      from the default array (@arrDft@)
backpermuteDft, unsafeBackpermuteDft
        :: forall r1 r2 sh1 sh2 e
        .  ( Shape sh1,   Shape sh2
           , Source r1 e, Source r2 e)
        => Array r2 sh2 e       -- ^ Default values (@arrDft@)
        -> (sh2 -> Maybe sh1)   -- ^ Function mapping each index in the result array
                                --      to an index in the source array.
        -> Array r1 sh1 e       -- ^ Source array.
        -> Array D  sh2 e

backpermuteDft arrDft fnIndex arrSrc
        = fromFunction (extent arrDft) fnElem
        where   fnElem ix
                 = case fnIndex ix of
                        Just ix'        -> arrSrc `index` ix'
                        Nothing         -> arrDft `index` ix
{-# INLINE [2] backpermuteDft #-}

unsafeBackpermuteDft arrDft fnIndex arrSrc
        = fromFunction (extent arrDft) fnElem
        where   fnElem ix
                 = case fnIndex ix of
                        Just ix'        -> arrSrc `unsafeIndex` ix'
                        Nothing         -> arrDft `unsafeIndex` ix
{-# INLINE [2] unsafeBackpermuteDft #-}



-- | Extend an array, according to a given slice specification.
--
--   For example, to replicate the rows of an array use the following:
--
--   @extend (Any :. (5::Int) :. All) arr@
--
extend, unsafeExtend
        :: ( Slice sl
           , Shape (SliceShape sl)
           , Source r e)
        => sl
        -> Array r (SliceShape sl) e
        -> Array D (FullShape sl)  e

extend sl arr
        = backpermute
                (fullOfSlice sl (extent arr))
                (sliceOfFull sl)
                arr
{-# INLINE [2] extend #-}

unsafeExtend sl arr
        = unsafeBackpermute
                (fullOfSlice sl (extent arr))
                (sliceOfFull sl)
                arr
{-# INLINE [2] unsafeExtend #-}



-- | Take a slice from an array, according to a given specification.
--
--   For example, to take a row from a matrix use the following:
--
--   @slice arr (Any :. (5::Int) :. All)@
--
--   To take a column use:
--
--   @slice arr (Any :. (5::Int))@
--
slice, unsafeSlice
        :: ( Slice sl
           , Shape (FullShape sl)
           , Source r e)
        => Array r (FullShape sl) e
        -> sl
        -> Array D (SliceShape sl) e

slice arr sl
        = backpermute
                (sliceOfFull sl (extent arr))
                (fullOfSlice sl)
                arr
{-# INLINE [2] slice #-}


unsafeSlice arr sl
        = unsafeBackpermute
                (sliceOfFull sl (extent arr))
                (fullOfSlice sl)
                arr
{-# INLINE [2] unsafeSlice #-}