File: expiring_hash_api.go

package info (click to toggle)
gitlab-agent 16.11.5-1
  • links: PTS, VCS
  • area: contrib
  • in suites: experimental
  • size: 7,072 kB
  • sloc: makefile: 193; sh: 55; ruby: 3
file content (240 lines) | stat: -rw-r--r-- 8,207 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
package redistool

import (
	"context"
	"errors"
	"fmt"
	"time"
	"unsafe"

	"github.com/redis/rueidis"
	"go.opentelemetry.io/otel/attribute"
	otelmetric "go.opentelemetry.io/otel/metric"
	"google.golang.org/protobuf/proto"
)

const (
	scanCount                             = 1000
	maxKeyGCAttempts                      = 2
	gcDeletedKeysMetricName               = "redis_expiring_hash_api_gc_deleted_keys_count"
	gcConflictMetricName                  = "redis_expiring_hash_api_gc_conflict"
	expiringHashNameKey     attribute.Key = "expiring_hash_name"
)

// ExpiringHashAPI represents a low-level API to work with a two-level hash: key K1 -> hashKey K2 -> value []byte.
// key identifies the hash; hashKey identifies the key in the hash; value is the value for the hashKey.
type ExpiringHashAPI[K1 any, K2 any] interface {
	SetBuilder() SetBuilder[K1, K2]
	Unset(ctx context.Context, key K1, hashKey K2) error
	Scan(ctx context.Context, key K1, cb ScanCallback) error
	GCFor(keys []K1, transactional bool) func(context.Context) (int /* keysDeleted */, error)
}

type ScanCallback func(rawHashKey string, value []byte, err error) (bool /* done */, error)

type RedisExpiringHashAPI[K1 any, K2 any] struct {
	client           rueidis.Client
	key1ToRedisKey   KeyToRedisKey[K1]
	key2ToRedisKey   KeyToRedisKey[K2]
	gcCounter        otelmetric.Int64Counter
	gcConflict       otelmetric.Int64Counter
	metricAttributes attribute.Set
}

func NewRedisExpiringHashAPI[K1 any, K2 any](name string, client rueidis.Client, key1ToRedisKey KeyToRedisKey[K1], key2ToRedisKey KeyToRedisKey[K2], m otelmetric.Meter) (*RedisExpiringHashAPI[K1, K2], error) {
	gcCounter, err := m.Int64Counter(
		gcDeletedKeysMetricName,
		otelmetric.WithDescription("Number of keys that have been garbage collected in a single pass"),
	)
	if err != nil {
		return nil, err
	}
	gcConflict, err := m.Int64Counter(
		gcConflictMetricName,
		otelmetric.WithDescription("Number of times garbage collection was aborted due to a concurrent hash mutation"),
	)
	if err != nil {
		return nil, err
	}

	return &RedisExpiringHashAPI[K1, K2]{
		client:           client,
		key1ToRedisKey:   key1ToRedisKey,
		key2ToRedisKey:   key2ToRedisKey,
		gcCounter:        gcCounter,
		gcConflict:       gcConflict,
		metricAttributes: attribute.NewSet(expiringHashNameKey.String(name)),
	}, nil
}

func (h *RedisExpiringHashAPI[K1, K2]) SetBuilder() SetBuilder[K1, K2] {
	return &RedisSetBuilder[K1, K2]{
		client:         h.client,
		key1ToRedisKey: h.key1ToRedisKey,
		key2ToRedisKey: h.key2ToRedisKey,
	}
}

func (h *RedisExpiringHashAPI[K1, K2]) Unset(ctx context.Context, key K1, hashKey K2) error {
	hdelCmd := h.client.B().Hdel().Key(h.key1ToRedisKey(key)).Field(h.key2ToRedisKey(hashKey)).Build()
	return h.client.Do(ctx, hdelCmd).Error()
}

func (h *RedisExpiringHashAPI[K1, K2]) Scan(ctx context.Context, key K1, cb ScanCallback) error {
	now := time.Now().Unix()
	return scan(ctx, h.key1ToRedisKey(key), h.client, func(k, v string) (bool /*done*/, error) {
		var msg ExpiringValue
		// Avoid creating a temporary copy
		vBytes := unsafe.Slice(unsafe.StringData(v), len(v)) //nolint: gosec
		err := proto.Unmarshal(vBytes, &msg)
		if err != nil {
			return cb(k, nil, fmt.Errorf("failed to unmarshal hash value from hashkey 0x%x: %w", k, err))
		}
		if msg.ExpiresAt < now { // skip expired entry
			return false, nil
		}
		return cb(k, msg.Value, nil)
	})
}

func (h *RedisExpiringHashAPI[K1, K2]) GCFor(keys []K1, transactional bool) func(context.Context) (int /* keysDeleted */, error) {
	return func(ctx context.Context) (deletedKeys int, retErr error) {
		defer func() { //nolint:contextcheck
			h.gcCounter.Add(context.Background(), int64(deletedKeys), otelmetric.WithAttributeSet(h.metricAttributes))
		}()

		if transactional {
			return h.gcForTransactional(ctx, keys)
		} else {
			return h.gcForNonTransactional(ctx, keys)
		}
	}
}

