File: tmap.go

package info (click to toggle)
trillian 1.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,600 kB
  • sloc: sh: 1,181; javascript: 474; sql: 330; makefile: 39
file content (378 lines) | stat: -rw-r--r-- 13,828 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
// Copyright 2020 Google LLC. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package batchmap is a library to be used within Beam pipelines to construct
// verifiable data structures.
package batchmap

//go:generate go install github.com/apache/beam/sdks/v2/go/cmd/starcgen
//go:generate starcgen --package=batchmap --identifiers=entryToNodeHashFn,partitionByPrefixLenFn,tileHashFn,leafShardFn,tileToNodeHashFn,tileUpdateFn

import (
	"context"
	"crypto"
	"fmt"

	"github.com/apache/beam/sdks/v2/go/pkg/beam"
	"github.com/apache/beam/sdks/v2/go/pkg/beam/register"

	"github.com/google/trillian/merkle/coniks"
	"github.com/google/trillian/merkle/smt"
	"github.com/google/trillian/merkle/smt/node"
)

var (
	cntTilesHashed  = beam.NewCounter("batchmap", "tiles-hashed")
	cntTilesCopied  = beam.NewCounter("batchmap", "tiles-copied")
	cntTilesCreated = beam.NewCounter("batchmap", "tiles-created")
	cntTilesUpdated = beam.NewCounter("batchmap", "tiles-updated")
)

func init() {
	register.DoFn1x2[nodeHash, []byte, nodeHash](&leafShardFn{})
	register.DoFn3x2[context.Context, []byte, func(*nodeHash) bool, *Tile, error](&tileHashFn{})
	register.DoFn4x2[context.Context, []byte, func(**Tile) bool, func(*nodeHash) bool, *Tile, error](&tileUpdateFn{})
	register.Function5x1(createStratum)
	register.Function6x1(updateStratum)
	register.Function1x2(tilePathFn)
}

// Create builds a new map from the given PCollection of *Entry. Outputs
// the resulting Merkle tree tiles as a PCollection of *Tile.
//
// The keys in the input PCollection must be 256-bit, uniformly distributed,
// and unique within the input.
// The values in the input PCollection must be 256-bit.
// treeID should be a unique ID for the lifetime of this map. This is used as
// part of the hashing algorithm to provide preimage resistance. If the tiles
// are to be imported into Trillian for serving, this must match the tree ID
// within Trillian.
// The internal hash algorithm can be picked between SHA256 and SHA512_256.
// The internal nodes will use this algorithm via the CONIKS strategy.
// prefixStrata is the number of 8-bit prefix strata. Any path from root to leaf
// will have prefixStrata+1 tiles.
func Create(s beam.Scope, entries beam.PCollection, treeID int64, hash crypto.Hash, prefixStrata int) (beam.PCollection, error) {
	s = s.Scope("batchmap.Create")
	if prefixStrata < 0 || prefixStrata >= 32 {
		return beam.PCollection{}, fmt.Errorf("prefixStrata must be in [0, 32), got %d", prefixStrata)
	}

	// Construct the map pipeline starting with the leaf tiles.
	nodeHashes := beam.ParDo(s, entryToNodeHashFn, entries)
	lastStratum := createStratum(s, nodeHashes, treeID, hash, prefixStrata)
	allTiles := make([]beam.PCollection, 0, prefixStrata+1)
	allTiles = append(allTiles, lastStratum)
	for d := prefixStrata - 1; d >= 0; d-- {
		nodeHashes = beam.ParDo(s, tileToNodeHashFn, lastStratum)
		lastStratum = createStratum(s, nodeHashes, treeID, hash, d)
		allTiles = append(allTiles, lastStratum)
	}

	// Collate all of the strata together and return them.
	return beam.Flatten(s, allTiles...), nil
}

