File: vpnconfiguration.py

package info (click to toggle)
python-proton-vpn-api-core 0.39.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 892 kB
  • sloc: python: 6,582; makefile: 8
file content (187 lines) | stat: -rw-r--r-- 6,297 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
This module defines the classes holding the necessary configuration to establish
a VPN connection.


Copyright (c) 2023 Proton AG

This file is part of Proton VPN.

Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
import ipaddress
import tempfile
import os

from jinja2 import Environment, BaseLoader
from proton.utils.environment import ExecutionEnvironment

from proton.vpn.connection.constants import \
    CA_CERT, OPENVPN_V2_TEMPLATE, WIREGUARD_TEMPLATE


class VPNConfiguration:
    """Base VPN configuration."""
    PROTOCOL = None
    EXTENSION = None

    def __init__(self, vpnserver, vpncredentials, settings, use_certificate=False):
        self._configfile = None
        self._configfile_enter_level = None
        self._vpnserver = vpnserver
        self._vpncredentials = vpncredentials
        self._settings = settings
        self.use_certificate = use_certificate

    @classmethod
    def from_factory(cls, protocol):
        """Returns the configuration class based on the specified protocol."""
        protocols = {
            "openvpn-tcp": OpenVPNTCPConfig,
            "openvpn-udp": OpenVPNUDPConfig,
            "wireguard": WireguardConfig,
        }

        return protocols[protocol]

    def __enter__(self):
        # We create the configuration file when we enter,
        # and delete it when we exit.
        # This is a race free way of having temporary files.

        if self._configfile is None:
            self._delete_existing_configuration()
            # NOTE: we should try to keep filename length
            # below 15 characters, including the prefix.
            self._configfile = tempfile.NamedTemporaryFile(
                dir=self.__base_path, delete=False,
                prefix='pvpn', suffix=self.EXTENSION, mode='w'
            )
            self._configfile.write(self.generate())
            self._configfile.close()
            self._configfile_enter_level = 0

        self._configfile_enter_level += 1

        return self._configfile.name

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._configfile is None:
            return

        self._configfile_enter_level -= 1
        if self._configfile_enter_level == 0:
            os.unlink(self._configfile.name)
            self._configfile = None

    def _delete_existing_configuration(self):
        for file in self.__base_path:
            if file.endswith(f".{self.EXTENSION}"):
                os.remove(os.path.join(self.__base_path, file))

    def generate(self) -> str:
        """Generates the configuration file content."""
        raise NotImplementedError

    @property
    def __base_path(self):
        return ExecutionEnvironment().path_runtime

    @staticmethod
    def cidr_to_netmask(cidr) -> str:
        """Returns the subnet netmask from the CIDR."""
        subnet = ipaddress.IPv4Network(f"0.0.0.0/{cidr}")
        return str(subnet.netmask)

    @staticmethod
    def is_valid_ipv4(ip_address) -> bool:
        """Returns True if the specified ip address is a valid IPv4 address,
        and False otherwise."""
        try:
            ipaddress.ip_address(ip_address)
        except ValueError:
            return False

        return True


class OVPNConfig(VPNConfiguration):
    """OpenVPN-specific configuration."""
    PROTOCOL = None
    EXTENSION = ".ovpn"

    def generate(self) -> str:
        """Method that generates a vpn config file.

        Returns:
            string: configuration file
        """
        openvpn_ports = self._vpnserver.openvpn_ports
        ports = openvpn_ports.tcp if "tcp" == self.PROTOCOL else openvpn_ports.udp

        enable_ipv6_support = self._vpnserver.has_ipv6_support and self._settings.ipv6

        j2_values = {
            "enable_ipv6_support": enable_ipv6_support,
            "openvpn_protocol": self.PROTOCOL,
            "serverlist": [self._vpnserver.server_ip],
            "openvpn_ports": ports,
            "ca_certificate": CA_CERT,
            "certificate_based": self.use_certificate,
        }

        if self.use_certificate:
            j2_values["cert"] = self._vpncredentials.pubkey_credentials.certificate_pem
            j2_values["priv_key"] = self._vpncredentials.pubkey_credentials.openvpn_private_key

        template =\
            (Environment(loader=BaseLoader, autoescape=True)  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.flask.security.xss.audit.direct-use-of-jinja2.direct-use-of-jinja2
                .from_string(OPENVPN_V2_TEMPLATE))

        return template.render(j2_values)


class OpenVPNTCPConfig(OVPNConfig):
    """Configuration for OpenVPN using TCP."""
    PROTOCOL = "tcp"


class OpenVPNUDPConfig(OVPNConfig):
    """Configuration for OpenVPN using UDP."""
    PROTOCOL = "udp"


class WireguardConfig(VPNConfiguration):
    """Wireguard-specific configuration."""
    PROTOCOL = "wireguard"
    EXTENSION = ".conf"

    def generate(self) -> str:
        """Method that generates a wireguard vpn configuration.
        """

        if not self.use_certificate:
            raise RuntimeError("Wireguards expects certificate configuration")

        j2_values = {
            "wg_client_secret_key": self._vpncredentials.pubkey_credentials.wg_private_key,
            "wg_ip": self._vpnserver.server_ip,
            "wg_port": self._vpnserver.wireguard_ports.udp[0],
            "wg_server_pk": self._vpnserver.x25519pk,
        }

        template =\
            (Environment(loader=BaseLoader, autoescape=True)  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.flask.security.xss.audit.direct-use-of-jinja2.direct-use-of-jinja2
                .from_string(WIREGUARD_TEMPLATE))
        return template.render(j2_values)