1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
|
package agent_tracker
import (
"context"
"fmt"
"strconv"
"sync"
"time"
"github.com/redis/rueidis"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/errz"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/logz"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/redistool"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/syncz"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
)
const (
refreshOverlap = 5 * time.Second
connectedAgentsKey int64 = 0
)
type ConnectedAgentInfoCallback func(*ConnectedAgentInfo) (done bool, err error)
type Registerer interface {
// RegisterConnection registers connection with the tracker.
RegisterConnection(ctx context.Context, info *ConnectedAgentInfo) error
// UnregisterConnection unregisters connection with the tracker.
UnregisterConnection(ctx context.Context, info *ConnectedAgentInfo) error
}
type Querier interface {
GetConnectionsByAgentId(ctx context.Context, agentId int64, cb ConnectedAgentInfoCallback) error
GetConnectionsByProjectId(ctx context.Context, projectId int64, cb ConnectedAgentInfoCallback) error
GetConnectedAgentsCount(ctx context.Context) (int64, error)
}
type Tracker interface {
Registerer
Querier
Run(ctx context.Context) error
}
type RedisTracker struct {
log *zap.Logger
errRep errz.ErrReporter
refreshPeriod time.Duration
gcPeriod time.Duration
// refreshMu is exclusively held during refresh process and non-exclusively held during de-registration.
// This ensures refresh and de-registration never happen concurrently and hence just unregistered connections are
// never written back into Redis by refresh process.
refreshMu syncz.RWMutex
// mu protects fields below
mu sync.Mutex
connectionsByAgentId redistool.ExpiringHashInterface[int64, int64] // agentId -> connectionId -> info
connectionsByProjectId redistool.ExpiringHashInterface[int64, int64] // projectId -> connectionId -> info
connectedAgents redistool.ExpiringHashInterface[int64, int64] // hash name -> agentId -> ""
}
func NewRedisTracker(log *zap.Logger, errRep errz.ErrReporter, client rueidis.Client, agentKeyPrefix string, ttl, refreshPeriod, gcPeriod time.Duration) *RedisTracker {
return &RedisTracker{
log: log,
errRep: errRep,
refreshPeriod: refreshPeriod,
gcPeriod: gcPeriod,
refreshMu: syncz.NewRWMutex(),
connectionsByAgentId: redistool.NewExpiringHash(client, connectionsByAgentIdHashKey(agentKeyPrefix), int64ToStr, ttl),
connectionsByProjectId: redistool.NewExpiringHash(client, connectionsByProjectIdHashKey(agentKeyPrefix), int64ToStr, ttl),
connectedAgents: redistool.NewExpiringHash(client, connectedAgentsHashKey(agentKeyPrefix), int64ToStr, ttl),
}
}
func (t *RedisTracker) Run(ctx context.Context) error {
refreshTicker := time.NewTicker(t.refreshPeriod)
defer refreshTicker.Stop()
gcTicker := time.NewTicker(t.gcPeriod)
defer gcTicker.Stop()
done := ctx.Done()
for {
select {
case <-done:
return nil
case <-refreshTicker.C:
t.refreshRegistrations(ctx, time.Now().Add(t.refreshPeriod-refreshOverlap))
case <-gcTicker.C:
keysDeleted := t.runGC(ctx)
if keysDeleted > 0 {
t.log.Info("Deleted expired agent connections records", logz.RemovedHashKeys(keysDeleted))
}
}
}
}
func (t *RedisTracker) RegisterConnection(ctx context.Context, info *ConnectedAgentInfo) error {
infoBytes, err := proto.Marshal(info)
if err != nil {
// This should never happen
return fmt.Errorf("failed to marshal object: %w", err)
}
return t.runIOFuncs(ctx, func() []redistool.IOFunc {
return []redistool.IOFunc{
t.connectionsByProjectId.Set(info.ProjectId, info.ConnectionId, infoBytes),
t.connectionsByAgentId.Set(info.AgentId, info.ConnectionId, infoBytes),
t.connectedAgents.Set(connectedAgentsKey, info.AgentId, nil),
}
})
}
func (t *RedisTracker) UnregisterConnection(ctx context.Context, info *ConnectedAgentInfo) error {
if !t.refreshMu.RLock(ctx) {
return ctx.Err()
}
defer t.refreshMu.RUnlock()
return t.runIOFuncs(ctx, func() []redistool.IOFunc {
t.connectedAgents.Forget(connectedAgentsKey, info.AgentId)
return []redistool.IOFunc{
t.connectionsByProjectId.Unset(info.ProjectId, info.ConnectionId),
t.connectionsByAgentId.Unset(info.AgentId, info.ConnectionId),
}
})
}
func (t *RedisTracker) GetConnectionsByAgentId(ctx context.Context, agentId int64, cb ConnectedAgentInfoCallback) error {
return t.getConnectionsByKey(ctx, t.connectionsByAgentId, agentId, cb)
}
func (t *RedisTracker) GetConnectionsByProjectId(ctx context.Context, projectId int64, cb ConnectedAgentInfoCallback) error {
return t.getConnectionsByKey(ctx, t.connectionsByProjectId, projectId, cb)
}
func (t *RedisTracker) GetConnectedAgentsCount(ctx context.Context) (int64, error) {
return t.connectedAgents.Len(ctx, connectedAgentsKey)
}
func (t *RedisTracker) refreshRegistrations(ctx context.Context, nextRefresh time.Time) {
if !t.refreshMu.Lock(ctx) {
return
}
defer t.refreshMu.Unlock()
refreshFuncs := syncz.RunWithMutex(&t.mu, func() []redistool.IOFunc {
return []redistool.IOFunc{
t.connectionsByProjectId.Refresh(nextRefresh),
t.connectionsByAgentId.Refresh(nextRefresh),
t.connectedAgents.Refresh(nextRefresh),
}
})
// No rush so run refresh sequentially to not stress RAM/CPU/Redis/network.
// We have more important work to do that we shouldn't impact.
for _, refresh := range refreshFuncs {
err := refresh(ctx)
if err != nil {
if errz.ContextDone(err) {
t.log.Debug("Redis hash data refresh interrupted", logz.Error(err))
break
}
t.errRep.HandleProcessingError(ctx, t.log, "Failed to refresh hash data in Redis", err)
// continue anyway
}
}
}
func (t *RedisTracker) runGC(ctx context.Context) int {
gcFuncs := syncz.RunWithMutex(&t.mu, func() []func(context.Context) (int, error) {
return []func(context.Context) (int, error){
t.connectionsByProjectId.GC(),
t.connectionsByAgentId.GC(),
t.connectedAgents.GC(),
}
})
keysDeleted := 0
// No rush so run GC sequentially to not stress RAM/CPU/Redis/network.
// We have more important work to do that we shouldn't impact.
for _, gc := range gcFuncs {
deleted, err := gc(ctx)
if err != nil {
if errz.ContextDone(err) {
t.log.Debug("Redis GC interrupted", logz.Error(err))
break
}
t.errRep.HandleProcessingError(ctx, t.log, "Failed to GC data in Redis", err)
// continue anyway
}
keysDeleted += deleted
}
return keysDeleted
}
func (t *RedisTracker) runIOFuncs(ctx context.Context, f func() []redistool.IOFunc) error {
ios := syncz.RunWithMutex(&t.mu, f)
var g errgroup.Group
for _, s := range ios {
s := s
g.Go(func() error {
return s(ctx)
})
}
return g.Wait()
}
func (t *RedisTracker) getConnectionsByKey(ctx context.Context, hash redistool.ExpiringHashInterface[int64, int64], key int64, cb ConnectedAgentInfoCallback) error {
_, err := hash.Scan(ctx, key, func(rawHashKey string, value []byte, err error) (bool, error) {
if err != nil {
t.errRep.HandleProcessingError(ctx, t.log, "Redis hash scan", err)
return false, nil
}
var info ConnectedAgentInfo
err = proto.Unmarshal(value, &info)
if err != nil {
t.errRep.HandleProcessingError(ctx, t.log, "Redis proto.Unmarshal(ConnectedAgentInfo)", err)
return false, nil
}
return cb(&info)
})
return err
}
// connectionsByAgentIdHashKey returns a key for agentId -> (connectionId -> marshaled ConnectedAgentInfo).
func connectionsByAgentIdHashKey(agentKeyPrefix string) redistool.KeyToRedisKey[int64] {
prefix := agentKeyPrefix + ":conn_by_agent_id:"
return func(agentId int64) string {
return redistool.PrefixedInt64Key(prefix, agentId)
}
}
// connectionsByProjectIdHashKey returns a key for projectId -> (agentId ->marshaled ConnectedAgentInfo).
func connectionsByProjectIdHashKey(agentKeyPrefix string) redistool.KeyToRedisKey[int64] {
prefix := agentKeyPrefix + ":conn_by_project_id:"
return func(projectId int64) string {
return redistool.PrefixedInt64Key(prefix, projectId)
}
}
// connectedAgentsHashKey returns the key for the hash of connected agents.
func connectedAgentsHashKey(agentKeyPrefix string) redistool.KeyToRedisKey[int64] {
prefix := agentKeyPrefix + ":connected_agents"
return func(_ int64) string {
return prefix
}
}
type ConnectedAgentInfoCollector []*ConnectedAgentInfo
func (c *ConnectedAgentInfoCollector) Collect(info *ConnectedAgentInfo) (bool, error) {
*c = append(*c, info)
return false, nil
}
func int64ToStr(key int64) string {
return strconv.FormatInt(key, 10)
}
|