// Update takes an existing base map (PCollection of *Tile), applies the
// delta (PCollection of *Entry) and returns the resulting map as a
// PCollection of *Tile.
// The deltas can add new keys to the map or overwrite existing keys. Keys
// cannot be deleted (though their value can be set to a sentinel value).
//
// treeID, hash, and prefixStrata must match the values passed into the
// original call to Create that started the base map.
func Update(s beam.Scope, base, delta beam.PCollection, treeID int64, hash crypto.Hash, prefixStrata int) (beam.PCollection, error) {
	s = s.Scope("batchmap.Update")
	if prefixStrata < 0 || prefixStrata >= 32 {
		return beam.PCollection{}, fmt.Errorf("prefixStrata must be in [0, 32), got %d", prefixStrata)
	}

	// Tile sets returned from this library have tiles present at all byte
	// lengths from [0..prefixStrata]. This makes this a perfect partition fn.
	baseStrata := beam.Partition(s, prefixStrata+1, partitionByPrefixLenFn, base)
	// Construct the map pipeline starting with the leaf tiles.
	nodeHashes := beam.ParDo(s, entryToNodeHashFn, delta)
	lastStratum := updateStratum(s, baseStrata[prefixStrata], nodeHashes, treeID, hash, prefixStrata)

	allTiles := make([]beam.PCollection, 0, prefixStrata+1)
	allTiles = append(allTiles, lastStratum)
	for d := prefixStrata - 1; d >= 0; d-- {
		nodeHashes = beam.ParDo(s, tileToNodeHashFn, lastStratum)
		lastStratum = updateStratum(s, baseStrata[d], nodeHashes, treeID, hash, d)
		allTiles = append(allTiles, lastStratum)
	}

	// Collate all of the strata together and return them.
	return beam.Flatten(s, allTiles...), nil
}

// createStratum creates the tiles for the stratum at the given rootDepth bytes.
// leaves is a PCollection of nodeHash that are the leaves of this layer.
// output is a PCollection of *Tile.
func createStratum(s beam.Scope, leaves beam.PCollection, treeID int64, hash crypto.Hash, rootDepth int) beam.PCollection {
	s = s.Scope(fmt.Sprintf("createStratum-%d", rootDepth))
	shardedLeaves := beam.ParDo(s, &leafShardFn{RootDepthBytes: rootDepth}, leaves)
	return beam.ParDo(s, &tileHashFn{TreeID: treeID, Hash: hash}, beam.GroupByKey(s, shardedLeaves))
}

// updateStratum updates the tiles for the stratum at the given bytes depth.
// base is a PCollection of *Tile which is the tiles in the stratum
// to be updated.
// deltas is a PCollection of nodeHash that are the updated leaves of this layer.
// output is a PCollection of *Tile.
func updateStratum(s beam.Scope, base, deltas beam.PCollection, treeID int64, hash crypto.Hash, rootDepth int) beam.PCollection {
	s = s.Scope(fmt.Sprintf("updateStratum-%d", rootDepth))
	shardedBase := beam.ParDo(s, tilePathFn, base)
	shardedDelta := beam.ParDo(s, &leafShardFn{RootDepthBytes: rootDepth}, deltas)
	return beam.ParDo(s, &tileUpdateFn{TreeID: treeID, Hash: hash}, beam.CoGroupByKey(s, shardedBase, shardedDelta))
}

func tilePathFn(t *Tile) ([]byte, *Tile) { return t.Path, t }

// nodeHash describes a leaf to be included in a tile.
// This is logically the same as smt.Node however it has public fields so is
// serializable by the default Beam coder. Also, it allows changes to be made
// to smt.Node without affecting this, which improves decoupling.
type nodeHash struct {
	// Path from root of the map to this node. Equivalent to node.ID, but with
	// the significant benefit that it will be serialized properly without
	// writing a custom coder for nodeHash.
	Path []byte
	Hash []byte
}

func partitionByPrefixLenFn(t *Tile) int {
	return len(t.Path)
}

func tileToNodeHashFn(t *Tile) nodeHash {
	return nodeHash{Path: t.Path, Hash: t.RootHash}
}

