1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
|
package redistool
import (
"context"
"encoding/base64"
"errors"
"unsafe"
"github.com/redis/rueidis"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/logz"
"go.uber.org/zap"
"k8s.io/utils/clock"
)
type RPCAPI interface {
Log() *zap.Logger
HandleProcessingError(msg string, err error)
RequestKey() []byte
}
// TokenLimiter is a redis-based rate limiter implementing the algorithm in https://redislabs.com/redis-best-practices/basic-rate-limiting/
type TokenLimiter struct {
redisClient rueidis.Client
clock clock.PassiveClock
keyPrefix string
limitPerMinute uint64
getAPI func(context.Context) RPCAPI
}
// NewTokenLimiter returns a new TokenLimiter
func NewTokenLimiter(redisClient rueidis.Client, keyPrefix string,
limitPerMinute uint64, getAPI func(context.Context) RPCAPI) *TokenLimiter {
return &TokenLimiter{
redisClient: redisClient,
clock: clock.RealClock{},
keyPrefix: keyPrefix,
limitPerMinute: limitPerMinute,
getAPI: getAPI,
}
}
// Allow consumes one limitable event from the token in the context
func (l *TokenLimiter) Allow(ctx context.Context) bool {
api := l.getAPI(ctx)
key := buildTokenLimiterKey(l.keyPrefix, api.RequestKey(), byte(l.clock.Now().UTC().Minute()))
getCmd := l.redisClient.B().Get().Key(key).Build()
count, err := l.redisClient.Do(ctx, getCmd).AsUint64()
if err != nil {
if err != rueidis.Nil { //nolint:errorlint
api.HandleProcessingError("redistool.TokenLimiter: error retrieving minute bucket count", err)
return false
}
count = 0
}
if count >= l.limitPerMinute {
api.Log().Debug("redistool.TokenLimiter: rate limit exceeded",
logz.RedisKey([]byte(key)), logz.U64Count(count), logz.TokenLimit(l.limitPerMinute))
return false
}
resp := l.redisClient.DoMulti(ctx,
l.redisClient.B().Multi().Build(),
l.redisClient.B().Incr().Key(key).Build(),
l.redisClient.B().Expire().Key(key).Seconds(59).Build(),
l.redisClient.B().Exec().Build(),
)
err = errors.Join(MultiErrors(resp)...)
if err != nil {
api.HandleProcessingError("redistool.TokenLimiter: error while incrementing token key count", err)
return false
}
return true
}
func buildTokenLimiterKey(keyPrefix string, requestKey []byte, currentMinute byte) string {
result := make([]byte, 0, len(keyPrefix)+1+len(requestKey)+1+1)
result = append(result, keyPrefix...)
result = append(result, ':')
result = append(result, base64.StdEncoding.EncodeToString(requestKey)...)
result = append(result, ':', currentMinute)
return unsafe.String(unsafe.SliceData(result), len(result)) //nolint: gosec
}
|