1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-}
module Data.Array.Repa.Eval.Selection
(selectChunkedS, selectChunkedP)
where
import Data.Array.Repa.Eval.Gang
import Data.Array.Repa.Shape
import Data.Vector.Unboxed as V
import Data.Vector.Unboxed.Mutable as VM
import GHC.Base (remInt, quotInt)
import Prelude as P
import Control.Monad as P
import Data.IORef
-- | Select indices matching a predicate.
--
-- * This primitive can be useful for writing filtering functions.
--
selectChunkedS
:: Shape sh
=> (sh -> a -> IO ()) -- ^ Update function to write into result.
-> (sh -> Bool) -- ^ See if this predicate matches.
-> (sh -> a) -- ^ .. and apply fn to the matching index
-> sh -- ^ Extent of indices to apply to predicate.
-> IO Int -- ^ Number of elements written to destination array.
{-# INLINE selectChunkedS #-}
selectChunkedS fnWrite fnMatch fnProduce !shSize
= fill 0 0
where lenSrc = size shSize
fill !nSrc !nDst
| nSrc >= lenSrc = return nDst
| ixSrc <- fromIndex shSize nSrc
, fnMatch ixSrc
= do fnWrite ixSrc (fnProduce ixSrc)
fill (nSrc + 1) (nDst + 1)
| otherwise
= fill (nSrc + 1) nDst
-- | Select indices matching a predicate, in parallel.
--
-- * This primitive can be useful for writing filtering functions.
--
-- * The array is split into linear chunks, with one chunk being given to
-- each thread.
--
-- * The number of elements in the result array depends on how many threads
-- you're running the program with.
--
selectChunkedP
:: forall a
. Unbox a
=> (Int -> Bool) -- ^ See if this predicate matches.
-> (Int -> a) -- .. and apply fn to the matching index
-> Int -- Extent of indices to apply to predicate.
-> IO [IOVector a] -- Chunks containing array elements.
{-# INLINE selectChunkedP #-}
selectChunkedP fnMatch fnProduce !len
= do
-- Make IORefs that the threads will write their result chunks to.
-- We start with a chunk size proportial to the number of threads we have,
-- but the threads themselves can grow the chunks if they run out of space.
refs <- P.replicateM threads
$ do vec <- VM.new $ len `div` threads
newIORef vec
-- Fire off a thread to fill each chunk.
gangIO theGang
$ \thread -> makeChunk (refs !! thread)
(splitIx thread)
(splitIx (thread + 1) - 1)
-- Read the result chunks back from the IORefs.
-- If a thread had to grow a chunk, then these might not be the same ones
-- we created back in the first step.
P.mapM readIORef refs
where -- See how many threads we have available.
!threads = gangSize theGang
!chunkLen = len `quotInt` threads
!chunkLeftover = len `remInt` threads
-- Decide where to split the source array.
{-# INLINE splitIx #-}
splitIx thread
| thread < chunkLeftover = thread * (chunkLen + 1)
| otherwise = thread * chunkLen + chunkLeftover
-- Fill the given chunk with elements selected from this range of indices.
makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
makeChunk !ref !ixSrc !ixSrcEnd
| ixSrc > ixSrcEnd
= do vecDst <- VM.new 0
writeIORef ref vecDst
| otherwise
= do vecDst <- VM.new (len `div` threads)
vecDst' <- fillChunk ixSrc ixSrcEnd vecDst 0 (VM.length vecDst)
writeIORef ref vecDst'
-- The main filling loop.
fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
fillChunk !ixSrc !ixSrcEnd !vecDst !ixDst !ixDstLen
-- If we've finished selecting elements, then slice the vector down
-- so it doesn't have any empty space at the end.
| ixSrc > ixSrcEnd
= return $ VM.slice 0 ixDst vecDst
-- If we've run out of space in the chunk then grow it some more.
| ixDst >= ixDstLen
= do let ixDstLen' = (VM.length vecDst + 1) * 2
vecDst' <- VM.grow vecDst ixDstLen'
fillChunk ixSrc ixSrcEnd vecDst' ixDst ixDstLen'
-- We've got a maching element, so add it to the chunk.
| fnMatch ixSrc
= do VM.unsafeWrite vecDst ixDst (fnProduce ixSrc)
fillChunk (ixSrc + 1) ixSrcEnd vecDst (ixDst + 1) ixDstLen
-- The element doesnt match, so keep going.
| otherwise
= fillChunk (ixSrc + 1) ixSrcEnd vecDst ixDst ixDstLen
|