File: Sqrt.hs

package info (click to toggle)
haskell-arithmoi 0.13.2.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 964 kB
  • sloc: haskell: 10,379; makefile: 3
file content (273 lines) | stat: -rw-r--r-- 9,288 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
-- |
-- Module:      Math.NumberTheory.Moduli.Sqrt
-- Copyright:   (c) 2011 Daniel Fischer
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Modular square roots and
-- <https://en.wikipedia.org/wiki/Jacobi_symbol Jacobi symbol>.
--

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE PostfixOperators #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

module Math.NumberTheory.Moduli.Sqrt
  ( -- * Modular square roots
    sqrtsMod
  , sqrtsModFactorisation
  , sqrtsModPrimePower
  , sqrtsModPrime
    -- * Jacobi symbol
  , JacobiSymbol(..)
  , jacobi
  , symbolToNum
  ) where

import Control.Monad (liftM2)
import Data.Bits
import Data.Constraint
import Data.List.Infinite (Infinite(..), (...))
import qualified Data.List.Infinite as Inf
import Data.Maybe
import Data.Mod
import Data.Proxy
import GHC.TypeNats (KnownNat, SomeNat(..), natVal, someNatVal, Nat)
import Numeric.Natural (Natural)

import Math.NumberTheory.Moduli.Chinese
import Math.NumberTheory.Moduli.JacobiSymbol
import Math.NumberTheory.Moduli.Singleton
import Math.NumberTheory.Primes
import Math.NumberTheory.Utils (shiftToOddCount, splitOff)
import Math.NumberTheory.Utils.FromIntegral

-- | List all modular square roots.
--
-- >>> :set -XDataKinds
-- >>> sqrtsMod sfactors (1 :: Mod 60)
-- [(1 `modulo` 60),(49 `modulo` 60),(41 `modulo` 60),(29 `modulo` 60),(31 `modulo` 60),(19 `modulo` 60),(11 `modulo` 60),(59 `modulo` 60)]
sqrtsMod :: SFactors Integer m -> Mod m -> [Mod m]
sqrtsMod sm a = case proofFromSFactors sm of
  Sub Dict -> map fromInteger $ sqrtsModFactorisation (toInteger (unMod a)) (unSFactors sm)

-- | List all square roots modulo a number, the factorisation of which is
-- passed as a second argument.
--
-- >>> sqrtsModFactorisation 1 (factorise 60)
-- [1,49,41,29,31,19,11,59]
sqrtsModFactorisation :: Integer -> [(Prime Integer, Word)] -> [Integer]
sqrtsModFactorisation _ []  = [0]
sqrtsModFactorisation n pps = map fst $ foldl1 (liftM2 comb) cs
  where
    ms :: [Integer]
    ms = map (\(p, pow) -> unPrime p ^ pow) pps

    rs :: [[Integer]]
    rs = map (uncurry (sqrtsModPrimePower n)) pps

    cs :: [[(Integer, Integer)]]
    cs = zipWith (\l m -> map (, m) l) rs ms

    comb t1 t2 = (if ch < 0 then ch + m else ch, m)
      where
        (ch, m) = fromJust $ chinese t1 t2

