File: CondensedTable.hs

package info (click to toggle)
haskell-mwc-random 0.15.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 388 kB
  • sloc: haskell: 1,333; makefile: 2
file content (285 lines) | stat: -rw-r--r-- 10,618 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
{-# LANGUAGE FlexibleContexts #-}
-- |
-- Module    : System.Random.MWC.CondensedTable
-- Copyright : (c) 2012 Aleksey Khudyakov
-- License   : BSD3
--
-- Maintainer  : bos@serpentine.com
-- Stability   : experimental
-- Portability : portable
--
-- Table-driven generation of random variates.  This approach can
-- generate random variates in /O(1)/ time for the supported
-- distributions, at a modest cost in initialization time.
module System.Random.MWC.CondensedTable (
    -- * Condensed tables
    CondensedTable
  , CondensedTableV
  , CondensedTableU
  , genFromTable
    -- * Constructors for tables
  , tableFromProbabilities
  , tableFromWeights
  , tableFromIntWeights
    -- ** Disrete distributions
  , tablePoisson
  , tableBinomial
    -- * References
    -- $references
  ) where

import Control.Arrow           (second,(***))

import Data.Word
import Data.Int
import Data.Bits
import qualified Data.Vector.Generic         as G
import           Data.Vector.Generic           ((++))
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed         as U
import qualified Data.Vector                 as V
import Data.Vector.Generic (Vector)
import Numeric.SpecFunctions (logFactorial)
import System.Random.Stateful

import Prelude hiding ((++))



-- | A lookup table for arbitrary discrete distributions. It allows
-- the generation of random variates in /O(1)/. Note that probability
-- is quantized in units of @1/2^32@, and all distributions with
-- infinite support (e.g. Poisson) should be truncated.
data CondensedTable v a =
  CondensedTable
  {-# UNPACK #-} !Word64 !(v a) -- Lookup limit and first table
  {-# UNPACK #-} !Word64 !(v a) -- Second table
  {-# UNPACK #-} !Word64 !(v a) -- Third table
  !(v a)                        -- Last table

-- Implementation note. We have to store lookup limit in Word64 since
-- we need to accomodate two cases. First is when we have no values in
-- lookup table, second is when all elements are there
--
-- Both are pretty easy to realize. For first one probability of every
-- outcome should be less then 1/256, latter arise when probabilities
-- of two outcomes are [0.5,0.5]

-- | A 'CondensedTable' that uses unboxed vectors.
type CondensedTableU = CondensedTable U.Vector

-- | A 'CondensedTable' that uses boxed vectors, and is able to hold
-- any type of element.
type CondensedTableV = CondensedTable V.Vector



-- | Generate a random value using a condensed table.
genFromTable :: (StatefulGen g m, Vector v a) => CondensedTable v a -> g -> m a
{-# INLINE genFromTable #-}
genFromTable table gen = do
  w <- uniformM gen
  return $! lookupTable table $ fromIntegral (w :: Word32)

lookupTable :: Vector v a => CondensedTable v a -> Word64 -> a
{-# INLINE lookupTable #-}
lookupTable (CondensedTable na aa nb bb nc cc dd) i
  | i < na    = aa `at` ( i       `shiftR` 24)
  | i < nb    = bb `at` ((i - na) `shiftR` 16)
  | i < nc    = cc `at` ((i - nb) `shiftR` 8 )
  | otherwise = dd `at` ( i - nc)
  where
    at arr j = G.unsafeIndex arr (fromIntegral j)


----------------------------------------------------------------
-- Table generation
----------------------------------------------------------------

-- | Generate a condensed lookup table from a list of outcomes with
-- given probabilities. The vector should be non-empty and the
-- probabilities should be non-negative and sum to 1. If this is not
-- the case, this algorithm will construct a table for some
-- distribution that may bear no resemblance to what you intended.
tableFromProbabilities
    :: (Vector v (a,Word32), Vector v (a,Double), Vector v a, Vector v Word32)
       => v (a, Double) -> CondensedTable v a
{-# INLINE tableFromProbabilities #-}
tableFromProbabilities v
  | G.null tbl = pkgError "tableFromProbabilities" "empty vector of outcomes"
  | otherwise  = tableFromIntWeights $ G.map (second $ toWeight . (* mlt)) tbl
  where
    -- 2^32. N.B. This number is exatly representable.
    mlt = 4.294967296e9
    -- Drop non-positive probabilities
    tbl = G.filter ((> 0) . snd) v
    -- Convert Double weight to Word32 and avoid overflow at the same
    -- time. It's especially dangerous if one probability is
    -- approximately 1 and others are 0.
    toWeight w | w > mlt - 1 = 2^(32::Int) - 1
               | otherwise   = round w


-- | Same as 'tableFromProbabilities' but treats number as weights not
-- probilities. Non-positive weights are discarded, and those
-- remaining are normalized to 1.
tableFromWeights
    :: (Vector v (a,Word32), Vector v (a,Double), Vector v a, Vector v Word32)
       => v (a, Double) -> CondensedTable v a
{-# INLINE tableFromWeights #-}
tableFromWeights = tableFromProbabilities . normalize . G.filter ((> 0) . snd)
  where
    normalize v
      | G.null v  = pkgError "tableFromWeights" "no positive weights"
      | otherwise = G.map (second (/ s)) v
      where
        -- Explicit fold is to avoid 'Vector v Double' constraint
        s = G.foldl' (flip $ (+) . snd) 0 v


-- | Generate a condensed lookup table from integer weights. Weights
-- should sum to @2^32@ at least approximately. This function will
-- correct small deviations from @2^32@ such as arising from rounding
-- errors. But for large deviations it's likely to product incorrect
-- result with terrible performance.
tableFromIntWeights :: (Vector v (a,Word32), Vector v a, Vector v Word32)
                    => v (a, Word32)
                    -> CondensedTable v a
{-# INLINE tableFromIntWeights #-}
tableFromIntWeights v
  | n == 0    = pkgError "tableFromIntWeights" "empty table"
    -- Single element tables should be treated separately. Otherwise
    -- they will confuse correctWeights
  | n == 1    = let m = 2^(32::Int) - 1 -- Works for both Word32 & Word64
                in CondensedTable
                   m (G.replicate 256 $ fst $ G.head tbl)
                   m  G.empty
                   m  G.empty
                      G.empty
  | otherwise = CondensedTable
                na aa
                nb bb
                nc cc
                   dd
  where
    -- We must filter out zero-probability outcomes because they may
    -- confuse weight correction algorithm
    tbl   = G.filter ((/=0) . snd) v
    n     = G.length tbl
    -- Corrected table
    table = uncurry G.zip $ id *** correctWeights $ G.unzip tbl
    -- Make condensed table
    mkTable  d =
      G.concatMap (\(x,w) -> G.replicate (fromIntegral $ digit d w) x) table
    len = fromIntegral . G.length
    -- Tables
    aa = mkTable 0
    bb = mkTable 1
    cc = mkTable 2
    dd = mkTable 3
    -- Offsets
    na =       len aa `shiftL` 24
    nb = na + (len bb `shiftL` 16)
    nc = nb + (len cc `shiftL` 8)


-- Calculate N'th digit base 256
digit :: Int -> Word32 -> Word32
digit 0 x =  x `shiftR` 24
digit 1 x = (x `shiftR` 16) .&. 0xff
digit 2 x = (x `shiftR` 8 ) .&. 0xff
digit 3 x =  x .&. 0xff
digit _ _ = pkgError "digit" "the impossible happened!?"
{-# INLINE digit #-}

-- Correct integer weights so they sum up to 2^32. Array of weight
-- should contain at least 2 elements.
correctWeights :: G.Vector v Word32 => v Word32 -> v Word32
{-# INLINE correctWeights #-}
correctWeights v = G.create $ do
  let
    -- Sum of weights
    s = G.foldl' (flip $ (+) . fromIntegral) 0 v :: Int64
    -- Array size
    n = G.length v
  arr <- G.thaw v
  -- On first pass over array adjust only entries which are larger
  -- than `lim'. On second and subsequent passes `lim' is set to 1.
  --
  -- It's possibly to make this algorithm loop endlessly if all
  -- weights are 1 or 0.
  let loop lim i delta
        | delta == 0 = return ()
        | i >= n     = loop 1 0 delta
        | otherwise  = do
            w <- M.read arr i
            case () of
              _| w < lim   -> loop lim (i+1) delta
               | delta < 0 -> M.write arr i (w + 1) >> loop lim (i+1) (delta + 1)
               | otherwise -> M.write arr i (w - 1) >> loop lim (i+1) (delta - 1)
  loop 255 0 (s - 2^(32::Int))
  return arr


-- | Create a lookup table for the Poisson distribution. Note that
-- table construction may have significant cost. For λ < 100 it
-- takes as much time to build table as generation of 1000-30000
-- variates.
tablePoisson :: Double -> CondensedTableU Int
tablePoisson = tableFromProbabilities . make
  where
    make lam
      | lam < 0    = pkgError "tablePoisson" "negative lambda"
      | lam < 22.8 = U.unfoldr unfoldForward (exp (-lam), 0)
      | otherwise  = U.unfoldr unfoldForward (pMax, nMax)
                  ++ U.tail (U.unfoldr unfoldBackward (pMax, nMax))
      where
        -- Number with highest probability and its probability
        --
        -- FIXME: this is not ideal precision-wise. Check if code
        --        from statistics gives better precision.
        nMax = floor lam :: Int
        pMax = exp $ fromIntegral nMax * log lam - lam - logFactorial nMax
        -- Build probability list
        unfoldForward (p,i)
          | p < minP  = Nothing
          | otherwise = Just ( (i,p)
                             , (p * lam / fromIntegral (i+1), i+1)
                             )
        -- Go down
        unfoldBackward (p,i)
          | p < minP  = Nothing
          | otherwise = Just ( (i,p)
                             , (p / lam * fromIntegral i, i-1)
                             )
    -- Minimal representable probability for condensed tables
    minP = 1.1641532182693481e-10 -- 2**(-33)

-- | Create a lookup table for the binomial distribution.
tableBinomial :: Int            -- ^ Number of tries
              -> Double         -- ^ Probability of success
              -> CondensedTableU Int
tableBinomial n p = tableFromProbabilities makeBinom
  where 
  makeBinom
    | n <= 0         = pkgError "tableBinomial" "non-positive number of tries"
    | p == 0         = U.singleton (0,1)
    | p == 1         = U.singleton (n,1)
    | p > 0 && p < 1 = U.unfoldrN (n + 1) unfolder ((1-p)^n, 0)
    | otherwise      = pkgError "tableBinomial" "probability is out of range"
    where
      h = p / (1 - p)
      unfolder (t,i) = Just ( (i,t)
                            , (t * (fromIntegral $ n + 1 - i1) * h / fromIntegral i1, i1) )
        where i1 = i + 1

pkgError :: String -> String -> a
pkgError func err =
    error . concat $ ["System.Random.MWC.CondensedTable.", func, ": ", err]

-- $references
--
-- * Wang, J.; Tsang, W. W.; G. Marsaglia (2004), Fast Generation of
--   Discrete Random Variables, /Journal of Statistical Software,
--   American Statistical Association/, vol. 11(i03).
--   <http://ideas.repec.org/a/jss/jstsof/11i03.html>