1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
----------------------------------------------------------------------------
-- |
-- Module : TestUtils
-- Copyright : (c) Sergey Vinokurov 2022
-- License : Apache-2.0 (see LICENSE)
-- Maintainer : serg.foo@gmail.com
----------------------------------------------------------------------------
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE NamedFieldPuns #-}
module TestUtils
( Delay(..)
, sleep
, Iterations(..)
, callN
, Thread(..)
, runThread
, Threads(..)
, spawnAndCall
) where
import Control.Concurrent
import Control.Concurrent.Async
import Control.Monad
import Control.Monad.IO.Class
import Data.Foldable
import Data.List.NonEmpty (NonEmpty(..))
import GHC.Generics (Generic)
import Test.QuickCheck
-- In microseconds
newtype Delay = Delay { unDelay :: Int }
deriving (Eq, Show)
sleep :: MonadIO m => Delay -> m ()
sleep (Delay n) = case n of
0 -> pure ()
k -> liftIO $ threadDelay k
instance Arbitrary Delay where
arbitrary = Delay <$> chooseInt (0, 10)
shrink = map Delay . filter (\x -> 0 <= x && x <= 25) . shrink . unDelay
newtype Iterations = Iterations { unIterations :: Int }
deriving (Eq, Show)
instance Arbitrary Iterations where
arbitrary = Iterations <$> chooseInt (0, 50)
shrink = map Iterations . filter (\x -> 0 <= x && x <= 50) . shrink . unIterations
callN :: Applicative m => Iterations -> m a -> m ()
callN (Iterations !n) action = go n
where
go !k =
if k > 0
then action *> go (k - 1)
else pure ()
data Thread = Thread
{ tDelay :: Delay
, tIncrement :: Int
, tIterations :: Iterations
} deriving (Eq, Show, Generic)
instance Arbitrary Thread where
arbitrary = Thread <$> arbitrary <*> chooseInt (-1000, 1000) <*> arbitrary
shrink = filter ((<= 1000) . abs . tIncrement) . genericShrink
runThread :: MonadIO m => Thread -> (Delay -> m a) -> (Int -> m b) -> m ()
runThread Thread{tDelay, tIncrement, tIterations} doSleep f =
callN tIterations (f tIncrement *> doSleep tDelay)
newtype Threads = Threads { unThreads :: NonEmpty Thread }
deriving (Eq, Show)
instance Arbitrary Threads where
arbitrary = do
n <- chooseInt (0, 31)
Threads <$> ((:|) <$> arbitrary <*> replicateM n arbitrary)
shrink = map Threads . genericShrink . unThreads
spawnAndCall :: Traversable f => f b -> IO a -> (a -> b -> IO ()) -> IO a
spawnAndCall threads mkRes action = do
res <- mkRes
traverse_ wait =<< traverse (async . action res) threads
pure res
|