File: realip.go

package info (click to toggle)
golang-github-grpc-ecosystem-go-grpc-middleware 2.1.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,464 kB
  • sloc: makefile: 107; sh: 9
file content (180 lines) | stat: -rw-r--r-- 5,747 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.

package realip

import (
	"context"
	"net"
	"net/netip"
	"strings"

	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/peer"
)

// XRealIp, XForwardedFor and TrueClientIp are header keys
// used to extract the real client IP from the request. They represent common
// conventions for identifying the originating IP address of a client connecting
// through proxies or load balancers.
const (
	XRealIp       = "X-Real-IP"
	XForwardedFor = "X-Forwarded-For"
	TrueClientIp  = "True-Client-IP"
)

var noIP = netip.Addr{}

type realipKey struct{}

// FromContext extracts the real client IP from the context.
// It returns the IP and a boolean indicating if it was present.
func FromContext(ctx context.Context) (netip.Addr, bool) {
	ip, ok := ctx.Value(realipKey{}).(netip.Addr)
	return ip, ok
}

func remotePeer(ctx context.Context) net.Addr {
	pr, ok := peer.FromContext(ctx)
	if !ok {
		return nil
	}
	return pr.Addr
}

func ipInNets(ip netip.Addr, nets []netip.Prefix) bool {
	for _, n := range nets {
		if n.Contains(ip) {
			return true
		}
	}
	return false
}

func getHeader(ctx context.Context, key string) string {
	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		return ""
	}

	if md[strings.ToLower(key)] == nil {
		return ""
	}

	return md[strings.ToLower(key)][0]
}

func ipFromXForwardedFoR(trustedProxies []netip.Prefix, ips []string, idx int) netip.Addr {
	for i := idx; i >= 0; i-- {
		h := strings.TrimSpace(ips[i])
		ip, err := netip.ParseAddr(h)
		if err != nil {
			return noIP
		}
		if !ipInNets(ip, trustedProxies) {
			return ip
		}
	}
	return noIP
}

func ipFromHeaders(ctx context.Context, headers []string, trustedProxies []netip.Prefix, trustedProxyCnt uint) netip.Addr {
	for _, header := range headers {
		a := strings.Split(getHeader(ctx, header), ",")
		idx := len(a) - 1
		if header == XForwardedFor {
			idx = idx - int(trustedProxyCnt)
			if idx < 0 {
				continue
			}
			return ipFromXForwardedFoR(trustedProxies, a, idx)
		}
		h := strings.TrimSpace(a[idx])
		ip, err := netip.ParseAddr(h)
		if err == nil {
			return ip
		}
	}
	return noIP
}

func getRemoteIP(ctx context.Context, trustedPeers, trustedProxies []netip.Prefix, headers []string, proxyCnt uint) netip.Addr {
	pr := remotePeer(ctx)
	if pr == nil {
		return noIP
	}

	addrPort, err := netip.ParseAddrPort(pr.String())
	if err != nil {
		return noIP
	}
	ip := addrPort.Addr()

	if len(trustedPeers) == 0 || !ipInNets(ip, trustedPeers) {
		return ip
	}
	if ip := ipFromHeaders(ctx, headers, trustedProxies, proxyCnt); ip != noIP {
		return ip
	}
	// No ip from the headers, return the peer ip.
	return ip
}

type serverStream struct {
	grpc.ServerStream
	ctx context.Context
}

func (s *serverStream) Context() context.Context {
	return s.ctx
}

// UnaryServerInterceptor returns a new unary server interceptor that extracts the real client IP from request headers.
// It checks if the request comes from a trusted peer, and if so, extracts the IP from the configured headers.
// The real IP is added to the request context.
// See UnaryServerInterceptorOpts as it allows to configure trusted proxy ips list and count that should work better with Google LB
func UnaryServerInterceptor(trustedPeers []netip.Prefix, headers []string) grpc.UnaryServerInterceptor {
	return UnaryServerInterceptorOpts(WithTrustedPeers(trustedPeers), WithHeaders(headers))
}

// StreamServerInterceptor returns a new stream server interceptor that extracts the real client IP from request headers.
// It checks if the request comes from a trusted peer, and if so, extracts the IP from the configured headers.
// The real IP is added to the request context.
// See UnaryServerInterceptorOpts as it allows to configure trusted proxy ips list and count that should work better with Google LB
func StreamServerInterceptor(trustedPeers []netip.Prefix, headers []string) grpc.StreamServerInterceptor {
	return StreamServerInterceptorOpts(WithTrustedPeers(trustedPeers), WithHeaders(headers))
}

// UnaryServerInterceptorOpts returns a new unary server interceptor that extracts the real client IP from request headers.
// It checks if the request comes from a trusted peer, validates headers against trusted proxies list and trusted proxies count
// then it extracts the IP from the configured headers.
// The real IP is added to the request context.
func UnaryServerInterceptorOpts(opts ...Option) grpc.UnaryServerInterceptor {
	o := evaluateOpts(opts)
	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
		ip := getRemoteIP(ctx, o.trustedPeers, o.trustedProxies, o.headers, o.trustedProxiesCount)
		if ip != noIP {
			ctx = context.WithValue(ctx, realipKey{}, ip)
		}
		return handler(ctx, req)
	}
}

// StreamServerInterceptorOpts returns a new stream server interceptor that extracts the real client IP from request headers.
// It checks if the request comes from a trusted peer, validates headers against trusted proxies list and trusted proxies count
// then it extracts the IP from the configured headers.
// The real IP is added to the request context.
func StreamServerInterceptorOpts(opts ...Option) grpc.StreamServerInterceptor {
	o := evaluateOpts(opts)
	return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		ip := getRemoteIP(stream.Context(), o.trustedPeers, o.trustedProxies, o.headers, o.trustedProxiesCount)
		if ip != noIP {
			return handler(srv, &serverStream{
				ServerStream: stream,
				ctx:          context.WithValue(stream.Context(), realipKey{}, ip),
			})
		}
		return handler(srv, stream)
	}
}