1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
|
-- |
-- Module : Foundation.Random.XorShift
-- License : BSD-style
--
-- XorShift variant: Xoroshiro128+
-- <https://en.wikipedia.org/wiki/Xoroshiro128%2B>
--
-- C implementation at:
-- <http://xoroshiro.di.unimi.it/xoroshiro128plus.c>
--
{-# LANGUAGE MagicHash #-}
module Foundation.Random.XorShift
( State
, initialize
, next
, nextList
, nextDouble
) where
import Basement.Imports
import Basement.PrimType
import Basement.Types.OffsetSize
import Foundation.Numerical
import Foundation.Bits
import Foundation.Random.Class
import Foundation.Random.DRG
import Basement.Compat.Bifunctor
import Basement.Compat.ExtList (reverse)
import qualified Basement.UArray as A
import qualified Prelude
import GHC.Prim
import GHC.Float
-- | State of Xoroshiro128 plus
data State = State {-# UNPACK #-} !Word64 {-# UNPACK #-} !Word64
instance RandomGen State where
randomNew = initialize <$> getRandomWord64 <*> getRandomWord64
randomNewFrom bs
| A.length bs == 16 =
let bs64 = A.recast bs
in Just $ State (A.index bs64 0) (A.index bs64 1)
| otherwise = Nothing
randomGenerate = generate
randomGenerateWord64 = next
randomGenerateF32 = nextFloat
randomGenerateF64 = nextDouble
initialize :: Word64 -> Word64 -> State
initialize s0 s1 = State s0 s1
generate :: CountOf Word8 -> State -> (UArray Word8, State)
generate c st =
first (A.take c . A.unsafeRecast . fromList) $ nextList c64 st
where
c64 = sizeRecast c'
c' = countOfRoundUp 8 c
next :: State -> (Word64, State)
next (State s0 s1prev) = (s0 + s1prev, State s0' s1')
where
!s1 = s0 `xor` s1prev
s0' = (s0 `rotateL` 55) `xor` s1 `xor` (s1 .<<. 14)
s1' = (s1 `rotateL` 36)
nextList :: CountOf Word64 -> State -> ([Word64], State)
nextList c state = loop [] state 0
where
loop acc st o
| o .==# c = (reverse acc, st)
| otherwise =
let (w, st') = next st
in loop (w:acc) st' (o+1)
nextFloat :: State -> (Float, State)
nextFloat = first dToF . nextDouble
where dToF (D# d) = F# (double2Float# d)
nextDouble :: State -> (Double, State)
nextDouble !st = (d' - 1.0 , st')
where
!(w, st') = next st
upperMask = 0x3FF0000000000000
lowerMask = 0x000FFFFFFFFFFFFF
d' :: Double
d' = Prelude.fromIntegral d
d = upperMask .|. (w .&. lowerMask)
|