File: Uniformity.hs

package info (click to toggle)
haskell-splitmix 0.1.0.5-2
  • links: PTS
  • area: main
  • in suites: forky, sid, trixie
  • size: 196 kB
  • sloc: haskell: 1,366; ansic: 151; sh: 53; makefile: 9
file content (134 lines) | stat: -rw-r--r-- 4,065 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DeriveFunctor       #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Chi-Squared test for uniformity.
module Uniformity (testUniformity) where

import Data.List                    (intercalate)
import Data.List                    (foldl')
import Numeric                      (showFFloat)
import Numeric.SpecFunctions        (incompleteGamma)
import Test.Framework.Providers.API (Test, TestName)

import qualified Data.Map as Map

import MiniQC as QC

-- | \( \lim_{n\to\infty} \mathrm{Pr}(V \le v) = \ldots \)
chiDist
    :: Int     -- ^ k, categories
    -> Double  -- ^ v, value
    -> Double
chiDist k x = incompleteGamma (0.5 * v) (0.5 * x) where
  v = fromIntegral (k - 1)

-- | When the distribution is uniform,
--
-- \[
-- \frac{1}{n} \sum_{s = 1}^k \frac{Y_s^2}{p_s} - n
-- \]
--
-- simplifies to
--
-- \[
-- \frac{k}{n} \sum_{s=1}^k Y_s^2 - n
-- \]
--
-- when \(p_s = \frac{1}{k} \), i.e. \(k\) is the number of buckets.
--
calculateV :: Int -> Map.Map k Int -> Double
calculateV k data_ = chiDist k v
  where
    v          = fromIntegral k * fromIntegral sumY2 / fromIntegral n - fromIntegral n
    V2 n sumY2 = foldl' sumF (V2 0 0) (Map.elems data_) where
        sumF (V2 m m2) x = V2 (m + x) (m2 + x * x)

-- Strict pair of 'Int's, used as an accumulator.
data V2 = V2 !Int !Int

countStream :: Ord a => Stream a -> Int -> Map.Map a Int
countStream = go Map.empty where
    go !acc s n
        | n <= 0    = acc
        | otherwise = case s of
            x :> xs -> go (Map.insertWith (+) x 1 acc) xs (pred n)

testUniformityRaw :: forall a. (Ord a, Show a) => Int -> Stream a -> Either String Double
testUniformityRaw k s
    | Map.size m > k = Left $ "Got more elements (" ++ show (Map.size m, take 5 $ Map.keys m) ++ " than expected (" ++ show k ++ ")"
    | p > 0.999999   = Left $
        "Too impropabable p-value: " ++ show p ++ "\n" ++ table
        [ [ show x, showFFloat (Just 3) (fromIntegral y / fromIntegral n :: Double) "" ]
        | (x, y) <- take 20 $ Map.toList m
        ]
    | otherwise      = Right p
  where
    -- each bucket to have roughly 128 elements
    n :: Int
    n = k * 128

    -- buckets from the stream
    m :: Map.Map a Int
    m = countStream s n

    -- calculate chi-squared value
    p :: Double
    p = calculateV k m

testUniformityQC :: (Ord a, Show a) => Int -> Stream a -> QC.Property
testUniformityQC k s = case testUniformityRaw k s of
    Left err -> QC.counterexample err False
    Right _  -> QC.property True

-- | Test that generator produces values uniformly.
--
-- The size is scaled to be at least 20.
--
testUniformity
    :: forall a b. (Ord b, Show b)
    => TestName
    -> QC.Gen a  -- ^ Generator to test
    -> (a -> b)    -- ^ Partitioning function
    -> Int         -- ^ Number of partittions
    -> Test
testUniformity name gen f k = QC.testMiniProperty name
    $ QC.forAllBlind (streamGen gen)
    $ testUniformityQC k . fmap f

-------------------------------------------------------------------------------
-- Infinite stream
-------------------------------------------------------------------------------

data Stream a = a :> Stream a deriving (Functor)
infixr 5 :>

streamGen :: QC.Gen a -> QC.Gen (Stream a)
streamGen g = gs where
    gs = do
        x <- g
        xs <- gs
        return (x :> xs)

-------------------------------------------------------------------------------
-- Table
-------------------------------------------------------------------------------

table :: [[String]] -> String
table cells = unlines rows
  where
    cols      :: Int
    rowWidths :: [Int]
    rows      :: [String]

    (cols, rowWidths, rows) = foldr go (0, repeat 0, []) cells

    go :: [String] -> (Int, [Int], [String]) -> (Int, [Int], [String])
    go xs (c, w, yss) =
        ( max c (length xs)
        , zipWith max w (map length xs ++ repeat 0)
        , intercalate "   " (take cols (zipWith fill xs rowWidths))
          : yss
        )

    fill :: String -> Int -> String
    fill s n = s ++ replicate (n - length s) ' '