File: P256.hs

package info (click to toggle)
haskell-crypton 1.0.4-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,548 kB
  • sloc: haskell: 26,764; ansic: 22,294; makefile: 6
file content (458 lines) | stat: -rw-r--r-- 14,995 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# OPTIONS_GHC -fno-warn-unused-binds #-}

-- |
-- Module      : Crypto.PubKey.ECC.P256
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- P256 support
module Crypto.PubKey.ECC.P256 (
    Scalar,
    Point,

    -- * Point arithmetic
    pointBase,
    pointAdd,
    pointNegate,
    pointMul,
    pointDh,
    pointsMulVarTime,
    pointIsValid,
    pointIsAtInfinity,
    toPoint,
    pointX,
    pointToIntegers,
    pointFromIntegers,
    pointToBinary,
    pointFromBinary,
    unsafePointFromBinary,

    -- * Scalar arithmetic
    scalarGenerate,
    scalarZero,
    scalarN,
    scalarIsZero,
    scalarAdd,
    scalarSub,
    scalarMul,
    scalarInv,
    scalarInvSafe,
    scalarCmp,
    scalarFromBinary,
    scalarToBinary,
    scalarFromInteger,
    scalarToInteger,
) where

import Data.Word
import Foreign.C.Types
import Foreign.Ptr

import Crypto.Error
import Crypto.Internal.ByteArray
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import qualified Crypto.Number.Serialize as S (i2ospOf, os2ip)
import Crypto.Number.Serialize.Internal (i2ospOf, os2ip)
import Crypto.Random
import Data.Memory.PtrMethods (memSet)

-- | A P256 scalar
newtype Scalar = Scalar ScrubbedBytes
    deriving (Show, Eq, ByteArrayAccess, NFData)

-- | A P256 point
newtype Point = Point Bytes
    deriving (Show, Eq, NFData)

scalarSize :: Int
scalarSize = 32

pointSize :: Int
pointSize = 64

type P256Digit = Word32

data P256Scalar
data P256Y
data P256X

order :: Integer
order = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551

------------------------------------------------------------------------
-- Point methods
------------------------------------------------------------------------

-- | Get the base point for the P256 Curve
pointBase :: Point
pointBase =
    case scalarFromInteger 1 of
        CryptoPassed s -> toPoint s
        CryptoFailed _ -> error "pointBase: assumption failed"

-- | Lift to curve a scalar
--
-- Using the curve generator as base point compute:
--
-- > scalar * G
toPoint :: Scalar -> Point
toPoint s
    | scalarIsZero s = error "cannot create point from zero"
    | otherwise =
        withNewPoint $ \px py -> withScalar s $ \p ->
            ccrypton_p256_basepoint_mul p px py

-- | Add a point to another point
pointAdd :: Point -> Point -> Point
pointAdd a b = withNewPoint $ \dx dy ->
    withPoint a $ \ax ay -> withPoint b $ \bx by ->
        ccrypton_p256e_point_add ax ay bx by dx dy

-- | Negate a point
pointNegate :: Point -> Point
pointNegate a = withNewPoint $ \dx dy ->
    withPoint a $ \ax ay ->
        ccrypton_p256e_point_negate ax ay dx dy

-- | Multiply a point by a scalar
--
-- warning: variable time
pointMul :: Scalar -> Point -> Point
pointMul scalar p = withNewPoint $ \dx dy ->
    withScalar scalar $ \n -> withPoint p $ \px py ->
        ccrypton_p256e_point_mul n px py dx dy

-- | Similar to 'pointMul', serializing the x coordinate as binary.
-- When scalar is multiple of point order the result is all zero.
pointDh :: ByteArray binary => Scalar -> Point -> binary
pointDh scalar p =
    B.unsafeCreate scalarSize $ \dst -> withTempPoint $ \dx dy -> do
        withScalar scalar $ \n -> withPoint p $ \px py ->
            ccrypton_p256e_point_mul n px py dx dy
        ccrypton_p256_to_bin (castPtr dx) dst

