File: generate_psa_wrappers.py

package info (click to toggle)
modsecurity 3.0.14-1
  • links: PTS
  • area: main
  • in suites: sid, trixie
  • size: 88,920 kB
  • sloc: ansic: 174,512; sh: 43,569; cpp: 26,214; python: 15,734; makefile: 3,864; yacc: 2,947; lex: 1,359; perl: 1,243; php: 42; tcl: 4
file content (257 lines) | stat: -rwxr-xr-x 10,846 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#!/usr/bin/env python3
"""Generate wrapper functions for PSA function calls.
"""

# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later

### WARNING: the code in this file has not been extensively reviewed yet.
### We do not think it is harmful, but it may be below our normal standards
### for robustness and maintainability.

import argparse
import itertools
import os
from typing import Iterator, List, Optional, Tuple

import scripts_path #pylint: disable=unused-import
from mbedtls_dev import build_tree
from mbedtls_dev import c_parsing_helper
from mbedtls_dev import c_wrapper_generator
from mbedtls_dev import typing_util


class BufferParameter:
    """Description of an input or output buffer parameter sequence to a PSA function."""
    #pylint: disable=too-few-public-methods

    def __init__(self, i: int, is_output: bool,
                 buffer_name: str, size_name: str) -> None:
        """Initialize the parameter information.

        i is the index of the function argument that is the pointer to the buffer.
        The size is argument i+1. For a variable-size output, the actual length
        goes in argument i+2.

        buffer_name and size_names are the names of arguments i and i+1.
        This class does not yet help with the output length.
        """
        self.index = i
        self.buffer_name = buffer_name
        self.size_name = size_name
        self.is_output = is_output


class PSAWrapperGenerator(c_wrapper_generator.Base):
    """Generate a C source file containing wrapper functions for PSA Crypto API calls."""

    _CPP_GUARDS = ('defined(MBEDTLS_PSA_CRYPTO_C) && ' +
                   'defined(MBEDTLS_TEST_HOOKS) && \\\n    ' +
                   '!defined(RECORD_PSA_STATUS_COVERAGE_LOG)')
    _WRAPPER_NAME_PREFIX = 'mbedtls_test_wrap_'
    _WRAPPER_NAME_SUFFIX = ''

    def gather_data(self) -> None:
        root_dir = build_tree.guess_mbedtls_root()
        for header_name in ['crypto.h', 'crypto_extra.h']:
            header_path = os.path.join(root_dir, 'include', 'psa', header_name)
            c_parsing_helper.read_function_declarations(self.functions, header_path)

    _SKIP_FUNCTIONS = frozenset([
        'mbedtls_psa_external_get_random', # not a library function
        'psa_get_key_domain_parameters', # client-side function
        'psa_get_key_slot_number', # client-side function
        'psa_key_derivation_verify_bytes', # not implemented yet
        'psa_key_derivation_verify_key', # not implemented yet
        'psa_set_key_domain_parameters', # client-side function
    ])

    def _skip_function(self, function: c_wrapper_generator.FunctionInfo) -> bool:
        if function.return_type != 'psa_status_t':
            return True
        if function.name in self._SKIP_FUNCTIONS:
            return True
        return False

    # PAKE stuff: not implemented yet
    _PAKE_STUFF = frozenset([
        'psa_crypto_driver_pake_inputs_t *',
        'psa_pake_cipher_suite_t *',
    ])

    def _return_variable_name(self,
                              function: c_wrapper_generator.FunctionInfo) -> str:
        """The name of the variable that will contain the return value."""
        if function.return_type == 'psa_status_t':
            return 'status'
        return super()._return_variable_name(function)

    _FUNCTION_GUARDS = c_wrapper_generator.Base._FUNCTION_GUARDS.copy() \
        #pylint: disable=protected-access
    _FUNCTION_GUARDS.update({
        'mbedtls_psa_register_se_key': 'defined(MBEDTLS_PSA_CRYPTO_SE_C)',
        'mbedtls_psa_inject_entropy': 'defined(MBEDTLS_PSA_INJECT_ENTROPY)',
        'mbedtls_psa_external_get_random': 'defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)',
        'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)',
    })

    @staticmethod
    def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo],
                                  argument_names: List[str]) -> Iterator[BufferParameter]:
        """Detect function arguments that are buffers (pointer, size [,length])."""
        types = ['' if arg.suffix else arg.type for arg in arguments]
        # pairs = list of (type_of_arg_N, type_of_arg_N+1)
        # where each type_of_arg_X is the empty string if the type is an array
        # or there is no argument X.
        pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue=''))
        for i, t01 in pairs:
            if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \
               t01[1] == 'size_t':
                yield BufferParameter(i, not t01[0].startswith('const '),
                                      argument_names[i], argument_names[i+1])

    @staticmethod
    def _write_poison_buffer_parameter(out: typing_util.Writable,
                                       param: BufferParameter,
                                       poison: bool) -> None:
        """Write poisoning or unpoisoning code for a buffer parameter.

        Write poisoning code if poison is true, unpoisoning code otherwise.
        """
        out.write('    MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format(
            'POISON' if poison else 'UNPOISON',
            param.buffer_name, param.size_name
        ))

    def _write_poison_buffer_parameters(self, out: typing_util.Writable,
                                        buffer_parameters: List[BufferParameter],
                                        poison: bool) -> None:
        """Write poisoning or unpoisoning code for the buffer parameters.

        Write poisoning code if poison is true, unpoisoning code otherwise.
        """
        if not buffer_parameters:
            return
        out.write('#if !defined(MBEDTLS_PSA_ASSUME_EXCLUSIVE_BUFFERS)\n')
        for param in buffer_parameters:
            self._write_poison_buffer_parameter(out, param, poison)
        out.write('#endif /* !defined(MBEDTLS_PSA_ASSUME_EXCLUSIVE_BUFFERS) */\n')

    @staticmethod
    def _parameter_should_be_copied(function_name: str,
                                    _buffer_name: Optional[str]) -> bool:
        """Whether the specified buffer argument to a PSA function should be copied.
        """
        # False-positives that do not need buffer copying
        if function_name in ('mbedtls_psa_inject_entropy',
                             'psa_crypto_driver_pake_get_password',
                             'psa_crypto_driver_pake_get_user',
                             'psa_crypto_driver_pake_get_peer'):
            return False

        return True

    def _write_function_call(self, out: typing_util.Writable,
                             function: c_wrapper_generator.FunctionInfo,
                             argument_names: List[str]) -> None:
        buffer_parameters = list(
            param
            for param in self._detect_buffer_parameters(function.arguments,
                                                        argument_names)
            if self._parameter_should_be_copied(function.name,
                                                function.arguments[param.index].name))
        self._write_poison_buffer_parameters(out, buffer_parameters, True)
        super()._write_function_call(out, function, argument_names)
        self._write_poison_buffer_parameters(out, buffer_parameters, False)

    def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
        super()._write_prologue(out, header)
        out.write("""
#if {}

#include <psa/crypto.h>

#include <test/memory.h>
#include <test/psa_crypto_helpers.h>
#include <test/psa_test_wrappers.h>
"""
                  .format(self._CPP_GUARDS))

    def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None:
        out.write("""
#endif /* {} */
"""
                  .format(self._CPP_GUARDS))
        super()._write_epilogue(out, header)


