File: optimized_pairing.py

package info (click to toggle)
python-py-ecc 8.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 796 kB
  • sloc: python: 4,896; makefile: 237
file content (249 lines) | stat: -rw-r--r-- 6,079 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
from py_ecc.fields import (
    optimized_bls12_381_FQ as FQ,
    optimized_bls12_381_FQ2 as FQ2,
    optimized_bls12_381_FQ12 as FQ12,
)
from py_ecc.fields.field_properties import (
    field_properties,
)
from py_ecc.typing import (
    Optimized_Field,
    Optimized_Point2D,
    Optimized_Point3D,
)

from .optimized_curve import (
    G1,
    add,
    b,
    b2,
    curve_order,
    double,
    is_on_curve,
    multiply,
    normalize,
    twist,
)

field_modulus = field_properties["bls12_381"]["field_modulus"]

ate_loop_count = 15132376222941642752
log_ate_loop_count = 62
pseudo_binary_encoding = [
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    1,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    1,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    1,
    0,
    0,
    1,
    0,
    1,
    1,
]


if not (
    sum([e * 2**i for i, e in enumerate(pseudo_binary_encoding)]) == ate_loop_count
):
    raise ValueError("Pseudo binary encoding is incorrect")


def normalize1(
    p: Optimized_Point3D[Optimized_Field],
) -> Optimized_Point3D[Optimized_Field]:
    x, y = normalize(p)

    return x, y, x.one()


# Create a function representing the line between P1 and P2,
# and evaluate it at T. Returns a numerator and a denominator
# to avoid unneeded divisions
def linefunc(
    P1: Optimized_Point3D[Optimized_Field],
    P2: Optimized_Point3D[Optimized_Field],
    T: Optimized_Point3D[Optimized_Field],
) -> Optimized_Point2D[Optimized_Field]:
    zero = P1[0].zero()
    x1, y1, z1 = P1
    x2, y2, z2 = P2
    xt, yt, zt = T
    # points in projective coords: (x / z, y / z)
    # hence, m = (y2/z2 - y1/z1) / (x2/z2 - x1/z1)
    # multiply numerator and denominator by z1z2 to get values below
    m_numerator = y2 * z1 - y1 * z2
    m_denominator = x2 * z1 - x1 * z2
    if m_denominator != zero:
        # m * ((xt/zt) - (x1/z1)) - ((yt/zt) - (y1/z1))
        return (
            m_numerator * (xt * z1 - x1 * zt) - m_denominator * (yt * z1 - y1 * zt),
            m_denominator * zt * z1,
        )
    elif m_numerator == zero:
        # m = 3(x/z)^2 / 2(y/z), multiply num and den by z**2
        m_numerator = 3 * x1 * x1
        m_denominator = 2 * y1 * z1
        return (
            m_numerator * (xt * z1 - x1 * zt) - m_denominator * (yt * z1 - y1 * zt),
            m_denominator * zt * z1,
        )
    else:
        return xt * z1 - x1 * zt, z1 * zt


def cast_point_to_fq12(pt: Optimized_Point3D[FQ]) -> Optimized_Point3D[FQ12]:
    if pt is None:
        return None
    x, y, z = pt
    return (FQ12([x.n] + [0] * 11), FQ12([y.n] + [0] * 11), FQ12([z.n] + [0] * 11))


# Check consistency of the "line function"
one, two, three = G1, double(G1), multiply(G1, 3)
negone, negtwo, negthree = (
    multiply(G1, curve_order - 1),
    multiply(G1, curve_order - 2),
    multiply(G1, curve_order - 3),
)


conditions = [
    linefunc(one, two, one)[0] == FQ(0),
    linefunc(one, two, two)[0] == FQ(0),
    linefunc(one, two, three)[0] != FQ(0),
    linefunc(one, two, negthree)[0] == FQ(0),
    linefunc(one, negone, one)[0] == FQ(0),
    linefunc(one, negone, negone)[0] == FQ(0),
    linefunc(one, negone, two)[0] != FQ(0),
    linefunc(one, one, one)[0] == FQ(0),
    linefunc(one, one, two)[0] != FQ(0),
    linefunc(one, one, negtwo)[0] == FQ(0),
]

if not all(conditions):
    raise ValueError("Line function is inconsistent")


# Main miller loop
def miller_loop(
    Q: Optimized_Point3D[FQ2], P: Optimized_Point3D[FQ], final_exponentiate: bool = True
) -> FQ12:
    if Q is None or P is None:
        return FQ12.one()
    cast_P = cast_point_to_fq12(P)
    twist_R = twist_Q = twist(Q)
    R: Optimized_Point3D[FQ2] = Q
    f_num, f_den = FQ12.one(), FQ12.one()
    # for i in range(log_ate_loop_count, -1, -1):
    for v in pseudo_binary_encoding[62::-1]:
        _n, _d = linefunc(twist_R, twist_R, cast_P)
        f_num = f_num * f_num * _n
        f_den = f_den * f_den * _d
        R = double(R)
        twist_R = twist(R)
        if v == 1:
            _n, _d = linefunc(twist_R, twist_Q, cast_P)
            f_num = f_num * _n
            f_den = f_den * _d
            R = add(R, Q)
            twist_R = twist(R)
    # assert R == multiply(Q, ate_loop_count)
    # Q1 = (Q[0] ** field_modulus, Q[1] ** field_modulus, Q[2] ** field_modulus)
    # assert is_on_curve(Q1, b12)
    # nQ2 = (Q1[0] ** field_modulus, -Q1[1] ** field_modulus, Q1[2] ** field_modulus)
    # assert is_on_curve(nQ2, b12)
    # _n1, _d1 = linefunc(R, Q1, P)
    # R = add(R, Q1)
    # _n2, _d2 = linefunc(R, nQ2, P)
    # f = f_num * _n1 * _n2 / (f_den * _d1 * _d2)
    f = f_num / f_den
    # R = add(R, nQ2) This line is in many specifications but technically does nothing
    if final_exponentiate:
        return f ** ((field_modulus**12 - 1) // curve_order)
    else:
        return f


# Pairing computation
def pairing(
    Q: Optimized_Point3D[FQ2], P: Optimized_Point3D[FQ], final_exponentiate: bool = True
) -> FQ12:
    if not is_on_curve(Q, b2):
        raise ValueError("Invalid input - point Q is not on the correct curve")
    if not is_on_curve(P, b):
        raise ValueError("Invalid input - point P is not on the correct curves")
    if P[-1] == (P[-1].zero()) or Q[-1] == (Q[-1].zero()):
        return FQ12.one()
    return miller_loop(Q, P, final_exponentiate=final_exponentiate)


exptable = [FQ12([0] * i + [1] + [0] * (11 - i)) ** field_modulus for i in range(12)]


def exp_by_p(x: FQ12) -> FQ12:
    return sum(
        (table_entry * int(coeff) for table_entry, coeff in zip(exptable, x.coeffs)),
        FQ12.zero(),
    )


def final_exponentiate(p: FQ12) -> FQ12:
    cofactor = (field_modulus**4 - field_modulus**2 + 1) // curve_order
    p2 = exp_by_p(exp_by_p(p)) * p
    p3 = exp_by_p(exp_by_p(exp_by_p(exp_by_p(exp_by_p(exp_by_p(p2)))))) / p2
    return p3**cofactor