File: expiring_hash.go

package info (click to toggle)
gitlab-agent 16.1.3-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 6,324 kB
  • sloc: makefile: 175; sh: 52; ruby: 3
file content (329 lines) | stat: -rw-r--r-- 10,275 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
package redistool

import (
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"strings"
	"time"

	"github.com/redis/rueidis"
	"golang.org/x/sync/errgroup"
	"google.golang.org/protobuf/proto"
)

// KeyToRedisKey is used to convert typed key (key1 or key2) into a string.
// HSET key1 key2 value.
type KeyToRedisKey[K any] func(key K) string
type ScanCallback func(rawHashKey string, value []byte, err error) (bool /* done */, error)

// IOFunc is a function that should be called to perform the I/O of the requested operation.
// It is safe to call concurrently as it does not interfere with the hash's operation.
type IOFunc func(ctx context.Context) error

// ExpiringHashInterface represents a two-level hash: key K1 -> hashKey K2 -> value []byte.
// key identifies the hash; hashKey identifies the key in the hash; value is the value for the hashKey.
// It is not safe for concurrent use directly, but it allows to perform I/O with backing store concurrently by
// returning functions for doing that.
type ExpiringHashInterface[K1 any, K2 any] interface {
	Set(key K1, hashKey K2, value []byte) IOFunc
	Unset(key K1, hashKey K2) IOFunc
	// Forget only removes the item from the in-memory map.
	Forget(key K1, hashKey K2)
	Scan(ctx context.Context, key K1, cb ScanCallback) (int /* keysDeleted */, error)
	Len(ctx context.Context, key K1) (int64, error)
	// GC returns a function that iterates all relevant stored data and deletes expired entries.
	// The returned function can be called concurrently as it does not interfere with the hash's operation.
	// The function returns number of deleted Redis (hash) keys, including when an error occurs.
	// It only inspects/GCs hashes where it has entries. Other concurrent clients GC same and/or other corresponding hashes.
	// Hashes that don't have a corresponding client (e.g. because it crashed) will expire because of TTL on the hash key.
	GC() func(context.Context) (int /* keysDeleted */, error)
	// Clear clears all data in this hash and deletes it from the backing store.
	Clear(context.Context) (int, error)
	Refresh(nextRefresh time.Time) IOFunc
}

type ExpiringHash[K1 comparable, K2 comparable] struct {
	client         rueidis.Client
	key1ToRedisKey KeyToRedisKey[K1]
	key2ToRedisKey KeyToRedisKey[K2]
	ttl            time.Duration
	data           map[K1]map[K2]*ExpiringValue // key -> hash key -> value
}

func NewExpiringHash[K1 comparable, K2 comparable](client rueidis.Client, key1ToRedisKey KeyToRedisKey[K1],
	key2ToRedisKey KeyToRedisKey[K2], ttl time.Duration) *ExpiringHash[K1, K2] {
	return &ExpiringHash[K1, K2]{
		client:         client,
		key1ToRedisKey: key1ToRedisKey,
		key2ToRedisKey: key2ToRedisKey,
		ttl:            ttl,
		data:           make(map[K1]map[K2]*ExpiringValue),
	}
}

func (h *ExpiringHash[K1, K2]) Set(key K1, hashKey K2, value []byte) IOFunc {
	ev := &ExpiringValue{
		ExpiresAt: time.Now().Add(h.ttl).Unix(),
		Value:     value,
	}
	h.setData(key, hashKey, ev)
	return func(ctx context.Context) error {
		return h.refreshKey(ctx, key, []refreshKey[K2]{
			{
				hashKey: hashKey,
				value: ExpiringValue{ // cannot copy ev directly
					ExpiresAt: ev.ExpiresAt,
					Value:     ev.Value,
				},
			},
		})
	}
}

func (h *ExpiringHash[K1, K2]) Unset(key K1, hashKey K2) IOFunc {
	h.unsetData(key, hashKey)
	return func(ctx context.Context) error {
		hdelCmd := h.client.B().Hdel().Key(h.key1ToRedisKey(key)).Field(h.key2ToRedisKey(hashKey)).Build()
		return h.client.Do(ctx, hdelCmd).Error()
	}
}

