1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
|
{-# LANGUAGE TypeOperators, ExplicitForAll, FlexibleContexts #-}
module Data.Array.Repa.Operators.IndexSpace
( reshape
, append, (++)
, transpose
, extract
, backpermute, unsafeBackpermute
, backpermuteDft, unsafeBackpermuteDft
, extend, unsafeExtend
, slice, unsafeSlice)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Base
import Data.Array.Repa.Repr.Delayed
import Data.Array.Repa.Operators.Traversal
import Data.Array.Repa.Shape as S
import Prelude hiding ((++), traverse)
import qualified Prelude as P
stage = "Data.Array.Repa.Operators.IndexSpace"
-- Index space transformations ------------------------------------------------
-- | Impose a new shape on the elements of an array.
-- The new extent must be the same size as the original, else `error`.
reshape :: ( Shape sh1, Shape sh2
, Source r1 e)
=> sh2
-> Array r1 sh1 e
-> Array D sh2 e
reshape sh2 arr
| not $ S.size sh2 == S.size (extent arr)
= error
$ stage P.++ ".reshape: reshaped array will not match size of the original"
reshape sh2 arr
= fromFunction sh2
$ unsafeIndex arr . fromIndex (extent arr) . toIndex sh2
{-# INLINE [2] reshape #-}
-- | Append two arrays.
append, (++)
:: ( Shape sh
, Source r1 e, Source r2 e)
=> Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e
-> Array D (sh :. Int) e
append arr1 arr2
= unsafeTraverse2 arr1 arr2 fnExtent fnElem
where
(_ :. n) = extent arr1
fnExtent (sh1 :. i) (sh2 :. j)
= intersectDim sh1 sh2 :. (i + j)
fnElem f1 f2 (sh :. i)
| i < n = f1 (sh :. i)
| otherwise = f2 (sh :. (i - n))
{-# INLINE [2] append #-}
(++) arr1 arr2 = append arr1 arr2
{-# INLINE (++) #-}
-- | Transpose the lowest two dimensions of an array.
-- Transposing an array twice yields the original.
transpose
:: (Shape sh, Source r e)
=> Array r (sh :. Int :. Int) e
-> Array D (sh :. Int :. Int) e
transpose arr
= unsafeTraverse arr
(\(sh :. m :. n) -> (sh :. n :.m))
(\f -> \(sh :. i :. j) -> f (sh :. j :. i))
{-# INLINE [2] transpose #-}
-- | Extract a sub-range of elements from an array.
extract :: (Shape sh, Source r e)
=> sh -- ^ Starting index.
-> sh -- ^ Size of result.
-> Array r sh e
-> Array D sh e
extract start sz arr
= fromFunction sz (\ix -> arr `unsafeIndex` (addDim start ix))
{-# INLINE [2] extract #-}
-- | Backwards permutation of an array's elements.
backpermute, unsafeBackpermute
:: forall r sh1 sh2 e
. ( Shape sh1
, Source r e)
=> sh2 -- ^ Extent of result array.
-> (sh2 -> sh1) -- ^ Function mapping each index in the result array
-- to an index of the source array.
-> Array r sh1 e -- ^ Source array.
-> Array D sh2 e
backpermute newExtent perm arr
= traverse arr (const newExtent) (. perm)
{-# INLINE [2] backpermute #-}
unsafeBackpermute newExtent perm arr
= unsafeTraverse arr (const newExtent) (. perm)
{-# INLINE [2] unsafeBackpermute #-}
-- | Default backwards permutation of an array's elements.
-- If the function returns `Nothing` then the value at that index is taken
-- from the default array (@arrDft@)
backpermuteDft, unsafeBackpermuteDft
:: forall r1 r2 sh1 sh2 e
. ( Shape sh1, Shape sh2
, Source r1 e, Source r2 e)
=> Array r2 sh2 e -- ^ Default values (@arrDft@)
-> (sh2 -> Maybe sh1) -- ^ Function mapping each index in the result array
-- to an index in the source array.
-> Array r1 sh1 e -- ^ Source array.
-> Array D sh2 e
backpermuteDft arrDft fnIndex arrSrc
= fromFunction (extent arrDft) fnElem
where fnElem ix
= case fnIndex ix of
Just ix' -> arrSrc `index` ix'
Nothing -> arrDft `index` ix
{-# INLINE [2] backpermuteDft #-}
unsafeBackpermuteDft arrDft fnIndex arrSrc
= fromFunction (extent arrDft) fnElem
where fnElem ix
= case fnIndex ix of
Just ix' -> arrSrc `unsafeIndex` ix'
Nothing -> arrDft `unsafeIndex` ix
{-# INLINE [2] unsafeBackpermuteDft #-}
-- | Extend an array, according to a given slice specification.
--
-- For example, to replicate the rows of an array use the following:
--
-- @extend (Any :. (5::Int) :. All) arr@
--
extend, unsafeExtend
:: ( Slice sl
, Shape (SliceShape sl)
, Source r e)
=> sl
-> Array r (SliceShape sl) e
-> Array D (FullShape sl) e
extend sl arr
= backpermute
(fullOfSlice sl (extent arr))
(sliceOfFull sl)
arr
{-# INLINE [2] extend #-}
unsafeExtend sl arr
= unsafeBackpermute
(fullOfSlice sl (extent arr))
(sliceOfFull sl)
arr
{-# INLINE [2] unsafeExtend #-}
-- | Take a slice from an array, according to a given specification.
--
-- For example, to take a row from a matrix use the following:
--
-- @slice arr (Any :. (5::Int) :. All)@
--
-- To take a column use:
--
-- @slice arr (Any :. (5::Int))@
--
slice, unsafeSlice
:: ( Slice sl
, Shape (FullShape sl)
, Source r e)
=> Array r (FullShape sl) e
-> sl
-> Array D (SliceShape sl) e
slice arr sl
= backpermute
(sliceOfFull sl (extent arr))
(fullOfSlice sl)
arr
{-# INLINE [2] slice #-}
unsafeSlice arr sl
= unsafeBackpermute
(sliceOfFull sl (extent arr))
(fullOfSlice sl)
arr
{-# INLINE [2] unsafeSlice #-}
|