1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
|
# --- ELLIPTIC CURVE MATH ------------------------------------------------------
#
# curve definition: y^2 = x^3 - p*x - q
# over finite field: Z/nZ* (prime residue classes modulo a prime number n)
#
#
# COPYRIGHT (c) 2010 by Toni Mattis <solaris@live.de>
# ------------------------------------------------------------------------------
'''
Module for elliptic curve arithmetic over a prime field GF(n).
E(GF(n)) takes the form y**2 == x**3 - p*x - q (mod n) for a prime n.
0. Structures used by this module
PARAMETERS and SCALARS are non-negative (long) integers.
A POINT (x, y), usually denoted p1, p2, ...
is a pair of (long) integers where 0 <= x < n and 0 <= y < n
A POINT in PROJECTIVE COORDINATES, usually denoted jp1, jp2, ...
takes the form (X, Y, Z, Z**2, Z**3) where x = X / Z**2
and y = Y / z**3. This form is called Jacobian coordinates.
The NEUTRAL element "0" or "O" is represented by None
in both coordinate systems.
1. Basic Functions
euclid() Is the Extended Euclidean Algorithm.
inv() Computes the multiplicative inversion modulo n.
curve_q() Finds the curve parameter q (mod n)
when p and a point are given.
element() Tests whether a point (x, y) is on the curve.
2. Point transformations
to_projective() Converts a point (x, y) to projective coordinates.
from_projective() Converts a point from projective coordinates
to (x, y) using the transformation described above.
neg() Computes the inverse point -P in both coordinate
systems.
3. Slow point arithmetic
These algorithms make use of basic geometry and modular arithmetic
thus being suitable for small numbers and academic study.
add() Computes the sum of two (x, y)-points
mul() Perform scalar multiplication using "double & add"
4. Fast point arithmetic
These algorithms make use of projective coordinates, signed binary
expansion and a JSP-like approach (joint sparse form).
The following functions consume and return projective coordinates:
addf() Optimized point addition.
doublef() Optimized point doubling.
mulf() Highly optimized scalar multiplication.
muladdf() Highly optimized addition of two products.
The following functions use the optimized ones above but consume
and output (x, y)-coordinates for a more convenient usage:
mulp() Encapsulates mulf()
muladdp() Encapsulates muladdf()
For single additions add() is generally faster than an encapsulation of
addf() which would involve expensive coordinate transformations.
Hence there is no addp() and doublep().
'''
# BASIC MATH -------------------------------------------------------------------
def euclid(a, b):
'''Solve x*a + y*b = ggt(a, b) and return (x, y, ggt(a, b))'''
# Non-recursive approach hence suitable for large numbers
x = yy = 0
y = xx = 1
while b:
q = a // b
a, b = b, a % b
x, xx = xx - q * x, x
y, yy = yy - q * y, y
return xx, yy, a
def inv(a, n):
'''Perform inversion 1/a modulo n. a and n should be COPRIME.'''
# coprimality is not checked here in favour of performance
i = euclid(a, n)[0]
while i < 0:
i += n
return i
def curve_q(x, y, p, n):
'''Find curve parameter q mod n having point (x, y) and parameter p'''
return ((x * x - p) * x - y * y) % n
def element(point, p, q, n):
'''Test, whether the given point is on the curve (p, q, n)'''
if point:
x, y = point
return (x * x * x - p * x - q) % n == (y * y) % n
else:
return True
def to_projective(p):
'''Transform point p given as (x, y) to projective coordinates'''
if p:
return (p[0], p[1], 1, 1, 1)
else:
return None # Identity point (0)
def from_projective(jp, n):
'''Transform a point from projective coordinates to (x, y) mod n'''
if jp:
return (jp[0] * inv(jp[3], n)) % n, (jp[1] * inv(jp[4], n)) % n
else:
return None # Identity point (0)
def neg(p, n):
'''Compute the inverse point to p in any coordinate system'''
return (p[0], (n - p[1]) % n) + p[2:] if p else None
# POINT ADDITION ---------------------------------------------------------------
# addition of points in y**2 = x**3 - p*x - q over <Z/nZ*; +>
def add(p, q, n, p1, p2):
'''Add points p1 and p2 over curve (p, q, n)'''
if p1 and p2:
x1, y1 = p1
x2, y2 = p2
if (x1 - x2) % n:
s = ((y1 - y2) * inv(x1 - x2, n)) % n # slope
x = (s * s - x1 - x2) % n # intersection with curve
return (x, n - (y1 + s * (x - x1)) % n)
else:
if (y1 + y2) % n: # slope s calculated by derivation
s = ((3 * x1 * x1 - p) * inv(2 * y1, n)) % n
x = (s * s - 2 * x1) % n # intersection with curve
return (x, n - (y1 + s * (x - x1)) % n)
else:
return None
else: # either p1 is not none -> ret. p1, otherwiese p2, which may be
return p1 if p1 else p2 # none too.
# faster addition: redundancy in projective coordinates eliminates
# expensive inversions mod n.
def addf(p, q, n, jp1, jp2):
'''Add jp1 and jp2 in projective (jacobian) coordinates.'''
if jp1 and jp2:
x1, y1, z1, z1s, z1c = jp1
x2, y2, z2, z2s, z2c = jp2
s1 = (y1 * z2c) % n
s2 = (y2 * z1c) % n
u1 = (x1 * z2s) % n
u2 = (x2 * z1s) % n
if (u1 - u2) % n:
h = (u2 - u1) % n
r = (s2 - s1) % n
hs = (h * h) % n
hc = (hs * h) % n
x3 = (-hc - 2 * u1 * hs + r * r) % n
y3 = (-s1 * hc + r * (u1 * hs - x3)) % n
z3 = (z1 * z2 * h) % n
z3s = (z3 * z3) % n
z3c = (z3s * z3) % n
return (x3, y3, z3, z3s, z3c)
else:
if (s1 + s2) % n:
return doublef(p, q, n, jp1)
else:
return None
else:
return jp1 if jp1 else jp2
# explicit point doubling using redundant coordinates
def doublef(p, q, n, jp):
'''Double jp in projective (jacobian) coordinates'''
if not jp:
return None
x1, y1, z1, z1p2, z1p3 = jp
y1p2 = (y1 * y1) % n
a = (4 * x1 * y1p2) % n
b = (3 * x1 * x1 - p * z1p3 * z1) % n
x3 = (b * b - 2 * a) % n
y3 = (b * (a - x3) - 8 * y1p2 * y1p2) % n
z3 = (2 * y1 * z1) % n
z3p2 = (z3 * z3) % n
return x3, y3, z3, z3p2, (z3p2 * z3) % n
# SCALAR MULTIPLICATION --------------------------------------------------------
# scalar multiplication p1 * c = p1 + p1 + ... + p1 (c times) in O(log(n))
def mul(p, q, n, p1, c):
'''multiply point p1 by scalar c over curve (p, q, n)'''
res = None
while c > 0:
if c & 1:
res = add(p, q, n, res, p1)
c >>= 1 # c = c / 2
p1 = add(p, q, n, p1, p1) # p1 = p1 * 2
return res
# this method allows _signed_bin() to choose between 1 and -1. It will select
# the sign which leaves the higher number of zeroes in the binary
# representation (the higher GDB).
def _gbd(n):
'''Compute second greatest base-2 divisor'''
i = 1
if n <= 0: return 0
while not n % i:
i <<= 1
return i >> 2
# This method transforms n into a binary representation having signed bits.
# A signed binary expansion contains more zero-bits hence reducing the number
# of additions required by a multiplication algorithm.
#
# Example: 15 ( 0b1111 ) can be written as 16 - 1, resulting in (1,0,0,0,-1)
# and saving 2 additions. Subtraction can be performed as
# efficiently as addition.
def _signed_bin(n):
'''Transform n into an optimized signed binary representation'''
r = []
while n > 1:
if n & 1:
cp = _gbd(n + 1)
cn = _gbd(n - 1)
if cp > cn: # -1 leaves more zeroes -> subtract -1 (= +1)
r.append(-1)
n += 1
else: # +1 leaves more zeroes -> subtract +1 (= -1)
r.append(+1)
n -= 1
else:
r.append(0) # be glad about one more zero
n >>= 1
r.append(n)
return r[::-1]
# This multiplication algorithm combines signed binary expansion and
# fast addition using projective coordinates resulting in 5 to 10 times
# faster multiplication.
def mulf(p, q, n, jp1, c):
'''Multiply point jp1 by c in projective coordinates'''
sb = _signed_bin(c)
res = None
jp0 = neg(jp1, n) # additive inverse of jp1 to be used fot bit -1
for s in sb:
res = doublef(p, q, n, res)
if s:
res = addf(p, q, n, res, jp1) if s > 0 else \
addf(p, q, n, res, jp0)
return res
# Encapsulates mulf() in order to enable flat coordinates (x, y)
def mulp(p, q, n, p1, c):
'''Multiply point p by c using fast multiplication'''
return from_projective(mulf(p, q, n, to_projective(p1), c), n)
# Sum of two products using Shamir's trick and signed binary expansion
def muladdf(p, q, n, jp1, c1, jp2, c2):
'''Efficiently compute c1 * jp1 + c2 * jp2 in projective coordinates'''
s1 = _signed_bin(c1)
s2 = _signed_bin(c2)
diff = len(s2) - len(s1)
if diff > 0:
s1 = [0] * diff + s1
elif diff < 0:
s2 = [0] * -diff + s2
jp1p2 = addf(p, q, n, jp1, jp2)
jp1n2 = addf(p, q, n, jp1, neg(jp2, n))
precomp = ((None, jp2, neg(jp2, n)),
(jp1, jp1p2, jp1n2),
(neg(jp1, n), neg(jp1n2, n), neg(jp1p2, n)))
res = None
for i, j in zip(s1, s2):
res = doublef(p, q, n, res)
if i or j:
res = addf(p, q, n, res, precomp[i][j])
return res
# Encapsulate muladdf()
def muladdp(p, q, n, p1, c1, p2, c2):
'''Efficiently compute c1 * p1 + c2 * p2 in (x, y)-coordinates'''
return from_projective(muladdf(p, q, n,
to_projective(p1), c1,
to_projective(p2), c2), n)
# POINT COMPRESSION ------------------------------------------------------------
# Compute the square root modulo n
# Determine the sign-bit of a point allowing to reconstruct y-coordinates
# when x and the sign-bit are given:
def sign_bit(p1):
'''Return the signedness of a point p1'''
return p1[1] % 2 if p1 else 0
# Reconstruct the y-coordinate when curve parameters, x and the sign-bit of
# the y coordinate are given:
def y_from_x(x, p, q, n, sign):
'''Return the y coordinate over curve (p, q, n) for given (x, sign)'''
# optimized form of (x**3 - p*x - q) % n
a = (((x * x) % n - p) * x - q) % n
if __name__ == "__main__":
import rsa
import time
t = time.time()
n = rsa.get_prime(256/8, 20)
tp = time.time() - t
p = rsa.random.randint(1, n)
p1 = (rsa.random.randint(1, n), rsa.random.randint(1, n))
q = curve_q(p1[0], p1[1], p, n)
r1 = rsa.random.randint(1,n)
r2 = rsa.random.randint(1,n)
q1 = mulp(p, q, n, p1, r1)
q2 = mulp(p, q, n, p1, r2)
s1 = mulp(p, q, n, q1, r2)
s2 = mulp(p, q, n, q2, r1)
s1 == s2
tt = time.time() - t
def test(tcount, bits = 256):
n = rsa.get_prime(bits/8, 20)
p = rsa.random.randint(1, n)
p1 = (rsa.random.randint(1, n), rsa.random.randint(1, n))
q = curve_q(p1[0], p1[1], p, n)
p2 = mulp(p, q, n, p1, rsa.random.randint(1, n))
c1 = [rsa.random.randint(1, n) for i in xrange(tcount)]
c2 = [rsa.random.randint(1, n) for i in xrange(tcount)]
c = zip(c1, c2)
t = time.time()
for i, j in c:
from_projective(addf(p, q, n,
mulf(p, q, n, to_projective(p1), i),
mulf(p, q, n, to_projective(p2), j)), n)
t1 = time.time() - t
t = time.time()
for i, j in c:
muladdp(p, q, n, p1, i, p2, j)
t2 = time.time() - t
return tcount, t1, t2
|