-- | multiply the point @p with @n2 and add a lifted to curve value @n1
--
-- > n1 * G + n2 * p
--
-- warning: variable time
pointsMulVarTime :: Scalar -> Scalar -> Point -> Point
pointsMulVarTime n1 n2 p = withNewPoint $ \dx dy ->
    withScalar n1 $ \pn1 -> withScalar n2 $ \pn2 -> withPoint p $ \px py ->
        ccrypton_p256_points_mul_vartime pn1 pn2 px py dx dy

-- | Check if a 'Point' is valid
pointIsValid :: Point -> Bool
pointIsValid p = unsafeDoIO $ withPoint p $ \px py -> do
    r <- ccrypton_p256_is_valid_point px py
    return (r /= 0)

-- | Check if a 'Point' is the point at infinity
pointIsAtInfinity :: Point -> Bool
pointIsAtInfinity (Point b) = constAllZero b

-- | Return the x coordinate as a 'Scalar' if the point is not at infinity
pointX :: Point -> Maybe Scalar
pointX p
    | pointIsAtInfinity p = Nothing
    | otherwise = Just $
        withNewScalarFreeze $ \d ->
            withPoint p $ \px _ ->
                ccrypton_p256_mod ccrypton_SECP256r1_n (castPtr px) (castPtr d)

-- | Convert a point to (x,y) Integers
pointToIntegers :: Point -> (Integer, Integer)
pointToIntegers p = unsafeDoIO $ withPoint p $ \px py ->
    allocTemp 32 (serialize (castPtr px) (castPtr py))
  where
    serialize px py temp = do
        ccrypton_p256_to_bin px temp
        x <- os2ip temp scalarSize
        ccrypton_p256_to_bin py temp
        y <- os2ip temp scalarSize
        return (x, y)

-- | Convert from (x,y) Integers to a point
pointFromIntegers :: (Integer, Integer) -> Point
pointFromIntegers (x, y) = withNewPoint $ \dx dy ->
    allocTemp
        scalarSize
        (\temp -> fill temp (castPtr dx) x >> fill temp (castPtr dy) y)
  where
    -- put @n to @temp in big endian format, then from @temp to @dest in p256 scalar format
    fill :: Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
    fill temp dest n = do
        -- write the integer in big endian format to temp
        memSet temp 0 scalarSize
        e <- i2ospOf n temp scalarSize
        if e == 0
            then error "pointFromIntegers: filling failed"
            else return ()
        -- then fill dest with the P256 scalar from temp
        ccrypton_p256_from_bin temp dest

-- | Convert a point to a binary representation
pointToBinary :: ByteArray ba => Point -> ba
pointToBinary p = B.unsafeCreate pointSize $ \dst -> withPoint p $ \px py -> do
    ccrypton_p256_to_bin (castPtr px) dst
    ccrypton_p256_to_bin (castPtr py) (dst `plusPtr` 32)

-- | Convert from binary to a valid point
pointFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Point
pointFromBinary ba = unsafePointFromBinary ba >>= validatePoint
  where
    validatePoint :: Point -> CryptoFailable Point
    validatePoint p
        | pointIsValid p = CryptoPassed p
        | otherwise = CryptoFailed CryptoError_PointCoordinatesInvalid

-- | Convert from binary to a point, possibly invalid
unsafePointFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Point
unsafePointFromBinary ba
    | B.length ba /= pointSize = CryptoFailed CryptoError_PublicKeySizeInvalid
    | otherwise =
        CryptoPassed $ withNewPoint $ \px py -> B.withByteArray ba $ \src -> do
            ccrypton_p256_from_bin src (castPtr px)
            ccrypton_p256_from_bin (src `plusPtr` scalarSize) (castPtr py)

------------------------------------------------------------------------
-- Scalar methods
------------------------------------------------------------------------

-- | Generate a randomly generated new scalar
scalarGenerate :: MonadRandom randomly => randomly Scalar
scalarGenerate = unwrap . scalarFromBinary . witness <$> getRandomBytes 32
  where
    unwrap (CryptoFailed _) = error "scalarGenerate: assumption failed"
    unwrap (CryptoPassed s) = s
    witness :: ScrubbedBytes -> ScrubbedBytes
    witness = id

-- | The scalar representing 0
scalarZero :: Scalar
scalarZero = withNewScalarFreeze $ \d -> ccrypton_p256_init d

