File: Selection.hs

package info (click to toggle)
haskell-repa 3.4.1.5-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 304 kB
  • sloc: haskell: 3,135; makefile: 2
file content (131 lines) | stat: -rw-r--r-- 5,030 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-}
module Data.Array.Repa.Eval.Selection
        (selectChunkedS, selectChunkedP)
where
import Data.Array.Repa.Eval.Gang
import Data.Array.Repa.Shape
import Data.Vector.Unboxed                      as V
import Data.Vector.Unboxed.Mutable              as VM
import GHC.Base                                 (remInt, quotInt)
import Prelude                                  as P
import Control.Monad                            as P
import Data.IORef


-- | Select indices matching a predicate.
--  
--   * This primitive can be useful for writing filtering functions.
--
selectChunkedS
        :: Shape sh
        => (sh -> a -> IO ())   -- ^ Update function to write into result.
        -> (sh -> Bool)         -- ^ See if this predicate matches.
        -> (sh -> a)            -- ^  .. and apply fn to the matching index
        -> sh                   -- ^ Extent of indices to apply to predicate.
        -> IO Int               -- ^ Number of elements written to destination array.

{-# INLINE selectChunkedS #-}
selectChunkedS fnWrite fnMatch fnProduce !shSize
 = fill 0 0
 where  lenSrc  = size shSize

        fill !nSrc !nDst
         | nSrc >= lenSrc       = return nDst

         | ixSrc        <- fromIndex shSize nSrc
         , fnMatch ixSrc
         = do   fnWrite ixSrc (fnProduce ixSrc)
                fill (nSrc + 1) (nDst + 1)

         | otherwise
         =      fill (nSrc + 1) nDst


-- | Select indices matching a predicate, in parallel.
--  
--   * This primitive can be useful for writing filtering functions.
--
--   * The array is split into linear chunks, with one chunk being given to
--     each thread.
--
--   * The number of elements in the result array depends on how many threads
--     you're running the program with.
--
selectChunkedP
        :: forall a
        .  Unbox a
        => (Int -> Bool)        -- ^ See if this predicate matches.
        -> (Int -> a)           --   .. and apply fn to the matching index
        -> Int                  -- Extent of indices to apply to predicate.
        -> IO [IOVector a]      -- Chunks containing array elements.

{-# INLINE selectChunkedP #-}
selectChunkedP fnMatch fnProduce !len
 = do
        -- Make IORefs that the threads will write their result chunks to.
        -- We start with a chunk size proportial to the number of threads we have,
        -- but the threads themselves can grow the chunks if they run out of space.
        refs    <- P.replicateM threads
                $ do    vec     <- VM.new $ len `div` threads
                        newIORef vec

        -- Fire off a thread to fill each chunk.
        gangIO theGang
         $ \thread -> makeChunk (refs !! thread)
                        (splitIx thread)
                        (splitIx (thread + 1) - 1)

        -- Read the result chunks back from the IORefs.
        -- If a thread had to grow a chunk, then these might not be the same ones
        -- we created back in the first step.
        P.mapM readIORef refs

 where  -- See how many threads we have available.
        !threads        = gangSize theGang
        !chunkLen       = len `quotInt` threads
        !chunkLeftover  = len `remInt`  threads


        -- Decide where to split the source array.
        {-# INLINE splitIx #-}
        splitIx thread
         | thread < chunkLeftover = thread * (chunkLen + 1)
         | otherwise              = thread * chunkLen  + chunkLeftover


        -- Fill the given chunk with elements selected from this range of indices.
        makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
        makeChunk !ref !ixSrc !ixSrcEnd
         | ixSrc > ixSrcEnd
         = do  vecDst   <- VM.new 0
               writeIORef ref vecDst

         | otherwise
         = do  vecDst   <- VM.new (len `div` threads)
               vecDst'  <- fillChunk ixSrc ixSrcEnd vecDst 0 (VM.length vecDst)
               writeIORef ref vecDst'


        -- The main filling loop.
        fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
        fillChunk !ixSrc !ixSrcEnd !vecDst !ixDst !ixDstLen
         -- If we've finished selecting elements, then slice the vector down
         -- so it doesn't have any empty space at the end.
         | ixSrc > ixSrcEnd
         =      return  $ VM.slice 0 ixDst vecDst

         -- If we've run out of space in the chunk then grow it some more.
         | ixDst >= ixDstLen
         = do   let ixDstLen'   = (VM.length vecDst + 1) * 2
                vecDst'         <- VM.grow vecDst ixDstLen'
                fillChunk ixSrc ixSrcEnd vecDst' ixDst ixDstLen'

         -- We've got a maching element, so add it to the chunk.
         | fnMatch ixSrc
         = do   VM.unsafeWrite vecDst ixDst (fnProduce ixSrc)
                fillChunk (ixSrc + 1) ixSrcEnd vecDst (ixDst + 1) ixDstLen

         -- The element doesnt match, so keep going.
         | otherwise
         =      fillChunk (ixSrc + 1) ixSrcEnd vecDst ixDst ixDstLen