File: noise_protocol.py

package info (click to toggle)
python-noiseprotocol 0.3.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,656 kB
  • sloc: python: 1,259; makefile: 25
file content (139 lines) | stat: -rw-r--r-- 5,547 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import warnings
from functools import partial
from typing import Tuple

from noiseprotocol.exceptions import NoiseProtocolNameError, NoisePSKError, NoiseValidationError
from noiseprotocol.state import HandshakeState
from noiseprotocol.constants import MAX_PROTOCOL_NAME_LEN, Empty


class NoiseProtocol(object):
    """
    TODO: Document
    """
    def __init__(self, protocol_name: bytes, backend: 'NoiseBackend'):
        self.name = protocol_name
        self.backend = backend
        unpacked_name = UnpackedName.from_protocol_name(self.name)
        mappings = self.backend.map_protocol_name_to_crypto(unpacked_name)

        # A valid Pattern instance (see Section 7 of specification (rev 32))
        self.pattern = mappings['pattern']()
        self.pattern_modifiers = unpacked_name.pattern_modifiers
        if self.pattern_modifiers:
            self.pattern.apply_pattern_modifiers(self.pattern_modifiers)

        # Handle PSK handshake options
        self.psks = None
        self.is_psk_handshake = any([modifier.startswith('psk') for modifier in self.pattern_modifiers])

        # Preinitialized
        self.dh_fn = mappings['dh']()
        self.hash_fn = mappings['hash']()
        self.hmac = partial(backend.hmac, algorithm=self.hash_fn.fn)
        self.hkdf = partial(backend.hkdf, hmac_hash_fn=self.hmac)

        # Initialized where needed
        self.cipher_class = mappings['cipher']
        self.keypair_class = mappings['keypair']

        self.prologue = None
        self.initiator = None
        self.handshake_hash = None

        self.handshake_state = Empty()
        self.symmetric_state = Empty()
        self.cipher_state_handshake = Empty()
        self.cipher_state_encrypt = Empty()
        self.cipher_state_decrypt = Empty()

        self.keypairs = {'s': None, 'e': None, 'rs': None, 're': None}

    def handshake_done(self):
        if self.pattern.one_way:
            if self.initiator:
                self.cipher_state_decrypt = None
            else:
                self.cipher_state_encrypt = None
        self.handshake_hash = self.symmetric_state.get_handshake_hash()
        del self.handshake_state
        del self.symmetric_state
        del self.cipher_state_handshake
        del self.prologue
        del self.initiator
        del self.dh_fn
        del self.hash_fn
        del self.keypair_class

    def validate(self):
        if self.is_psk_handshake:
            if any([len(psk) != 32 for psk in self.psks]):
                raise NoisePSKError('Invalid psk length! Has to be 32 bytes long')
            if len(self.psks) != self.pattern.psk_count:
                raise NoisePSKError('Bad number of PSKs provided to this protocol! {} are required, '
                                    'given {}'.format(self.pattern.psk_count, len(self.psks)))

        if self.initiator is None:
            raise NoiseValidationError('You need to set role with NoiseConnection.set_as_initiator '
                                       'or NoiseConnection.set_as_responder')

        for keypair in self.pattern.get_required_keypairs(self.initiator):
            if self.keypairs[keypair] is None:
                raise NoiseValidationError('Keypair {} has to be set for chosen handshake pattern'.format(keypair))

        if self.keypairs['e'] is not None or self.keypairs['re'] is not None:
            warnings.warn('One of ephemeral keypairs is already set. '
                          'This is OK for testing, but should NEVER happen in production!')

    def initialise_handshake_state(self):
        kwargs = {'initiator': self.initiator}
        if self.prologue:
            kwargs['prologue'] = self.prologue
        for keypair, value in self.keypairs.items():
            if value:
                kwargs[keypair] = value
        self.handshake_state = HandshakeState.initialize(self, **kwargs)
        self.symmetric_state = self.handshake_state.symmetric_state


class UnpackedName:
    def __init__(self, pattern, dh, cipher, hash, keypair, pattern_modifiers):
        self.pattern = pattern
        self.dh = dh
        self.cipher = cipher
        self.hash = hash
        self.keypair = keypair
        self.pattern_modifiers = pattern_modifiers

    @classmethod
    def from_protocol_name(cls, name):
        if not isinstance(name, bytes):
            raise NoiseProtocolNameError('Protocol name has to be of type "bytes" not {}'.format(type(name)))
        if len(name) > MAX_PROTOCOL_NAME_LEN:
            raise NoiseProtocolNameError('Protocol name too long, has to be at most '
                                         '{} chars long'.format(MAX_PROTOCOL_NAME_LEN))

        unpacked = name.decode().split('_')
        if unpacked[0] != 'Noise':
            raise NoiseProtocolNameError('Noise Protocol name shall begin with Noise! Provided: {}'.format(name))

        # Extract pattern name and pattern modifiers
        pattern = ''
        modifiers_str = None
        for i, char in enumerate(unpacked[1]):
            if char.isupper():
                pattern += char
            else:
                # End of pattern, now look for modifiers
                modifiers_str = unpacked[1][i:]  # Will be empty string if it exceeds string size
                break
        modifiers = modifiers_str.split('+') if modifiers_str else []

        return cls(
            pattern=pattern,
            dh=unpacked[2],
            cipher=unpacked[3],
            hash=unpacked[4],
            keypair=unpacked[2],
            pattern_modifiers=modifiers
        )