-- | The scalar representing the curve order
scalarN :: Scalar
scalarN = throwCryptoError (scalarFromInteger order)

-- | Check if the scalar is 0
scalarIsZero :: Scalar -> Bool
scalarIsZero s = unsafeDoIO $ withScalar s $ \d -> do
    result <- ccrypton_p256_is_zero d
    return $ result /= 0

-- | Perform addition between two scalars
--
-- > a + b
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd a b =
    withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb ->
        ccrypton_p256e_modadd ccrypton_SECP256r1_n pa pb d

-- | Perform subtraction between two scalars
--
-- > a - b
scalarSub :: Scalar -> Scalar -> Scalar
scalarSub a b =
    withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb ->
        ccrypton_p256e_modsub ccrypton_SECP256r1_n pa pb d

-- | Perform multiplication between two scalars
--
-- > a * b
scalarMul :: Scalar -> Scalar -> Scalar
scalarMul a b =
    withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb ->
        ccrypton_p256_modmul ccrypton_SECP256r1_n pa 0 pb d

-- | Give the inverse of the scalar
--
-- > 1 / a
--
-- warning: variable time
scalarInv :: Scalar -> Scalar
scalarInv a =
    withNewScalarFreeze $ \b -> withScalar a $ \pa ->
        ccrypton_p256_modinv_vartime ccrypton_SECP256r1_n pa b

-- | Give the inverse of the scalar using safe exponentiation
--
-- > 1 / a
scalarInvSafe :: Scalar -> Scalar
scalarInvSafe a =
    withNewScalarFreeze $ \b -> withScalar a $ \pa ->
        ccrypton_p256e_scalar_invert pa b

-- | Compare 2 Scalar
scalarCmp :: Scalar -> Scalar -> Ordering
scalarCmp a b = unsafeDoIO $
    withScalar a $ \pa -> withScalar b $ \pb -> do
        v <- ccrypton_p256_cmp pa pb
        return $ compare v 0