func (h *ExpiringHash[K1, K2]) Forget(key K1, hashKey K2) {
	h.unsetData(key, hashKey)
}

func (h *ExpiringHash[K1, K2]) Len(ctx context.Context, key K1) (size int64, retErr error) {
	hlenCmd := h.client.B().Hlen().Key(h.key1ToRedisKey(key)).Build()
	return h.client.Do(ctx, hlenCmd).AsInt64()
}

func (h *ExpiringHash[K1, K2]) scan(ctx context.Context, key K1, cb func(k, v string) (bool /*done*/, bool /*delete*/, error)) (keysDeleted int, retErr error) {
	redisKey := h.key1ToRedisKey(key)
	var keysToDelete []string
	defer func() {
		if len(keysToDelete) == 0 {
			return
		}
		hdelCmd := h.client.B().Hdel().Key(redisKey).Field(keysToDelete...).Build()
		err := h.client.Do(ctx, hdelCmd).Error()
		if err != nil {
			if retErr == nil {
				retErr = err
			}
			return
		}
		keysDeleted = len(keysToDelete)
	}()
	// Scan keys of a hash. See https://redis.io/commands/scan
	var se rueidis.ScanEntry
	var err error
	for more := true; more; more = se.Cursor != 0 {
		hscanCmd := h.client.B().Hscan().Key(redisKey).Cursor(se.Cursor).Build()
		se, err = h.client.Do(ctx, hscanCmd).AsScanEntry()
		if err != nil {
			return 0, err
		}
		if len(se.Elements)%2 != 0 {
			// This shouldn't happen
			return 0, errors.New("invalid Redis reply")
		}
		for i := 0; i < len(se.Elements); i += 2 {
			k := se.Elements[i]
			v := se.Elements[i+1]
			done, del, err := cb(k, v)
			if del {
				keysToDelete = append(keysToDelete, k)
			}
			if err != nil || done {
				return 0, err
			}
		}
	}
	return 0, nil
}

func (h *ExpiringHash[K1, K2]) Scan(ctx context.Context, key K1, cb ScanCallback) (keysDeleted int, retErr error) {
	now := time.Now().Unix()
	var msg ExpiringValue
	return h.scan(ctx, key, func(k, v string) (bool /*done*/, bool /*delete*/, error) {
		err := proto.Unmarshal([]byte(v), &msg)
		if err != nil {
			done, cbErr := cb(k, nil, fmt.Errorf("failed to unmarshal hash value from hashkey 0x%x: %w", k, err))
			return done, false, cbErr
		}
		if msg.ExpiresAt < now {
			return false, true, nil
		}
		done, cbErr := cb(k, msg.Value, nil)
		return done, false, cbErr
	})
}

func (h *ExpiringHash[K1, K2]) GC() func(context.Context) (int, error) {
	// Copy keys for safe concurrent access.
	keys := make([]K1, 0, len(h.data))
	for key := range h.data {
		keys = append(keys, key)
	}
	return func(ctx context.Context) (int, error) {
		var deletedKeys int
		for _, key := range keys {
			deleted, err := h.gcHash(ctx, key)
			if err != nil {
				return deletedKeys, err
			}
			deletedKeys += deleted
		}
		return deletedKeys, nil
	}
}

// gcHash iterates a hash and removes all expired values.
// It assumes that values are marshaled ExpiringValue.
func (h *ExpiringHash[K1, K2]) gcHash(ctx context.Context, key K1) (int, error) {
	now := time.Now().Unix()
	var msg ExpiringValueTimestamp
	var firstErr error
	deleted, err := h.scan(ctx, key, func(k, v string) (bool /*done*/, bool /*delete*/, error) {
		err := proto.UnmarshalOptions{
			DiscardUnknown: true, // We know there is one more field, but we don't need it
		}.Unmarshal([]byte(v), &msg)
		if err != nil {
			if firstErr == nil {
				firstErr = err
			}
			return false, false, nil
		}
		return false, msg.ExpiresAt < now, nil
	})
	if err != nil {
		return deleted, err
	}
	return deleted, firstErr
}

