File: common.py

package info (click to toggle)
python-jwcrypto 1.1.0-1%2Bdeb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 552 kB
  • sloc: python: 5,386; makefile: 177
file content (190 lines) | stat: -rw-r--r-- 5,721 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file

import copy
import json
from base64 import urlsafe_b64decode, urlsafe_b64encode
from collections import namedtuple
from collections.abc import MutableMapping

# Padding stripping versions as described in
# RFC 7515 Appendix C


def base64url_encode(payload):
    if not isinstance(payload, bytes):
        payload = payload.encode('utf-8')
    encode = urlsafe_b64encode(payload)
    return encode.decode('utf-8').rstrip('=')


def base64url_decode(payload):
    size = len(payload) % 4
    if size == 2:
        payload += '=='
    elif size == 3:
        payload += '='
    elif size != 0:
        raise ValueError('Invalid base64 string')
    return urlsafe_b64decode(payload.encode('utf-8'))


# JSON encoding/decoding helpers with good defaults

def json_encode(string):
    if isinstance(string, bytes):
        string = string.decode('utf-8')
    return json.dumps(string, separators=(',', ':'), sort_keys=True)


def json_decode(string):
    if isinstance(string, bytes):
        string = string.decode('utf-8')
    return json.loads(string)


class JWException(Exception):
    pass


class InvalidJWAAlgorithm(JWException):
    def __init__(self, message=None):
        msg = 'Invalid JWA Algorithm name'
        if message:
            msg += ' (%s)' % message
        super(InvalidJWAAlgorithm, self).__init__(msg)


class InvalidCEKeyLength(JWException):
    """Invalid CEK Key Length.

    This exception is raised when a Content Encryption Key does not match
    the required length.
    """

    def __init__(self, expected, obtained):
        msg = 'Expected key of length %d bits, got %d' % (expected, obtained)
        super(InvalidCEKeyLength, self).__init__(msg)


class InvalidJWEOperation(JWException):
    """Invalid JWS Object.

    This exception is raised when a requested operation cannot
    be execute due to unsatisfied conditions.
    """

    def __init__(self, message=None, exception=None):
        msg = None
        if message:
            msg = message
        else:
            msg = 'Unknown Operation Failure'
        if exception:
            msg += ' {%s}' % repr(exception)
        super(InvalidJWEOperation, self).__init__(msg)


class InvalidJWEKeyType(JWException):
    """Invalid JWE Key Type.

    This exception is raised when the provided JWK Key does not match
    the type required by the specified algorithm.
    """

    def __init__(self, expected, obtained):
        msg = 'Expected key type %s, got %s' % (expected, obtained)
        super(InvalidJWEKeyType, self).__init__(msg)


class InvalidJWEKeyLength(JWException):
    """Invalid JWE Key Length.

    This exception is raised when the provided JWK Key does not match
    the length required by the specified algorithm.
    """

    def __init__(self, expected, obtained):
        msg = 'Expected key of length %d, got %d' % (expected, obtained)
        super(InvalidJWEKeyLength, self).__init__(msg)


class InvalidJWSERegOperation(JWException):
    """Invalid JWSE Header Registry Operation.

    This exception is raised when there is an error in trying to add a JW
    Signature or Encryption header to the Registry.
    """

    def __init__(self, message=None, exception=None):
        msg = None
        if message:
            msg = message
        else:
            msg = 'Unknown Operation Failure'
        if exception:
            msg += ' {%s}' % repr(exception)
        super(InvalidJWSERegOperation, self).__init__(msg)


# JWSE Header Registry definitions

# RFC 7515 - 9.1: JSON Web Signature and Encryption Header Parameters Registry
# HeaderParameters are for both JWS and JWE
JWSEHeaderParameter = namedtuple('Parameter',
                                 'description mustprotect supported check_fn')


class JWSEHeaderRegistry(MutableMapping):
    def __init__(self, init_registry=None):
        if init_registry:
            if isinstance(init_registry, dict):
                self._registry = copy.deepcopy(init_registry)
            else:
                raise InvalidJWSERegOperation('Unknown input type')
        else:
            self._registry = {}

        MutableMapping.__init__(self)

    def check_header(self, h, value):
        if h not in self._registry:
            raise InvalidJWSERegOperation('No header "%s" found in registry'
                                          % h)

        param = self._registry[h]
        if param.check_fn is None:
            return True
        else:
            return param.check_fn(value)

    def __getitem__(self, key):
        return self._registry.__getitem__(key)

    def __iter__(self):
        return self._registry.__iter__()

    def __delitem__(self, key):
        if self._registry[key].mustprotect or \
           self._registry[key].supported:
            raise InvalidJWSERegOperation('Unable to delete protected or '
                                          'supported field')
        else:
            self._registry.__delitem__(key)

    def __setitem__(self, h, jwse_header_param):
        # Check if a header is not supported
        if h in self._registry:
            p = self._registry[h]
            if p.supported:
                raise InvalidJWSERegOperation('Supported header already exists'
                                              ' in registry')
            elif p.mustprotect and not jwse_header_param.mustprotect:
                raise InvalidJWSERegOperation('Header specified should be'
                                              'a protected header')
            else:
                del self._registry[h]

        self._registry[h] = jwse_header_param

    def __len__(self):
        return self._registry.__len__()