1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
|
{-# LANGUAGE CPP
, DeriveDataTypeable
, NoImplicitPrelude
, ImpredicativeTypes
, RankNTypes #-}
#if __GLASGOW_HASKELL__ >= 701
{-# LANGUAGE Trustworthy #-}
#endif
--------------------------------------------------------------------------------
-- |
-- Module : Control.Concurrent.Thread.Group
-- Copyright : (c) 2010-2012 Bas van Dijk & Roel van Dijk
-- License : BSD3 (see the file LICENSE)
-- Maintainer : Bas van Dijk <v.dijk.bas@gmail.com>
-- , Roel van Dijk <vandijk.roel@gmail.com>
--
-- This module extends @Control.Concurrent.Thread@ with the ability to wait for
-- a group of threads to terminate.
--
-- This module exports equivalently named functions from @Control.Concurrent@,
-- (@GHC.Conc@), and @Control.Concurrent.Thread@. Avoid ambiguities by importing
-- this module qualified. May we suggest:
--
-- @
-- import Control.Concurrent.Thread.Group ( ThreadGroup )
-- import qualified Control.Concurrent.Thread.Group as ThreadGroup ( ... )
-- @
--
--------------------------------------------------------------------------------
module Control.Concurrent.Thread.Group
( ThreadGroup
, new
, nrOfRunning
, wait
, waitN
-- * Forking threads
, forkIO
, forkOS
, forkOn
, forkIOWithUnmask
, forkOnWithUnmask
) where
--------------------------------------------------------------------------------
-- Imports
--------------------------------------------------------------------------------
-- from base:
import qualified Control.Concurrent ( forkOS
, forkIOWithUnmask
, forkOnWithUnmask
)
import Control.Concurrent ( ThreadId )
import Control.Concurrent.MVar ( newEmptyMVar, putMVar, readMVar )
import Control.Exception ( try, mask )
import Control.Monad ( return, (>>=), when )
import Data.Function ( (.), ($) )
import Data.Functor ( fmap )
import Data.Eq ( Eq )
import Data.Ord ( (>=) )
import Data.Int ( Int )
import Data.Typeable ( Typeable )
import Prelude ( ($!), (+), subtract )
import System.IO ( IO )
-- from stm:
import Control.Concurrent.STM.TVar ( TVar, newTVarIO, readTVar, writeTVar )
import Control.Concurrent.STM ( STM, atomically, retry )
-- from threads:
import Control.Concurrent.Thread ( Result )
import Control.Concurrent.Raw ( rawForkIO, rawForkOn )
#ifdef __HADDOCK_VERSION__
import qualified Control.Concurrent.Thread as Thread ( forkIO
, forkOS
, forkOn
, forkIOWithUnmask
, forkOnWithUnmask
)
#endif
--------------------------------------------------------------------------------
-- * Thread groups
--------------------------------------------------------------------------------
{-| A @ThreadGroup@ can be understood as a counter which counts the number of
threads that were added to the group minus the ones that have terminated.
More formally a @ThreadGroup@ has the following semantics:
* 'new' initializes the counter to 0.
* Forking a thread increments the counter.
* When a forked thread terminates, whether normally or by raising an exception,
the counter is decremented.
* 'nrOfRunning' yields a transaction that returns the counter.
* 'wait' blocks as long as the counter is greater than 0.
* 'waitN' blocks as long as the counter is greater or equal to the
specified number.
-}
newtype ThreadGroup = ThreadGroup (TVar Int) deriving (Eq, Typeable)
-- | Create an empty group of threads.
new :: IO ThreadGroup
new = fmap ThreadGroup $ newTVarIO 0
{-| Yield a transaction that returns the number of running threads in the
group.
Note that because this function yields a 'STM' computation, the returned number
is guaranteed to be consistent inside the transaction.
-}
nrOfRunning :: ThreadGroup -> STM Int
nrOfRunning (ThreadGroup numThreadsTV) = readTVar numThreadsTV
-- | Block until all threads in the group have terminated.
--
-- Note that: @wait = 'waitN' 1@.
wait :: ThreadGroup -> IO ()
wait = waitN 1
-- | Block until there are fewer than @N@ running threads in the group.
waitN :: Int -> ThreadGroup -> IO ()
waitN i tg = atomically $ nrOfRunning tg >>= \n -> when (n >= i) retry
--------------------------------------------------------------------------------
-- * Forking threads
--------------------------------------------------------------------------------
-- | Same as @Control.Concurrent.Thread.'Thread.forkIO'@ but additionaly adds
-- the thread to the group.
forkIO :: ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkIO = fork rawForkIO
-- | Same as @Control.Concurrent.Thread.'Thread.forkOS'@ but additionaly adds
-- the thread to the group.
forkOS :: ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkOS = fork Control.Concurrent.forkOS
-- | Same as @Control.Concurrent.Thread.'Thread.forkOn'@ but
-- additionaly adds the thread to the group.
forkOn :: Int -> ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkOn = fork . rawForkOn
-- | Same as @Control.Concurrent.Thread.'Thread.forkIOWithUnmask'@ but
-- additionaly adds the thread to the group.
forkIOWithUnmask
:: ThreadGroup
-> ((forall b. IO b -> IO b) -> IO a)
-> IO (ThreadId, IO (Result a))
forkIOWithUnmask = forkWithUnmask Control.Concurrent.forkIOWithUnmask
-- | Like @Control.Concurrent.Thread.'Thread.forkOnWithUnmask'@ but
-- additionaly adds the thread to the group.
forkOnWithUnmask
:: Int
-> ThreadGroup
-> ((forall b. IO b -> IO b) -> IO a)
-> IO (ThreadId, IO (Result a))
forkOnWithUnmask = forkWithUnmask . Control.Concurrent.forkOnWithUnmask
--------------------------------------------------------------------------------
-- Utils
--------------------------------------------------------------------------------
fork :: (IO () -> IO ThreadId)
-> ThreadGroup
-> IO a
-> IO (ThreadId, IO (Result a))
fork doFork (ThreadGroup numThreadsTV) a = do
res <- newEmptyMVar
tid <- mask $ \restore -> do
atomically $ modifyTVar numThreadsTV (+ 1)
doFork $ do
try (restore a) >>= putMVar res
atomically $ modifyTVar numThreadsTV (subtract 1)
return (tid, readMVar res)
forkWithUnmask
:: (((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId)
-> ThreadGroup
-> ((forall b. IO b -> IO b) -> IO a)
-> IO (ThreadId, IO (Result a))
forkWithUnmask doForkWithUnmask = \(ThreadGroup numThreadsTV) f -> do
res <- newEmptyMVar
tid <- mask $ \restore -> do
atomically $ modifyTVar numThreadsTV (+ 1)
doForkWithUnmask $ \unmask -> do
try (restore $ f unmask) >>= putMVar res
atomically $ modifyTVar numThreadsTV (subtract 1)
return (tid, readMVar res)
-- | Strictly modify the contents of a 'TVar'.
modifyTVar :: TVar a -> (a -> a) -> STM ()
modifyTVar tv f = readTVar tv >>= writeTVar tv .! f
-- | Strict function composition
(.!) :: (b -> c) -> (a -> b) -> (a -> c)
f .! g = \x -> f $! g x
|