func (h *ExpiringHash[K1, K2]) Clear(ctx context.Context) (int, error) {
	var toDel []string
	keysDeleted := 0
	cmds := make([]rueidis.Completed, 0, len(h.data))
	for k1, m := range h.data {
		toDel = toDel[:0] // reuse backing array, but reset length
		for k2 := range m {
			toDel = append(toDel, h.key2ToRedisKey(k2))
		}
		cmds = append(cmds, h.client.B().Hdel().Key(h.key1ToRedisKey(k1)).Field(toDel...).Build())
		delete(h.data, k1)
		keysDeleted += len(toDel)
	}
	err := MultiFirstError(h.client.DoMulti(ctx, cmds...))
	return keysDeleted, err
}

func (h *ExpiringHash[K1, K2]) Refresh(nextRefresh time.Time) IOFunc {
	argsMap := make(map[K1][]refreshKey[K2], len(h.data))
	for key, hashData := range h.data {
		args := h.prepareRefreshKey(hashData, nextRefresh)
		if len(args) == 0 {
			// Nothing to do for this key.
			continue
		}
		argsMap[key] = args
	}
	return func(ctx context.Context) error {
		var wg errgroup.Group
		for key, args := range argsMap {
			key := key
			args := args
			wg.Go(func() error {
				return h.refreshKey(ctx, key, args)
			})
		}
		return wg.Wait()
	}
}

func (h *ExpiringHash[K1, K2]) prepareRefreshKey(hashData map[K2]*ExpiringValue, nextRefresh time.Time) []refreshKey[K2] {
	args := make([]refreshKey[K2], 0, len(hashData))
	expiresAt := time.Now().Add(h.ttl).Unix()
	nextRefreshUnix := nextRefresh.Unix()
	for hashKey, value := range hashData {
		if value.ExpiresAt > nextRefreshUnix {
			// Expires after next refresh. Will be refreshed later, no need to refresh now.
			continue
		}
		value.ExpiresAt = expiresAt
		// Copy value to decouple from the mutable instance in hashData. That way it's safe for concurrent access.
		args = append(args, refreshKey[K2]{
			hashKey: hashKey,
			value:   ExpiringValue{ExpiresAt: value.ExpiresAt, Value: value.Value},
		})
	}
	return args
}

func (h *ExpiringHash[K1, K2]) refreshKey(ctx context.Context, key K1, args []refreshKey[K2]) error {
	var marshalErr error
	redisKey := h.key1ToRedisKey(key)
	hsetCmd := h.client.B().Hset().Key(redisKey).FieldValue()
	empty := true
	// Iterate indexes to avoid copying the value which has inlined proto message, which shouldn't be copied.
	for i := range args {
		redisValue, err := proto.Marshal(&args[i].value)
		if err != nil {
			// This should never happen
			if marshalErr == nil {
				marshalErr = fmt.Errorf("failed to marshal ExpiringValue: %w", err)
			}
			continue // skip this value
		}
		hsetCmd.FieldValue(h.key2ToRedisKey(args[i].hashKey), rueidis.BinaryString(redisValue))
		empty = false
	}
	if empty {
		return nil // nothing to do, all skipped.
	}
	resp := h.client.DoMulti(ctx,
		h.client.B().Multi().Build(),
		hsetCmd.Build(),
		h.client.B().Pexpire().Key(redisKey).Milliseconds(h.ttl.Milliseconds()).Build(),
		h.client.B().Exec().Build(),
	)
	err := MultiFirstError(resp)
	if err != nil {
		return err
	}
	return marshalErr
}

func (h *ExpiringHash[K1, K2]) setData(key K1, hashKey K2, value *ExpiringValue) {
	nm := h.data[key]
	if nm == nil {
		nm = make(map[K2]*ExpiringValue, 1)
		h.data[key] = nm
	}
	nm[hashKey] = value
}

func (h *ExpiringHash[K1, K2]) unsetData(key K1, hashKey K2) {
	nm := h.data[key]
	delete(nm, hashKey)
	if len(nm) == 0 {
		delete(h.data, key)
	}
}

type refreshKey[K2 any] struct {
	hashKey K2
	value   ExpiringValue
}

func PrefixedInt64Key(prefix string, key int64) string {
	var b strings.Builder
	b.WriteString(prefix)
	id := make([]byte, 8)
	binary.LittleEndian.PutUint64(id, uint64(key))
	b.Write(id)
	return b.String()
}