File: utils.go

package info (click to toggle)
go-mmproxy 2.2.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 248 kB
  • sloc: makefile: 12; sh: 7
file content (135 lines) | stat: -rw-r--r-- 3,647 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright 2019 Path Network, Inc. All rights reserved.
// Copyright 2024 Konrad Zemek <konrad.zemek@gmail.com>
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package utils

import (
	"fmt"
	"net"
	"net/netip"
	"strconv"
	"syscall"
	"time"
)

type Protocol int

const (
	TCP Protocol = iota
	UDP
)

type Options struct {
	Protocol           Protocol
	ListenAddr         netip.AddrPort
	TargetAddr4        netip.AddrPort
	TargetAddr6        netip.AddrPort
	DynamicDestination bool
	Mark               int
	Verbose            int
	AllowedSubnets     []netip.Prefix
	UDPCloseAfter      time.Duration
}

func CheckOriginAllowed(remoteIP netip.Addr, allowedSubnets []netip.Prefix) bool {
	if len(allowedSubnets) == 0 {
		return true
	}

	for _, ipNet := range allowedSubnets {
		if ipNet.Contains(remoteIP) {
			return true
		}
	}
	return false
}

func ParseHostPort(hostport string, ipVersion int) (netip.AddrPort, error) {
	host, portStr, err := net.SplitHostPort(hostport)
	if err != nil {
		return netip.AddrPort{}, fmt.Errorf("failed to parse host and port: %w", err)
	}

	ips, err := net.LookupIP(host)
	if err != nil {
		return netip.AddrPort{}, fmt.Errorf("failed to lookup IP addresses: %w", err)
	}

	filteredIPs := make([]netip.Addr, 0, len(ips))
	for _, stdip := range ips {
		ip := netip.MustParseAddr(stdip.String())
		if ipVersion == 0 || (ip.Is4() && ipVersion == 4) || (ip.Is6() && ipVersion == 6) {
			filteredIPs = append(filteredIPs, ip)
		}
	}

	if len(filteredIPs) == 0 {
		return netip.AddrPort{}, fmt.Errorf("no IP addresses found")
	}

	port, err := strconv.ParseUint(portStr, 10, 16)
	if err != nil {
		return netip.AddrPort{}, fmt.Errorf("failed to parse port: %w", err)
	}

	return netip.AddrPortFrom(filteredIPs[0], uint16(port)), nil
}

func DialUpstreamControl(sport uint16, protocol Protocol, mark int) func(string, string, syscall.RawConn) error {
	return func(network, address string, c syscall.RawConn) error {
		var syscallErr error
		err := c.Control(func(fd uintptr) {
			if protocol == TCP {
				syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, syscall.TCP_SYNCNT, 2)
				if syscallErr != nil {
					syscallErr = fmt.Errorf("setsockopt(IPPROTO_TCP, TCP_SYNCTNT, 2): %w", syscallErr)
					return
				}
			}

			syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TRANSPARENT, 1)
			if syscallErr != nil {
				syscallErr = fmt.Errorf("setsockopt(IPPROTO_IP, IP_TRANSPARENT, 1): %w", syscallErr)
				return
			}

			syscallErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
			if syscallErr != nil {
				syscallErr = fmt.Errorf("setsockopt(SOL_SOCKET, SO_REUSEADDR, 1): %w", syscallErr)
				return
			}

			if sport == 0 {
				ipBindAddressNoPort := 24
				syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ipBindAddressNoPort, 1)
				if syscallErr != nil {
					syscallErr = fmt.Errorf("setsockopt(IPPROTO_IP, IP_BIND_ADDRESS_NO_PORT, 1): %w", syscallErr)
					return
				}
			}

			if mark != 0 {
				syscallErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark)
				if syscallErr != nil {
					syscallErr = fmt.Errorf("setsockopt(SOL_SOCK, SO_MARK, %d): %w", mark, syscallErr)
					return
				}
			}

			if network == "tcp6" || network == "udp6" {
				syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
				if syscallErr != nil {
					syscallErr = fmt.Errorf("setsockopt(IPPROTO_IP, IPV6_ONLY, 0): %w", syscallErr)
					return
				}
			}
		})

		if err != nil {
			return err
		}
		return syscallErr
	}
}