File: Basic.hs

package info (click to toggle)
haskell-crypton 1.0.4-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 3,548 kB
  • sloc: haskell: 26,764; ansic: 22,294; makefile: 6
file content (137 lines) | stat: -rw-r--r-- 4,763 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
{-# LANGUAGE BangPatterns #-}

-- |
-- Module      : Crypto.Number.Basic
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
module Crypto.Number.Basic (
    sqrti,
    gcde,
    areEven,
    log2,
    numBits,
    numBytes,
    asPowerOf2AndOdd,
) where

import Data.Bits

import Crypto.Number.Compat

-- | @sqrti@ returns two integers @(l,b)@ so that @l <= sqrt i <= b@.
-- The implementation is quite naive, use an approximation for the first number
-- and use a dichotomy algorithm to compute the bound relatively efficiently.
sqrti :: Integer -> (Integer, Integer)
sqrti i
    | i < 0 = error "cannot compute negative square root"
    | i == 0 = (0, 0)
    | i == 1 = (1, 1)
    | i == 2 = (1, 2)
    | otherwise = loop x0
  where
    nbdigits = length $ show i
    x0n = (if even nbdigits then nbdigits - 2 else nbdigits - 1) `div` 2
    x0 = if even nbdigits then 2 * 10 ^ x0n else 6 * 10 ^ x0n
    loop x = case compare (sq x) i of
        LT -> iterUp x
        EQ -> (x, x)
        GT -> iterDown x
    iterUp lb = if sq ub >= i then iter lb ub else iterUp ub
      where
        ub = lb * 2
    iterDown ub = if sq lb >= i then iterDown lb else iter lb ub
      where
        lb = ub `div` 2
    iter lb ub
        | lb == ub = (lb, ub)
        | lb + 1 == ub = (lb, ub)
        | otherwise =
            let d = (ub - lb) `div` 2
             in if sq (lb + d) >= i
                    then iter lb (ub - d)
                    else iter (lb + d) ub
    sq a = a * a

-- | Get the extended GCD of two integer using integer divMod
--
-- gcde 'a' 'b' find (x,y,gcd(a,b)) where ax + by = d
gcde :: Integer -> Integer -> (Integer, Integer, Integer)
gcde a b =
    onGmpUnsupported (gmpGcde a b) $
        if d < 0 then (-x, -y, -d) else (x, y, d)
  where
    (d, x, y) = f (a, 1, 0) (b, 0, 1)
    f t (0, _, _) = t
    f (a', sa, ta) t@(b', sb, tb) =
        let (q, r) = a' `divMod` b'
         in f t (r, sa - (q * sb), ta - (q * tb))

-- | Check if a list of integer are all even
areEven :: [Integer] -> Bool
areEven = and . map even

-- | Compute the binary logarithm of a integer
log2 :: Integer -> Int
log2 n = onGmpUnsupported (gmpLog2 n) $ imLog 2 n
  where
    -- http://www.haskell.org/pipermail/haskell-cafe/2008-February/039465.html
    imLog b x = if x < b then 0 else (x `div` b ^ l) `doDiv` l
      where
        l = 2 * imLog (b * b) x
        doDiv x' l' = if x' < b then l' else (x' `div` b) `doDiv` (l' + 1)
{-# INLINE log2 #-}

-- | Compute the number of bits for an integer
numBits :: Integer -> Int
numBits n = gmpSizeInBits n `onGmpUnsupported` (if n == 0 then 1 else computeBits 0 n)
  where
    computeBits !acc i
        | q == 0 =
            if r >= 0x80
                then acc + 8
                else
                    if r >= 0x40
                        then acc + 7
                        else
                            if r >= 0x20
                                then acc + 6
                                else
                                    if r >= 0x10
                                        then acc + 5
                                        else
                                            if r >= 0x08
                                                then acc + 4
                                                else
                                                    if r >= 0x04
                                                        then acc + 3
                                                        else
                                                            if r >= 0x02
                                                                then acc + 2
                                                                else
                                                                    if r >= 0x01
                                                                        then acc + 1
                                                                        else acc -- should be catch by previous loop
        | otherwise = computeBits (acc + 8) q
      where
        (q, r) = i `divMod` 256

-- | Compute the number of bytes for an integer
numBytes :: Integer -> Int
numBytes n = gmpSizeInBytes n `onGmpUnsupported` ((numBits n + 7) `div` 8)

-- | Express an integer as an odd number and a power of 2
asPowerOf2AndOdd :: Integer -> (Int, Integer)
asPowerOf2AndOdd a
    | a == 0 = (0, 0)
    | odd a = (0, a)
    | a < 0 = let (e, a1) = asPowerOf2AndOdd $ abs a in (e, -a1)
    | isPowerOf2 a = (log2 a, 1)
    | otherwise = loop a 0
  where
    isPowerOf2 n = (n /= 0) && ((n .&. (n - 1)) == 0)
    loop n pw =
        if n `mod` 2 == 0
            then loop (n `div` 2) (pw + 1)
            else (pw, n)