File: pkctx.hpp

package info (click to toggle)
openvpn3-client 25%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 19,276 kB
  • sloc: cpp: 190,085; python: 7,218; ansic: 1,866; sh: 1,361; java: 402; lisp: 81; makefile: 17
file content (167 lines) | stat: -rw-r--r-- 4,486 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
//    OpenVPN -- An application to securely tunnel IP networks
//               over a single port, with support for SSL/TLS-based
//               session authentication and key exchange,
//               packet encryption, packet authentication, and
//               packet compression.
//
//    Copyright (C) 2012- OpenVPN Inc.
//
//    SPDX-License-Identifier: MPL-2.0 OR AGPL-3.0-only WITH openvpn3-openssl-exception
//

// Wrap a mbed TLS pk_context object.

#ifndef OPENVPN_MBEDTLS_PKI_PKCTX_H
#define OPENVPN_MBEDTLS_PKI_PKCTX_H

#include <string>
#include <sstream>
#include <cstring>

#include <mbedtls/pk.h>

#include <openvpn/common/size.hpp>
#include <openvpn/common/exception.hpp>
#include <openvpn/common/rc.hpp>
#include <openvpn/mbedtls/util/error.hpp>
#include <openvpn/mbedtls/util/rand.hpp>

namespace openvpn::MbedTLSPKI {

class PKContext : public RC<thread_unsafe_refcount>
{
  public:
    typedef RCPtr<PKContext> Ptr;

    PKContext()
        : ctx(nullptr)
    {
    }

    PKContext(const std::string &key_txt, const std::string &title, const std::string &priv_key_pwd, MbedTLSRandom &rand)
        : ctx(nullptr)
    {
        try
        {
            parse(key_txt, title, priv_key_pwd, rand);
        }
        catch (...)
        {
            dealloc();
            throw;
        }
    }

    bool defined() const
    {
        return ctx != nullptr;
    }

    PKType::Type key_type() const
    {
        switch (mbedtls_pk_get_type(ctx))
        {
        case MBEDTLS_PK_RSA:
        case MBEDTLS_PK_RSA_ALT:
        case MBEDTLS_PK_RSASSA_PSS:
            return PKType::PK_RSA;
        case MBEDTLS_PK_ECKEY:
        case MBEDTLS_PK_ECKEY_DH:
            return PKType::PK_EC;
        case MBEDTLS_PK_ECDSA:
            return PKType::PK_ECDSA;
        case MBEDTLS_PK_NONE:
            return PKType::PK_NONE;
        default:
            return PKType::PK_UNKNOWN;
        }
    }

    size_t key_length() const
    {
        return mbedtls_pk_get_bitlen(ctx);
    }

    void parse(const std::string &key_txt, const std::string &title, const std::string &priv_key_pwd, MbedTLSRandom &rand)
    {
        alloc();
        // key_txt.length() is increased by 1 as it does not include the NULL-terminator
        // which mbedtls_pk_parse_key() expects to see.
        const int status = mbedtls_pk_parse_key(ctx,
                                                (const unsigned char *)key_txt.c_str(),
                                                key_txt.length() + 1,
                                                (const unsigned char *)priv_key_pwd.c_str(),
                                                priv_key_pwd.length()
#if MBEDTLS_VERSION_NUMBER > 0x03000000
                                                    ,
                                                mbedtls_ctr_drbg_random,
                                                rand.get_ctr_drbg_ctx()
#endif
        );
        if (status < 0)
            throw MbedTLSException("error parsing " + title + " private key", status);
    }

    std::string extract() const
    {
        // maximum size of the PEM data is not available at this point
        BufferAllocated buff(16000, 0);

        int ret = mbedtls_pk_write_key_pem(ctx, buff.data(), buff.max_size());
        if (ret < 0)
            throw MbedTLSException("extract priv_key: can't write to buffer", ret);

        return std::string((const char *)buff.data());
    }

    std::string render_pem() const
    {
        return extract();
    }

    void epki_enable(void *arg,
                     mbedtls_pk_rsa_alt_decrypt_func epki_decrypt,
                     mbedtls_pk_rsa_alt_sign_func epki_sign,
                     mbedtls_pk_rsa_alt_key_len_func epki_key_len)
    {
        alloc();
        const int status = mbedtls_pk_setup_rsa_alt(ctx, arg, epki_decrypt, epki_sign, epki_key_len);
        if (status < 0)
            throw MbedTLSException("error in mbedtls_pk_setup_rsa_alt", status);
    }

    mbedtls_pk_context *get() const
    {
        return ctx;
    }

    ~PKContext()
    {
        dealloc();
    }

  private:
    void alloc()
    {
        if (!ctx)
        {
            ctx = new mbedtls_pk_context;
            mbedtls_pk_init(ctx);
        }
    }

    void dealloc()
    {
        if (ctx)
        {
            mbedtls_pk_free(ctx);
            delete ctx;
            ctx = nullptr;
        }
    }

    mbedtls_pk_context *ctx;
};

} // namespace openvpn::MbedTLSPKI
#endif