-- | convert a scalar from binary
scalarFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary ba
    | B.length ba /= scalarSize = CryptoFailed CryptoError_SecretKeySizeInvalid
    | otherwise =
        CryptoPassed $ withNewScalarFreeze $ \p -> B.withByteArray ba $ \b ->
            ccrypton_p256_from_bin b p
{-# NOINLINE scalarFromBinary #-}

-- | convert a scalar to binary
scalarToBinary :: ByteArray ba => Scalar -> ba
scalarToBinary s = B.unsafeCreate scalarSize $ \b -> withScalar s $ \p ->
    ccrypton_p256_to_bin p b
{-# NOINLINE scalarToBinary #-}

-- | Convert from an Integer to a P256 Scalar
scalarFromInteger :: Integer -> CryptoFailable Scalar
scalarFromInteger i =
    maybe
        (CryptoFailed CryptoError_SecretKeySizeInvalid)
        scalarFromBinary
        (S.i2ospOf 32 i :: Maybe Bytes)

-- | Convert from a P256 Scalar to an Integer
scalarToInteger :: Scalar -> Integer
scalarToInteger s = S.os2ip (scalarToBinary s :: Bytes)

------------------------------------------------------------------------
-- Memory Helpers
------------------------------------------------------------------------
withNewPoint :: (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint f = Point $ B.unsafeCreate pointSize $ \px -> f px (pxToPy px)
{-# NOINLINE withNewPoint #-}

withPoint :: Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint (Point d) f = B.withByteArray d $ \px -> f px (pxToPy px)

pxToPy :: Ptr P256X -> Ptr P256Y
pxToPy px = castPtr (px `plusPtr` scalarSize)

withNewScalarFreeze :: (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze f = Scalar $ B.allocAndFreeze scalarSize f
{-# NOINLINE withNewScalarFreeze #-}

withTempPoint :: (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withTempPoint f = allocTempScrubbed pointSize (\p -> let px = castPtr p in f px (pxToPy px))

withScalar :: Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar (Scalar d) f = B.withByteArray d f

allocTemp :: Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp n f = ignoreSnd <$> B.allocRet n f
  where
    ignoreSnd :: (a, Bytes) -> a
    ignoreSnd = fst

allocTempScrubbed :: Int -> (Ptr Word8 -> IO a) -> IO a
allocTempScrubbed n f = ignoreSnd <$> B.allocRet n f
  where
    ignoreSnd :: (a, ScrubbedBytes) -> a
    ignoreSnd = fst

------------------------------------------------------------------------
-- Foreign bindings
------------------------------------------------------------------------
foreign import ccall "&crypton_SECP256r1_n"
    ccrypton_SECP256r1_n :: Ptr P256Scalar
foreign import ccall "&crypton_SECP256r1_p"
    ccrypton_SECP256r1_p :: Ptr P256Scalar
foreign import ccall "&crypton_SECP256r1_b"
    ccrypton_SECP256r1_b :: Ptr P256Scalar

foreign import ccall "crypton_p256_init"
    ccrypton_p256_init :: Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_is_zero"
    ccrypton_p256_is_zero :: Ptr P256Scalar -> IO CInt
foreign import ccall "crypton_p256_clear"
    ccrypton_p256_clear :: Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256e_modadd"
    ccrypton_p256e_modadd
        :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_add_d"
    ccrypton_p256_add_d :: Ptr P256Scalar -> P256Digit -> Ptr P256Scalar -> IO CInt
foreign import ccall "crypton_p256e_modsub"
    ccrypton_p256e_modsub
        :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_cmp"
    ccrypton_p256_cmp :: Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
foreign import ccall "crypton_p256_mod"
    ccrypton_p256_mod :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_modmul"
    ccrypton_p256_modmul
        :: Ptr P256Scalar
        -> Ptr P256Scalar
        -> P256Digit
        -> Ptr P256Scalar
        -> Ptr P256Scalar
        -> IO ()
foreign import ccall "crypton_p256e_scalar_invert"
    ccrypton_p256e_scalar_invert :: Ptr P256Scalar -> Ptr P256Scalar -> IO ()

-- foreign import ccall "crypton_p256_modinv"
--    ccrypton_p256_modinv :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_modinv_vartime"
    ccrypton_p256_modinv_vartime
        :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "crypton_p256_base_point_mul"
    ccrypton_p256_basepoint_mul
        :: Ptr P256Scalar
        -> Ptr P256X
        -> Ptr P256Y
        -> IO ()

foreign import ccall "crypton_p256e_point_add"
    ccrypton_p256e_point_add
        :: Ptr P256X
        -> Ptr P256Y
        -> Ptr P256X
        -> Ptr P256Y
        -> Ptr P256X
        -> Ptr P256Y
        -> IO ()

foreign import ccall "crypton_p256e_point_negate"
    ccrypton_p256e_point_negate
        :: Ptr P256X
        -> Ptr P256Y
        -> Ptr P256X
        -> Ptr P256Y
        -> IO ()

-- compute (out_x,out_y) = n * (in_x,in_y)
foreign import ccall "crypton_p256e_point_mul"
    ccrypton_p256e_point_mul
        :: Ptr P256Scalar -- n
        -> Ptr P256X
        -> Ptr P256Y -- in_{x,y}
        -> Ptr P256X
        -> Ptr P256Y -- out_{x,y}
        -> IO ()

-- compute (out_x,out,y) = n1 * G + n2 * (in_x,in_y)
foreign import ccall "crypton_p256_points_mul_vartime"
    ccrypton_p256_points_mul_vartime
        :: Ptr P256Scalar -- n1
        -> Ptr P256Scalar -- n2
        -> Ptr P256X
        -> Ptr P256Y -- in_{x,y}
        -> Ptr P256X
        -> Ptr P256Y -- out_{x,y}
        -> IO ()
foreign import ccall "crypton_p256_is_valid_point"
    ccrypton_p256_is_valid_point :: Ptr P256X -> Ptr P256Y -> IO CInt

foreign import ccall "crypton_p256_to_bin"
    ccrypton_p256_to_bin :: Ptr P256Scalar -> Ptr Word8 -> IO ()

foreign import ccall "crypton_p256_from_bin"
    ccrypton_p256_from_bin :: Ptr Word8 -> Ptr P256Scalar -> IO ()