1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
-- | A simple index operation representation. Every operation corresponds to a
-- constructor.
module Futhark.IR.Mem.IxFun.Alg
( IxFun (..),
iota,
offsetIndex,
permute,
reshape,
coerce,
slice,
flatSlice,
expand,
shape,
index,
disjoint,
)
where
import Data.List qualified as L
import Data.Set qualified as S
import Futhark.IR.Pretty ()
import Futhark.IR.Prop
import Futhark.IR.Syntax
( DimIndex (..),
FlatDimIndex (..),
FlatSlice (..),
Slice (..),
flatSliceDims,
sliceDims,
unitSlice,
)
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (div, mod, span)
type Shape num = [num]
type Indices num = [num]
type Permutation = [Int]
data IxFun num
= Direct (Shape num)
| Permute (IxFun num) Permutation
| Index (IxFun num) (Slice num)
| FlatIndex (IxFun num) (FlatSlice num)
| Reshape (IxFun num) (Shape num)
| Coerce (IxFun num) (Shape num)
| OffsetIndex (IxFun num) num
| Expand num num (IxFun num)
deriving (Eq, Show)
instance (Pretty num) => Pretty (IxFun num) where
pretty (Direct dims) =
"Direct" <> parens (commasep $ map pretty dims)
pretty (Permute fun perm) = pretty fun <> pretty perm
pretty (Index fun is) = pretty fun <> pretty is
pretty (FlatIndex fun is) = pretty fun <> pretty is
pretty (Reshape fun oldshape) =
pretty fun
<> "->reshape"
<> parens (pretty oldshape)
pretty (Coerce fun oldshape) =
pretty fun
<> "->coerce"
<> parens (pretty oldshape)
pretty (OffsetIndex fun i) =
pretty fun <> "->offset_index" <> parens (pretty i)
pretty (Expand o p fun) =
"expand(" <> pretty o <> "," <+> pretty p <> "," <+> pretty fun <> ")"
iota :: Shape num -> IxFun num
iota = Direct
offsetIndex :: IxFun num -> num -> IxFun num
offsetIndex = OffsetIndex
permute :: IxFun num -> Permutation -> IxFun num
permute = Permute
slice :: IxFun num -> Slice num -> IxFun num
slice = Index
flatSlice :: IxFun num -> FlatSlice num -> IxFun num
flatSlice = FlatIndex
expand :: num -> num -> IxFun num -> IxFun num
expand = Expand
reshape :: IxFun num -> Shape num -> IxFun num
reshape = Reshape
coerce :: IxFun num -> Shape num -> IxFun num
coerce = Reshape
shape ::
(IntegralExp num) =>
IxFun num ->
Shape num
shape (Direct dims) =
dims
shape (Permute ixfun perm) =
rearrangeShape perm $ shape ixfun
shape (Index _ how) =
sliceDims how
shape (FlatIndex ixfun how) =
flatSliceDims how <> tail (shape ixfun)
shape (Reshape _ dims) =
dims
shape (Coerce _ dims) =
dims
shape (OffsetIndex ixfun _) =
shape ixfun
shape (Expand _ _ ixfun) =
shape ixfun
index ::
(Eq num, IntegralExp num) =>
IxFun num ->
Indices num ->
num
index (Direct dims) is =
sum $ zipWith (*) is slicesizes
where
slicesizes = drop 1 $ sliceSizes dims
index (Permute fun perm) is_new =
index fun is_old
where
is_old = rearrangeShape (rearrangeInverse perm) is_new
index (Index fun (Slice js)) is =
index fun (adjust js is)
where
adjust (DimFix j : js') is' = j : adjust js' is'
adjust (DimSlice j _ s : js') (i : is') = j + i * s : adjust js' is'
adjust _ _ = []
index (FlatIndex fun (FlatSlice offset js)) is =
index fun $ sum (offset : zipWith f is js) : drop (length js) is
where
f i (FlatDimIndex _ s) = i * s
index (Reshape fun newshape) is =
let new_indices = reshapeIndex (shape fun) newshape is
in index fun new_indices
index (Coerce fun _) is =
index fun is
index (OffsetIndex fun i) is =
case shape fun of
d : ds ->
index (Index fun (Slice (DimSlice i (d - i) 1 : map (unitSlice 0) ds))) is
[] -> error "index: OffsetIndex: underlying index function has rank zero"
index (Expand o p ixfun) is =
o + p * index ixfun is
allPoints :: (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints dims =
let total = product dims
strides = drop 1 $ L.reverse $ scanl (*) 1 $ L.reverse dims
in map (unflatInd strides) [0 .. total - 1]
where
unflatInd strides x =
fst $
foldl
( \(res, acc) span ->
(res ++ [acc `div` span], acc `mod` span)
)
([], x)
strides
disjoint :: (IntegralExp num, Ord num, Enum num) => IxFun num -> IxFun num -> Bool
disjoint ixf1 ixf2 =
let shp1 = shape ixf1
points1 = S.fromList $ allPoints shp1
allIdxs1 = S.map (index ixf1) points1
shp2 = shape ixf2
points2 = S.fromList $ allPoints shp2
allIdxs2 = S.map (index ixf2) points2
in S.disjoint allIdxs1 allIdxs2
|