func entryToNodeHashFn(e *Entry) nodeHash {
	return nodeHash{Path: e.HashKey, Hash: e.HashValue}
}

// leafShardFn groups nodeHashs together based on the first RootDepthBytes
// bytes of their path. This groups all leaves from the same tile together.
type leafShardFn struct {
	RootDepthBytes int
}

func (fn *leafShardFn) ProcessElement(leaf nodeHash) ([]byte, nodeHash) {
	return leaf.Path[:fn.RootDepthBytes], leaf
}

type tileHashFn struct {
	TreeID int64
	Hash   crypto.Hash
	th     *tileHasher
}

func (fn *tileHashFn) Setup() {
	fn.th = &tileHasher{fn.TreeID, coniks.New(fn.Hash)}
}

func (fn *tileHashFn) ProcessElement(ctx context.Context, rootPath []byte, leaves func(*nodeHash) bool) (*Tile, error) {
	nodes, err := convertNodes(leaves)
	if err != nil {
		return nil, err
	}
	cntTilesHashed.Inc(ctx, 1)
	return fn.th.construct(rootPath, nodes)
}

// convertNodes consumes the Beam-style iterator of nodeHash and returns the
// corresponding slice of smt.Node. Nothing clever is attempted to ensure that
// the data structure will fit in memory. If the iterator has too many elements
// then this will cause an out of memory panic. It is up to the library client
// to configure the map with an appropriate number of prefix strata such that
// this does not occur.
func convertNodes(leaves func(*nodeHash) bool) ([]smt.Node, error) {
	nodes := []smt.Node{}
	var leaf nodeHash
	for leaves(&leaf) {
		lid, err := nodeID2Decode(leaf.Path)
		if err != nil {
			return nil, fmt.Errorf("failed to decode leaf ID: %v", err)
		}
		nodes = append(nodes, smt.Node{ID: lid, Hash: leaf.Hash})
	}
	return nodes, nil
}

// tileUpdateFn merges the base tile from the original map with the deltas that
// represent the changes to the map. Note this only supports additions or
// overwrites. There is no ability to delete a leaf.
type tileUpdateFn struct {
	TreeID int64
	Hash   crypto.Hash
	th     *tileHasher
}

func (fn *tileUpdateFn) Setup() {
	fn.th = &tileHasher{fn.TreeID, coniks.New(fn.Hash)}
}

func (fn *tileUpdateFn) ProcessElement(ctx context.Context, rootPath []byte, bases func(**Tile) bool, deltas func(*nodeHash) bool) (*Tile, error) {
	base, err := getOptionalTile(bases)
	if err != nil {
		return nil, fmt.Errorf("failed precondition getOptionalTile at %x: %v", rootPath, err)
	}

	nodes, err := convertNodes(deltas)
	if err != nil {
		return nil, err
	}

	if len(nodes) == 0 {
		// If there are no deltas, then the base tile is unchanged.
		cntTilesCopied.Inc(ctx, 1)
		return base, nil
	}
	if base == nil {
		cntTilesCreated.Inc(ctx, 1)
		return fn.th.construct(rootPath, nodes)
	}

	cntTilesUpdated.Inc(ctx, 1)
	return fn.updateTile(rootPath, base, nodes)
}

func (fn *tileUpdateFn) updateTile(rootPath []byte, base *Tile, deltas []smt.Node) (*Tile, error) {
	baseNodes := make([]smt.Node, 0, len(base.Leaves))
	for _, l := range base.Leaves {
		leafPath := append(rootPath, l.Path...)
		lidx, err := nodeID2Decode(leafPath)
		if err != nil {
			return nil, fmt.Errorf("failed to decode leaf ID: %v", err)
		}
		baseNodes = append(baseNodes, smt.Node{ID: lidx, Hash: l.Hash})
	}

	return fn.th.update(rootPath, baseNodes, deltas)
}

// tileHasher is an smt.NodeAccessor used for computing node hashes of a tile.
// This is not serializable and must be constructed within each worker stage.
type tileHasher struct {
	treeID int64
	h      *coniks.Hasher
}

