File: count.go

package info (click to toggle)
golang-github-cloudflare-circl 1.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 18,060 kB
  • sloc: asm: 20,492; ansic: 1,292; makefile: 68
file content (143 lines) | stat: -rw-r--r-- 3,413 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Package count is a VDAF for counting Boolean measurements.
package count

import (
	"crypto/subtle"

	"github.com/cloudflare/circl/vdaf/prio3/arith"
	"github.com/cloudflare/circl/vdaf/prio3/arith/fp64"
	"github.com/cloudflare/circl/vdaf/prio3/internal/flp"
	"github.com/cloudflare/circl/vdaf/prio3/internal/prio3"
)

type (
	poly        = fp64.Poly
	Vec         = fp64.Vec
	Fp          = fp64.Fp
	AggShare    = prio3.AggShare[Vec, Fp]
	InputShare  = prio3.InputShare[Vec, Fp]
	Nonce       = prio3.Nonce
	OutShare    = prio3.OutShare[Vec, Fp]
	PrepMessage = prio3.PrepMessage
	PrepShare   = prio3.PrepShare[Vec, Fp]
	PrepState   = prio3.PrepState[Vec, Fp]
	PublicShare = prio3.PublicShare
	VerifyKey   = prio3.VerifyKey
)

// Count is a verifiable distributed aggregation function in which each
// measurement is either one or zero and the aggregate result is the sum of
// the measurements.
type Count struct {
	p prio3.Prio3[bool, uint64, *flpCount, Vec, Fp, *Fp]
}

func New(numShares uint8, context []byte) (c *Count, err error) {
	const countID = 1
	c = new(Count)
	c.p, err = prio3.New(newFlpCount(), countID, numShares, context)
	if err != nil {
		return nil, err
	}

	return c, nil
}

func (c *Count) Params() prio3.Params { return c.p.Params() }

func (c *Count) Shard(measurement bool, nonce *Nonce, rand []byte,
) (PublicShare, []InputShare, error) {
	return c.p.Shard(measurement, nonce, rand)
}

func (c *Count) PrepInit(
	verifyKey *VerifyKey,
	nonce *Nonce,
	aggID uint8,
	publicShare PublicShare,
	inputShare InputShare,
) (*PrepState, *PrepShare, error) {
	return c.p.PrepInit(verifyKey, nonce, aggID, publicShare, inputShare)
}

func (c *Count) PrepSharesToPrep(prepShares []PrepShare) (*PrepMessage, error) {
	return c.p.PrepSharesToPrep(prepShares)
}

func (c *Count) PrepNext(state *PrepState, msg *PrepMessage) (*OutShare, error) {
	return c.p.PrepNext(state, msg)
}

func (c *Count) AggregateInit() AggShare { return c.p.AggregateInit() }

func (c *Count) AggregateUpdate(aggShare *AggShare, outShare *OutShare) {
	c.p.AggregateUpdate(aggShare, outShare)
}

func (c *Count) Unshard(aggShares []AggShare, numMeas uint) (aggregate *uint64, err error) {
	return c.p.Unshard(aggShares, numMeas)
}

type flpCount struct {
	flp.FLP[flp.GadgetMulFp64, poly, Vec, Fp, *Fp]
}

func newFlpCount() *flpCount {
	c := new(flpCount)
	c.Valid.MeasurementLen = 1
	c.Valid.JointRandLen = 0
	c.Valid.OutputLen = 1
	c.Valid.EvalOutputLen = 1
	c.Gadget = flp.GadgetMulFp64{}
	c.NumGadgetCalls = 1
	c.FLP.Eval = c.Eval
	return c
}

func (c *flpCount) Eval(
	out Vec, g flp.Gadget[poly, Vec, Fp, *Fp], numCalls uint,
	meas, jointRand Vec, numShares uint8,
) {
	g.Eval(&out[0], Vec{meas[0], meas[0]})
	out[0].SubAssign(&meas[0])
}

func (c *flpCount) Encode(measurement bool) (Vec, error) {
	var one Fp
	one.SetOne()
	y, err := one.MarshalBinary()
	if err != nil {
		return nil, err
	}

	var b int
	if measurement {
		b = 1
	}

	var x [fp64.Size]byte
	subtle.ConstantTimeCopy(b, x[:], y)

	out := arith.NewVec[Vec](1)
	err = out[0].UnmarshalBinary(x[:])
	if err != nil {
		return nil, err
	}

	return out, nil
}

func (c *flpCount) Truncate(meas Vec) Vec { return meas }

func (c *flpCount) Decode(output Vec, numMeas uint) (*uint64, error) {
	if len(output) < int(c.Valid.OutputLen) {
		return nil, flp.ErrOutputLen
	}

	n, err := output[0].GetUint64()
	if err != nil {
		return nil, err
	}

	return &n, nil
}