class PSALoggingWrapperGenerator(PSAWrapperGenerator, c_wrapper_generator.Logging):
    """Generate a C source file containing wrapper functions that log PSA Crypto API calls."""

    def __init__(self, stream: str) -> None:
        super().__init__()
        self.set_stream(stream)

    _PRINTF_TYPE_CAST = c_wrapper_generator.Logging._PRINTF_TYPE_CAST.copy()
    _PRINTF_TYPE_CAST.update({
        'mbedtls_svc_key_id_t': 'unsigned',
        'psa_algorithm_t': 'unsigned',
        'psa_drv_slot_number_t': 'unsigned long long',
        'psa_key_derivation_step_t': 'int',
        'psa_key_id_t': 'unsigned',
        'psa_key_slot_number_t': 'unsigned long long',
        'psa_key_lifetime_t': 'unsigned',
        'psa_key_type_t': 'unsigned',
        'psa_key_usage_flags_t': 'unsigned',
        'psa_pake_role_t': 'int',
        'psa_pake_step_t': 'int',
        'psa_status_t': 'int',
    })

    def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]:
        if typ.startswith('const '):
            typ = typ[6:]
        if typ == 'uint8_t *':
            # Skip buffers
            return '', []
        if typ.endswith('operation_t *'):
            return '', []
        if typ in self._PAKE_STUFF:
            return '', []
        if typ == 'psa_key_attributes_t *':
            return (var + '={id=%u, lifetime=0x%08x, type=0x%08x, bits=%u, alg=%08x, usage=%08x}',
                    ['(unsigned) psa_get_key_{}({})'.format(field, var)
                     for field in ['id', 'lifetime', 'type', 'bits', 'algorithm', 'usage_flags']])
        return super()._printf_parameters(typ, var)


DEFAULT_C_OUTPUT_FILE_NAME = 'tests/src/psa_test_wrappers.c'
DEFAULT_H_OUTPUT_FILE_NAME = 'tests/include/test/psa_test_wrappers.h'

def main() -> None:
    parser = argparse.ArgumentParser(description=globals()['__doc__'])
    parser.add_argument('--log',
                        help='Stream to log to (default: no logging code)')
    parser.add_argument('--output-c',
                        metavar='FILENAME',
                        default=DEFAULT_C_OUTPUT_FILE_NAME,
                        help=('Output .c file path (default: {}; skip .c output if empty)'
                              .format(DEFAULT_C_OUTPUT_FILE_NAME)))
    parser.add_argument('--output-h',
                        metavar='FILENAME',
                        default=DEFAULT_H_OUTPUT_FILE_NAME,
                        help=('Output .h file path (default: {}; skip .h output if empty)'
                              .format(DEFAULT_H_OUTPUT_FILE_NAME)))
    options = parser.parse_args()
    if options.log:
        generator = PSALoggingWrapperGenerator(options.log) #type: PSAWrapperGenerator
    else:
        generator = PSAWrapperGenerator()
    generator.gather_data()
    if options.output_h:
        generator.write_h_file(options.output_h)
    if options.output_c:
        generator.write_c_file(options.output_c)

if __name__ == '__main__':
    main()