func (th *tileHasher) construct(rootPath []byte, nodes []smt.Node) (*Tile, error) {
	rootDepthBytes := len(rootPath)
	if err := smt.Prepare(nodes, nodes[0].ID.BitLen()); err != nil {
		return nil, fmt.Errorf("smt.Prepare: %v", err)
	}

	// N.B. This needs to be done after Prepare but BEFORE HStar3 because it
	// fiddles around with the nodes and makes their IDs invalid afterwards.
	tls := make([]*TileLeaf, len(nodes))
	for i, n := range nodes {
		nPath, err := nodeID2Encode(n.ID)
		if err != nil {
			return nil, fmt.Errorf("failed to encode leaf ID: %v", err)
		}
		tls[i] = &TileLeaf{
			Path: nPath[rootDepthBytes:],
			Hash: n.Hash,
		}
	}

	rootHash, err := th.hashTile(uint(8*rootDepthBytes), nodes)
	if err != nil {
		return nil, fmt.Errorf("failed to hash tile: %v", err)
	}

	return &Tile{
		Path:     rootPath,
		Leaves:   tls,
		RootHash: rootHash,
	}, nil
}

func (th *tileHasher) update(rootPath []byte, baseNodes, deltaNodes []smt.Node) (*Tile, error) {
	// We add new values first and then update with base to easily check for duplicates in deltas.
	m := make(map[node.ID]smt.Node)
	for _, leaf := range deltaNodes {
		if v, found := m[leaf.ID]; found {
			return nil, fmt.Errorf("found duplicate values at leaf tile position %s: {%x, %x}", leaf.ID, v.Hash, leaf.Hash)
		}
		m[leaf.ID] = leaf
	}

	for _, leaf := range baseNodes {
		if _, found := m[leaf.ID]; !found {
			// Only add base values if they haven't been updated.
			m[leaf.ID] = leaf
		}
	}

	nodes := make([]smt.Node, 0, len(m))
	for _, v := range m {
		nodes = append(nodes, v)
	}
	return th.construct(rootPath, nodes)
}

// hashTile computes the root hash of the root given the prepared leaves.
// The leaves slice MUST NOT be used after calling this method.
func (th *tileHasher) hashTile(depthBits uint, leaves []smt.Node) ([]byte, error) {
	h, err := smt.NewHStar3(leaves, th.h.HashChildren, uint(leaves[0].ID.BitLen()), depthBits)
	if err != nil {
		return nil, err
	}
	r, err := h.Update(th)
	if err != nil {
		return nil, err
	}
	if len(r) != 1 {
		return nil, fmt.Errorf("expected single root but got %d", len(r))
	}
	return r[0].Hash, nil
}

// Get returns hash of an empty subtree for the given root node ID.
func (th tileHasher) Get(id node.ID) ([]byte, error) {
	return th.h.HashEmpty(th.treeID, id), nil
}

func (th tileHasher) Set(id node.ID, hash []byte) {}

func nodeID2Encode(n node.ID) ([]byte, error) {
	b, c := n.LastByte()
	if c == 0 {
		return []byte{}, nil
	}
	if c == 8 {
		return append([]byte(n.FullBytes()), b), nil
	}
	return nil, fmt.Errorf("node ID bit length is not aligned to bytes: %d", n.BitLen())
}

func nodeID2Decode(bs []byte) (node.ID, error) {
	return node.NewID(string(bs), 8*uint(len(bs))), nil
}

// getOptionalTile consumes the Beam-style iterator and returns:
// - nil if there were no entries
// - the single tile if there was only one entry
// - an error if there were multiple entries
func getOptionalTile(iter func(**Tile) bool) (*Tile, error) {
	var t1, t2 *Tile
	if !iter(&t1) || !iter(&t2) { // Only at most one entry is found.
		return t1, nil // Note: Returns nil if found nothing.
	}
	return nil, fmt.Errorf("unexpectedly found multiple tiles at %x", t1.Path)
}