File: TestUtils.hs

package info (click to toggle)
haskell-atomic-counter 0.1.2.4-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 144 kB
  • sloc: haskell: 490; makefile: 6
file content (89 lines) | stat: -rw-r--r-- 2,515 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
----------------------------------------------------------------------------
-- |
-- Module      :  TestUtils
-- Copyright   :  (c) Sergey Vinokurov 2022
-- License     :  Apache-2.0 (see LICENSE)
-- Maintainer  :  serg.foo@gmail.com
----------------------------------------------------------------------------

{-# LANGUAGE BangPatterns   #-}
{-# LANGUAGE DeriveGeneric  #-}
{-# LANGUAGE NamedFieldPuns #-}

module TestUtils
  ( Delay(..)
  , sleep
  , Iterations(..)
  , callN
  , Thread(..)
  , runThread
  , Threads(..)
  , spawnAndCall
  ) where

import Control.Concurrent
import Control.Concurrent.Async
import Control.Monad
import Control.Monad.IO.Class
import Data.Foldable
import Data.List.NonEmpty (NonEmpty(..))
import GHC.Generics (Generic)
import Test.QuickCheck

-- In microseconds
newtype Delay = Delay { unDelay :: Int }
  deriving (Eq, Show)

sleep :: MonadIO m => Delay -> m ()
sleep (Delay n) = case n of
  0 -> pure ()
  k -> liftIO $ threadDelay k

instance Arbitrary Delay where
  arbitrary = Delay <$> chooseInt (0, 10)
  shrink = map Delay . filter (\x -> 0 <= x && x <= 25) . shrink . unDelay

newtype Iterations = Iterations { unIterations :: Int }
  deriving (Eq, Show)

instance Arbitrary Iterations where
  arbitrary = Iterations <$> chooseInt (0, 50)
  shrink = map Iterations . filter (\x -> 0 <= x && x <= 50) . shrink . unIterations

callN :: Applicative m => Iterations -> m a -> m ()
callN (Iterations !n) action = go n
  where
    go !k =
      if k > 0
      then action *> go (k - 1)
      else pure ()

data Thread = Thread
  { tDelay      :: Delay
  , tIncrement  :: Int
  , tIterations :: Iterations
  } deriving (Eq, Show, Generic)

instance Arbitrary Thread where
  arbitrary = Thread <$> arbitrary <*> chooseInt (-1000, 1000) <*> arbitrary
  shrink = filter ((<= 1000) . abs . tIncrement) . genericShrink

runThread :: MonadIO m => Thread -> (Delay -> m a) -> (Int -> m b) -> m ()
runThread Thread{tDelay, tIncrement, tIterations} doSleep f =
  callN tIterations (f tIncrement *> doSleep tDelay)

newtype Threads = Threads { unThreads :: NonEmpty Thread }
  deriving (Eq, Show)

instance Arbitrary Threads where
  arbitrary = do
    n <- chooseInt (0, 31)
    Threads <$> ((:|) <$> arbitrary <*> replicateM n arbitrary)
  shrink = map Threads . genericShrink . unThreads

spawnAndCall :: Traversable f => f b -> IO a -> (a -> b -> IO ()) -> IO a
spawnAndCall threads mkRes action = do
  res <- mkRes
  traverse_ wait =<< traverse (async . action res) threads
  pure res