File: sum.go

package info (click to toggle)
golang-github-cloudflare-circl 1.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 18,060 kB
  • sloc: asm: 20,492; ansic: 1,292; makefile: 68
file content (172 lines) | stat: -rw-r--r-- 4,222 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
// Package sum is a VDAF for aggregating integers in a pre-determined range.
package sum

import (
	"math/bits"

	"github.com/cloudflare/circl/vdaf/prio3/arith/fp64"
	"github.com/cloudflare/circl/vdaf/prio3/internal/cursor"
	"github.com/cloudflare/circl/vdaf/prio3/internal/flp"
	"github.com/cloudflare/circl/vdaf/prio3/internal/prio3"
)

type (
	poly        = fp64.Poly
	Vec         = fp64.Vec
	Fp          = fp64.Fp
	AggShare    = prio3.AggShare[Vec, Fp]
	InputShare  = prio3.InputShare[Vec, Fp]
	Nonce       = prio3.Nonce
	OutShare    = prio3.OutShare[Vec, Fp]
	PrepMessage = prio3.PrepMessage
	PrepShare   = prio3.PrepShare[Vec, Fp]
	PrepState   = prio3.PrepState[Vec, Fp]
	PublicShare = prio3.PublicShare
	VerifyKey   = prio3.VerifyKey
)

// Sum is a verifiable distributed aggregation function in which each
// measurement is an integer in the range [0, maxMeasurement], where
// maxMeasurement defines the largest valid measurement, the aggregated result
// is the sum of all the measurements.
type Sum struct {
	p prio3.Prio3[uint64, uint64, *flpSum, Vec, Fp, *Fp]
}

func New(numShares uint8, maxMeasurement uint64, context []byte) (s *Sum, err error) {
	const sumID = 2
	flp, err := newFlpSum(maxMeasurement)
	if err != nil {
		return nil, err
	}

	s = new(Sum)
	s.p, err = prio3.New(flp, sumID, numShares, context)
	if err != nil {
		return nil, err
	}

	return s, nil
}

func (s *Sum) Params() prio3.Params { return s.p.Params() }

func (s *Sum) Shard(measurement uint64, nonce *Nonce, rand []byte,
) (PublicShare, []InputShare, error) {
	return s.p.Shard(measurement, nonce, rand)
}

func (s *Sum) PrepInit(
	verifyKey *VerifyKey,
	nonce *Nonce,
	aggID uint8,
	publicShare PublicShare,
	inputShare InputShare,
) (*PrepState, *PrepShare, error) {
	return s.p.PrepInit(verifyKey, nonce, aggID, publicShare, inputShare)
}

func (s *Sum) PrepSharesToPrep(prepShares []PrepShare) (*PrepMessage, error) {
	return s.p.PrepSharesToPrep(prepShares)
}

func (s *Sum) PrepNext(state *PrepState, msg *PrepMessage) (*OutShare, error) {
	return s.p.PrepNext(state, msg)
}

func (s *Sum) AggregateInit() AggShare { return s.p.AggregateInit() }

func (s *Sum) AggregateUpdate(aggShare *AggShare, outShare *OutShare) {
	s.p.AggregateUpdate(aggShare, outShare)
}

func (s *Sum) Unshard(aggShares []AggShare, numMeas uint) (aggregate *uint64, err error) {
	return s.p.Unshard(aggShares, numMeas)
}

type flpSum struct {
	flp.FLP[flp.GadgetPolyEvalx2x, poly, Vec, Fp, *Fp]
	bits   uint
	offset Fp
}

func newFlpSum(maxMeasurement uint64) (*flpSum, error) {
	bits := uint(bits.Len64(maxMeasurement))
	offset := (uint64(1) << uint64(bits)) - 1 - maxMeasurement

	s := new(flpSum)
	s.bits = bits
	err := s.offset.SetUint64(offset)
	if err != nil {
		return nil, err
	}

	s.Valid.MeasurementLen = 2 * bits
	s.Valid.JointRandLen = 0
	s.Valid.OutputLen = 1
	s.Valid.EvalOutputLen = 2*bits + 1
	s.Gadget = flp.GadgetPolyEvalx2x{}
	s.NumGadgetCalls = 2 * bits
	s.FLP.Eval = s.Eval
	return s, nil
}

func (s *flpSum) Eval(
	out Vec, g flp.Gadget[poly, Vec, Fp, *Fp], numCalls uint,
	meas, jointRand Vec, numShares uint8,
) {
	var input [1]Fp
	for i := range meas {
		input[0] = meas[i]
		g.Eval(&out[i], input[:])
	}

	measCur := cursor.New(meas)
	a := measCur.Next(s.bits).JoinBits()
	b := measCur.Next(s.bits).JoinBits()

	var invShares Fp
	invShares.InvUint64(uint64(numShares))
	rangeCheck := &out[len(meas)]
	rangeCheck.Mul(&s.offset, &invShares)
	rangeCheck.AddAssign(&a)
	rangeCheck.SubAssign(&b)
}

func (s *flpSum) Encode(measurement uint64) (Vec, error) {
	offset, err := s.offset.GetUint64()
	if err != nil {
		return nil, err
	}

	out := make(Vec, s.Valid.MeasurementLen)
	outCur := cursor.New(out)
	err = outCur.Next(s.bits).SplitBits(measurement)
	if err != nil {
		return nil, err
	}

	err = outCur.Next(s.bits).SplitBits(measurement + offset)
	if err != nil {
		return nil, err
	}

	return out, nil
}

func (s *flpSum) Truncate(meas Vec) Vec {
	return Vec{meas[:s.bits].JoinBits()}
}

func (s *flpSum) Decode(output Vec, numMeas uint) (*uint64, error) {
	if len(output) < int(s.Valid.OutputLen) {
		return nil, flp.ErrOutputLen
	}

	n, err := output[0].GetUint64()
	if err != nil {
		return nil, err
	}

	return &n, nil
}