func (h *RedisExpiringHashAPI[K1, K2]) gcForNonTransactional(ctx context.Context, keys []K1) (int /* keysDeleted */, error) {
	var deletedKeys int
	for _, key := range keys {
		deleted, err := h.gcHashNonTransactional(ctx, key)
		deletedKeys += deleted
		if err != nil {
			return deletedKeys, err
		}
	}
	return deletedKeys, nil
}

// gcHashNonTransactional iterates a hash and removes all expired values.
// It assumes that values are marshaled ExpiringValue.
func (h *RedisExpiringHashAPI[K1, K2]) gcHashNonTransactional(ctx context.Context, key K1) (int /* keysDeleted */, error) {
	redisKey := h.key1ToRedisKey(key)
	keysToDelete, errs := h.getKeysToGC(ctx, redisKey, h.client)
	if len(keysToDelete) == 0 {
		return 0, errors.Join(errs...)
	}
	delCmd := h.client.B().Hdel().Key(redisKey).Field(keysToDelete...).Build()
	err := h.client.Do(ctx, delCmd).Error()
	if err != nil {
		errs = append(errs, err)
	}
	return len(keysToDelete), errors.Join(errs...)
}

func (h *RedisExpiringHashAPI[K1, K2]) gcForTransactional(ctx context.Context, keys []K1) (int /* keysDeleted */, error) {
	var deletedKeys int
	client, cancel := h.client.Dedicate()
	defer cancel()
	for _, key := range keys {
		deleted, err := h.gcHashTransactional(ctx, key, client)
		deletedKeys += deleted
		switch err { //nolint:errorlint
		case nil, attemptsExceeded:
			// Try to GC next key on conflicts
		default:
			return deletedKeys, err
		}
	}
	return deletedKeys, nil
}

// gcHashTransactional iterates a hash and removes all expired values.
// It assumes that values are marshaled ExpiringValue.
// Returns attemptsExceeded if maxAttempts attempts were made but all failed.
func (h *RedisExpiringHashAPI[K1, K2]) gcHashTransactional(ctx context.Context, key K1, c rueidis.DedicatedClient) (int /* keysDeleted */, error) {
	var errs []error
	keysDeleted := 0
	redisKey := h.key1ToRedisKey(key)
	// We don't want to delete a k->v mapping that has just been overwritten by another client. So use a transaction.
	// We don't want to retry too many times to GC to avoid spending too much time on it. Retry once.
	err := transaction(ctx, maxKeyGCAttempts, c, h.gcConflict, h.metricAttributes, func(ctx context.Context) ([]rueidis.Completed, error) {
		var keysToDelete []string
		keysToDelete, errs = h.getKeysToGC(ctx, redisKey, c)
		keysDeleted = len(keysToDelete)
		if keysDeleted == 0 {
			return nil, nil // errs is handled outside of the closure
		}
		return []rueidis.Completed{
			c.B().Hdel().Key(redisKey).Field(keysToDelete...).Build(),
		}, nil
	}, redisKey)
	if err != nil {
		// Propagate attemptsExceeded error and any other errors as is.
		return 0, err
	}
	return keysDeleted, errors.Join(errs...)
}

func (h *RedisExpiringHashAPI[K1, K2]) getKeysToGC(ctx context.Context, redisKey string, c rueidis.CoreClient) ([]string, []error) {
	var errs []error
	var keysToDelete []string
	now := time.Now().Unix()
	err := scan(ctx, redisKey, c, func(k, v string) (bool /*done*/, error) {
		var msg ExpiringValueTimestamp
		// Avoid creating a temporary copy
		vBytes := unsafe.Slice(unsafe.StringData(v), len(v)) //nolint: gosec
		err := proto.UnmarshalOptions{
			DiscardUnknown: true, // We know there is one more field, but we don't need it
		}.Unmarshal(vBytes, &msg)
		if err != nil {
			errs = append(errs, err)
			return false, nil
		}

		if msg.ExpiresAt < now {
			keysToDelete = append(keysToDelete, k)
		}
		return false, nil
	})
	if err != nil {
		errs = append(errs, err)
	}
	return keysToDelete, errs
}

type scanCb func(k, v string) (bool /*done*/, error)

// scan iterates all keys of a hash and calls the given callback for each key.
// see https://redis.io/commands/scan
func scan(ctx context.Context, redisKey string, c rueidis.CoreClient, cb scanCb) error {
	var se rueidis.ScanEntry
	var err error
	for more := true; more; more = se.Cursor != 0 {
		hscanCmd := c.B().Hscan().Key(redisKey).Cursor(se.Cursor).Count(scanCount).Build()
		se, err = c.Do(ctx, hscanCmd).AsScanEntry()
		if err != nil {
			return err
		}
		if len(se.Elements)%2 != 0 {
			// This shouldn't happen
			return errors.New("invalid Redis reply")
		}
		for i := 0; i < len(se.Elements); i += 2 {
			k := se.Elements[i]
			v := se.Elements[i+1]
			done, err := cb(k, v)
			if err != nil || done {
				return err
			}
		}
	}
	return nil
}