-- | List all square roots modulo the power of a prime.
--
-- >>> import Data.Maybe
-- >>> import Math.NumberTheory.Primes
-- >>> sqrtsModPrimePower 7 (fromJust (isPrime 3)) 2
-- [4,5]
-- >>> sqrtsModPrimePower 9 (fromJust (isPrime 3)) 3
-- [3,12,21,24,6,15]
sqrtsModPrimePower :: Integer -> Prime Integer -> Word -> [Integer]
sqrtsModPrimePower nn p 1 = sqrtsModPrime nn p
sqrtsModPrimePower nn (unPrime -> prime) expo = let primeExpo = prime ^ expo in
  case splitOff prime (nn `mod` primeExpo) of
    (_, 0) -> [0, prime ^ ((expo + 1) `quot` 2) .. primeExpo - 1]
    (kk, n)
      | odd kk    -> []
      | otherwise -> case (if prime == 2 then sqM2P n expo' else sqrtModPP' n prime expo') of
        Nothing -> []
        Just r  -> let rr = r * prime ^ k in
          if prime == 2 && k + 1 == t
          then go rr os
          else go rr os ++ go (primeExpo - rr) os
      where
        k = kk `quot` 2
        t = (if prime == 2 then expo - k - 1 else expo - k) `max` ((expo + 1) `quot` 2)
        expo' = expo - 2 * k
        os = [0, prime ^ t .. primeExpo - 1]

        -- equivalent to map ((`mod` primeExpo) . (+ r)) rs,
        -- but avoids division
        go r rs = map (+ r) ps ++ map (+ (r - primeExpo)) qs
          where
            (ps, qs) = span (< primeExpo - r) rs

-- | List all square roots by prime modulo.
--
-- >>> import Data.Maybe
-- >>> import Math.NumberTheory.Primes
-- >>> sqrtsModPrime 1 (fromJust (isPrime 5))
-- [1,4]
-- >>> sqrtsModPrime 0 (fromJust (isPrime 5))
-- [0]
-- >>> sqrtsModPrime 2 (fromJust (isPrime 5))
-- []
sqrtsModPrime :: Integer -> Prime Integer -> [Integer]
sqrtsModPrime n (unPrime -> 2) = [n `mod` 2]
sqrtsModPrime n (unPrime -> prime) = case jacobi n prime of
  MinusOne -> []
  Zero     -> [0]
  One      -> case someNatVal (fromInteger prime) of
    SomeNat (_ :: Proxy p) -> let r = toInteger (unMod (sqrtModP' @p (fromInteger n))) in [r, prime - r]

-------------------------------------------------------------------------------
-- Internals

-- | @sqrtModP' square prime@ finds a square root of @square@ modulo
--   prime. @prime@ /must/ be a (positive) prime, and @square@ /must/ be a positive
--   quadratic residue modulo @prime@, i.e. @'jacobi square prime == 1@.
sqrtModP' :: KnownNat p => Mod p -> Mod p
sqrtModP' square
  | prime == 2         = square
  | rem4 prime == 3    = square ^ ((prime + 1) `quot` 4)
  | square == maxBound = sqrtOfMinusOne
  | otherwise          = tonelliShanks square
  where
    prime = natVal square

-- | @p@ must be of form @4k + 1@
sqrtOfMinusOne :: forall (p :: Nat). KnownNat p => Mod p
sqrtOfMinusOne = case results of
  [] -> error "sqrtOfMinusOne: internal invariant violated"
  hd : _ -> hd
  where
    p :: Natural
    p = natVal (Proxy :: Proxy p)

    k :: Natural
    k = (p - 1) `quot` 4

    results :: [Mod p]
    results = dropWhile (\n -> n == 1 || n == maxBound) $
      map (^ k) [2 .. maxBound - 1]

-- | @tonelliShanks square prime@ calculates a square root of @square@
--   modulo @prime@, where @prime@ is a prime of the form @4*k + 1@ and
--   @square@ is a positive quadratic residue modulo @prime@, using the
--   Tonelli-Shanks algorithm.
tonelliShanks :: forall p. KnownNat p => Mod p -> Mod p
tonelliShanks square = loop rc t1 generator log2
  where
    prime = natVal square
    (log2, q) = shiftToOddCount (prime - 1)
    generator = findNonSquare ^ q
    rc = square ^ ((q + 1) `quot` 2)
    t1 = square ^ q

    msquare 0 x = x
    msquare k x = msquare (k-1) (x * x)

    findPeriod per 1 = per
    findPeriod per x = findPeriod (per + 1) (x * x)

    loop :: Mod p -> Mod p -> Mod p -> Word -> Mod p
    loop !r t c m
        | t == 1    = r
        | otherwise = loop nextR nextT nextC nextM
          where
            nextM = findPeriod 0 t
            b     = msquare (m - 1 - nextM) c
            nextR = r * b
            nextC = b * b
            nextT = t * nextC

-- | prime must be odd, n must be coprime with prime
sqrtModPP' :: Integer -> Integer -> Word -> Maybe Integer
sqrtModPP' n prime expo = case jacobi n prime of
  MinusOne -> Nothing
  Zero     -> Nothing
  One      -> case someNatVal (fromInteger prime) of
    SomeNat (_ :: Proxy p) -> Just $ fixup $ sqrtModP' @p (fromInteger n)
  where
    fixup :: KnownNat p => Mod p -> Integer
    fixup r
      | diff' == 0 = r'
      | expo <= e  = r'
      | otherwise  = hoist (recip (2 * r)) r' (fromInteger q) (prime^e)
      where
        r' = toInteger (unMod r)
        diff' = r' * r' - n
        (e, q) = splitOff prime diff'

    hoist :: KnownNat p => Mod p -> Integer -> Mod p -> Integer -> Integer
    hoist inv root elim pp
      | diff' == 0    = root'
      | expo <= ex    = root'
      | otherwise     = hoist inv root' (fromInteger nelim) (prime ^ ex)
        where
          root' = root + toInteger (unMod (inv * negate elim)) * pp
          diff' = root' * root' - n
          (ex, nelim) = splitOff prime diff'

-- dirty, dirty
sqM2P :: Integer -> Word -> Maybe Integer
sqM2P n e
    | e < 2     = Just (n `mod` 2)
    | n' == 0   = Just 0
    | odd k     = Nothing
    | otherwise = (`mod` mdl) . (`shiftL` wordToInt k2) <$> solve s e2
      where
        mdl = 1 `shiftL` wordToInt e
        n' = n `mod` mdl
        (k, s) = shiftToOddCount n'
        k2 = k `quot` 2
        e2 = e - k
        solve _ 1 = Just 1
        solve 1 _ = Just 1
        solve r _
            | rem4 r == 3   = Nothing  -- otherwise r ≡ 1 (mod 4)
            | rem8 r == 5   = Nothing  -- otherwise r ≡ 1 (mod 8)
            | otherwise     = fixup r (fst $ shiftToOddCount (r-1))
              where
                fixup x pw
                    | pw >= e2  = Just x
                    | otherwise = fixup x' pw'
                      where
                        x' = x + (1 `shiftL` (wordToInt pw - 1))
                        d = x'*x' - r
                        pw' = if d == 0 then e2 else fst (shiftToOddCount d)

-------------------------------------------------------------------------------
-- Utilities

rem4 :: Integral a => a -> Int
rem4 n = fromIntegral n .&. 3

rem8 :: Integral a => a -> Int
rem8 n = fromIntegral n .&. 7

findNonSquare :: forall (n :: Nat). KnownNat n => Mod n
findNonSquare
  | rem8 n == 3 || rem8 n == 5 = 2
  | otherwise = fromIntegral $ Inf.head $
    Inf.dropWhile (\p -> jacobi p n /= MinusOne) candidates
  where
    n = natVal (Proxy :: Proxy n)

    -- It is enough to consider only prime candidates, but
    -- the probability that the smallest non-residue is > 67
    -- is small and 'jacobi' test is fast,
    -- so we use [71..n] instead of filter isPrime [71..n].
    candidates :: Infinite Natural
    candidates = 3 :< 5 :< 7 :< 11 :< 13 :< 17 :< 19 :< 23 :< 29 :< 31 :<
      37 :< 41 :< 43 :< 47 :< 53 :< 59 :< 61 :< 67 :< (71...)