File: ssl_example.py

package info (click to toggle)
python-nss 0.15.0-1
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 1,452 kB
  • ctags: 1,755
  • sloc: ansic: 27,607; python: 2,688; makefile: 2
file content (428 lines) | stat: -rwxr-xr-x 15,198 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
#!/usr/bin/python

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import warnings
warnings.simplefilter( "always", DeprecationWarning)

import argparse
import getpass
import os
import sys

from nss.error import NSPRError
import nss.io as io
import nss.nss as nss
import nss.ssl as ssl

# -----------------------------------------------------------------------------
NO_CLIENT_CERT             = 0
REQUEST_CLIENT_CERT_ONCE   = 1
REQUIRE_CLIENT_CERT_ONCE   = 2
REQUEST_CLIENT_CERT_ALWAYS = 3
REQUIRE_CLIENT_CERT_ALWAYS = 4

timeout_secs = 3

# -----------------------------------------------------------------------------
# Utility Functions
# -----------------------------------------------------------------------------


# -----------------------------------------------------------------------------
# Callback Functions
# -----------------------------------------------------------------------------

def password_callback(slot, retry, password):
    if password: return password
    return getpass.getpass("Enter password: ");

def handshake_callback(sock):
    print "handshake complete, peer = %s" % (sock.get_peer_name())

def auth_certificate_callback(sock, check_sig, is_server, certdb):
    print "auth_certificate_callback: check_sig=%s is_server=%s" % (check_sig, is_server)
    cert_is_valid = False

    cert = sock.get_peer_certificate()
    pin_args = sock.get_pkcs11_pin_arg()
    if pin_args is None:
        pin_args = ()

    print "peer cert:\n%s" % cert

    # Define how the cert is being used based upon the is_server flag.  This may
    # seem backwards, but isn't. If we're a server we're trying to validate a
    # client cert. If we're a client we're trying to validate a server cert.
    if is_server:
        intended_usage = nss.certificateUsageSSLClient
    else:
        intended_usage = nss.certificateUsageSSLServer

    try:
        # If the cert fails validation it will raise an exception, the errno attribute
        # will be set to the error code matching the reason why the validation failed
        # and the strerror attribute will contain a string describing the reason.
        approved_usage = cert.verify_now(certdb, check_sig, intended_usage, *pin_args)
    except Exception, e:
        print e.strerror
        cert_is_valid = False
        print "Returning cert_is_valid = %s" % cert_is_valid
        return cert_is_valid

    print "approved_usage = %s" % ', '.join(nss.cert_usage_flags(approved_usage))

    # Is the intended usage a proper subset of the approved usage
    if approved_usage & intended_usage:
        cert_is_valid = True
    else:
        cert_is_valid = False

    # If this is a server, we're finished
    if is_server or not cert_is_valid:
        print "Returning cert_is_valid = %s" % cert_is_valid
        return cert_is_valid

    # Certificate is OK.  Since this is the client side of an SSL
    # connection, we need to verify that the name field in the cert
    # matches the desired hostname.  This is our defense against
    # man-in-the-middle attacks.

    hostname = sock.get_hostname()
    print "verifying socket hostname (%s) matches cert subject (%s)" % (hostname, cert.subject)
    try:
        # If the cert fails validation it will raise an exception
        cert_is_valid = cert.verify_hostname(hostname)
    except Exception, e:
        print e.strerror
        cert_is_valid = False
        print "Returning cert_is_valid = %s" % cert_is_valid
        return cert_is_valid

    print "Returning cert_is_valid = %s" % cert_is_valid
    return cert_is_valid

def client_auth_data_callback(ca_names, chosen_nickname, password, certdb):
    cert = None
    if chosen_nickname:
        try:
            cert = nss.find_cert_from_nickname(chosen_nickname, password)
            priv_key = nss.find_key_by_any_cert(cert, password)
            print "client cert:\n%s" % cert
            return cert, priv_key
        except NSPRError, e:
            print e
            return False
    else:
        nicknames = nss.get_cert_nicknames(certdb, cert.SEC_CERT_NICKNAMES_USER)
        for nickname in nicknames:
            try:
                cert = nss.find_cert_from_nickname(nickname, password)
                print "client cert:\n%s" % cert
                if cert.check_valid_times():
                    if cert.has_signer_in_ca_names(ca_names):
                        priv_key = nss.find_key_by_any_cert(cert, password)
                        return cert, priv_key
            except NSPRError, e:
                print e
        return False

# -----------------------------------------------------------------------------
# Client Implementation
# -----------------------------------------------------------------------------

def Client():
    valid_addr = False
    # Get the IP Address of our server
    try:
        addr_info = io.AddrInfo(options.hostname)
    except Exception, e:
        print "could not resolve host address \"%s\"" % options.hostname
        return

    for net_addr in addr_info:
        if options.family != io.PR_AF_UNSPEC:
            if net_addr.family != options.family: continue
        net_addr.port = options.port

        if options.use_ssl:
            sock = ssl.SSLSocket(net_addr.family)

            # Set client SSL socket options
            sock.set_ssl_option(ssl.SSL_SECURITY, True)
            sock.set_ssl_option(ssl.SSL_HANDSHAKE_AS_CLIENT, True)
            sock.set_hostname(options.hostname)

            # Provide a callback which notifies us when the SSL handshake is complete
            sock.set_handshake_callback(handshake_callback)

            # Provide a callback to supply our client certificate info
            sock.set_client_auth_data_callback(client_auth_data_callback, options.client_nickname,
                                               options.password, nss.get_default_certdb())

            # Provide a callback to verify the servers certificate
            sock.set_auth_certificate_callback(auth_certificate_callback,
                                               nss.get_default_certdb())
        else:
            sock = io.Socket(net_addr.family)

        try:
            print "client trying connection to: %s" % (net_addr)
            sock.connect(net_addr, timeout=io.seconds_to_interval(timeout_secs))
            print "client connected to: %s" % (net_addr)
            valid_addr = True
            break
        except Exception, e:
            sock.close()
            print "client connection to: %s failed (%s)" % (net_addr, e)

    if not valid_addr:
        print "Could not establish valid address for \"%s\" in family %s" % \
        (options.hostname, io.addr_family_name(options.family))
        return

    # Talk to the server
    try:
        sock.send('Hello' + '\n') # newline is protocol record separator
        buf = sock.readline()
        if not buf:
            print "client lost connection"
            sock.close()
            return
        buf = buf.rstrip()        # remove newline record separator
        print "client received: %s" % (buf)
    except Exception, e:
        print e.strerror
        try:
            sock.close()
        except:
            pass
        return

    # End of (simple) protocol session?
    if buf == 'Goodbye':
        try:
            sock.shutdown()
        except:
            pass

    try:
        sock.close()
        if options.use_ssl:
            ssl.clear_session_cache()
    except Exception, e:
        print e

# -----------------------------------------------------------------------------
# Server Implementation
# -----------------------------------------------------------------------------

def Server():
    # Setup an IP Address to listen on any of our interfaces
    if options.family == io.PR_AF_UNSPEC:
        options.family = io.PR_AF_INET
    net_addr = io.NetworkAddress(io.PR_IpAddrAny, options.port, options.family)

    if options.use_ssl:
        # Perform basic SSL server configuration
        ssl.set_default_cipher_pref(ssl.SSL_RSA_WITH_NULL_MD5, True)
        ssl.config_server_session_id_cache()

        # Get our certificate and private key
        server_cert = nss.find_cert_from_nickname(options.server_nickname, options.password)
        priv_key = nss.find_key_by_any_cert(server_cert, options.password)
        server_cert_kea = server_cert.find_kea_type();

        print "server cert:\n%s" % server_cert

        sock = ssl.SSLSocket(net_addr.family)

        # Set server SSL socket options
        sock.set_pkcs11_pin_arg(options.password)
        sock.set_ssl_option(ssl.SSL_SECURITY, True)
        sock.set_ssl_option(ssl.SSL_HANDSHAKE_AS_SERVER, True)

        # If we're doing client authentication then set it up
        if options.client_cert_action >= REQUEST_CLIENT_CERT_ONCE:
            sock.set_ssl_option(ssl.SSL_REQUEST_CERTIFICATE, True)
        if options.client_cert_action == REQUIRE_CLIENT_CERT_ONCE:
            sock.set_ssl_option(ssl.SSL_REQUIRE_CERTIFICATE, True)
        sock.set_auth_certificate_callback(auth_certificate_callback, nss.get_default_certdb())

        # Configure the server SSL socket
        sock.config_secure_server(server_cert, priv_key, server_cert_kea)

    else:
        sock = io.Socket(net_addr.family)

    # Bind to our network address and listen for clients
    sock.bind(net_addr)
    print "listening on: %s" % (net_addr)
    sock.listen()

    while True:
        # Accept a connection from a client
        client_sock, client_addr = sock.accept()
        if options.use_ssl:
            client_sock.set_handshake_callback(handshake_callback)

        print "client connect from: %s" % (client_addr)

        while True:
            try:
                # Handle the client connection
                buf = client_sock.readline()
                if not buf:
                    print "server lost lost connection to %s" % (client_addr)
                    break

                buf = buf.rstrip()                 # remove newline record separator
                print "server received: %s" % (buf)

                client_sock.send('Goodbye' + '\n') # newline is protocol record separator
                try:
                    client_sock.shutdown(io.PR_SHUTDOWN_RCV)
                    client_sock.close()
                except:
                    pass
                break
            except Exception, e:
                print e.strerror
                break
        break

    try:
        sock.shutdown()
        sock.close()
        if options.use_ssl:
            ssl.shutdown_server_session_id_cache()
    except Exception, e:
        print e
        pass

# -----------------------------------------------------------------------------

class FamilyArgAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        value = values[0]
        if value == "inet":
            family = io.PR_AF_INET
        elif value == "inet6":
            family = io.PR_AF_INET6
        elif value == "unspec":
            family = io.PR_AF_UNSPEC
        else:
            raise argparse.ArgumentError(self, "unknown address family (%s)" % (value))
        setattr(namespace, self.dest, family)

parser = argparse.ArgumentParser(description='SSL example')

parser.add_argument('-C', '--client', action='store_true',
                    help='run as the client')

parser.add_argument('-S', '--server', action='store_true',
                    help='run as the server')

parser.add_argument('-d', '--db-name',
                    help='NSS database name (e.g. "sql:pki")')

parser.add_argument('-H', '--hostname',
                    help='host to connect to')

parser.add_argument('-f', '--family',
                    choices=['unspec', 'inet', 'inet6'],
                    dest='family', action=FamilyArgAction, nargs=1,
                    help='''
                      If unspec client tries all addresses returned by AddrInfo,
                      server binds to IPv4 "any" wildcard address.

                      If inet client tries IPv4 addresses returned by AddrInfo,
                      server binds to IPv4 "any" wildcard address.

                      If inet6 client tries IPv6 addresses returned by AddrInfo,
                      server binds to IPv6 "any" wildcard address''')

parser.add_argument('-4', '--inet',
                    dest='family', action='store_const', const=io.PR_AF_INET,
                    help='set family to inet (see family)')

parser.add_argument('-6', '--inet6',
                    dest='family', action='store_const', const=io.PR_AF_INET6,
                    help='set family to inet6 (see family)')

parser.add_argument('-n', '--server-nickname',
                    help='server certificate nickname')

parser.add_argument('-N', '--client-nickname',
                    help='client certificate nickname')

parser.add_argument('-w', '--password',
                    help='certificate database password')

parser.add_argument('-p', '--port', type=int,
                    help='host port')

parser.add_argument('-e', '--encrypt', dest='use_ssl', action='store_true',
                    help='use SSL connection')

parser.add_argument('-E', '--no-encrypt', dest='use_ssl', action='store_false',
                    help='do not use SSL connection')

parser.add_argument('--require-cert-once', dest='client_cert_action',
                    action='store_const', const=REQUIRE_CLIENT_CERT_ONCE)

parser.add_argument('--require-cert-always', dest='client_cert_action',
                    action='store_const', const=REQUIRE_CLIENT_CERT_ALWAYS)

parser.add_argument('--request-cert-once', dest='client_cert_action',
                    action='store_const', const=REQUEST_CLIENT_CERT_ONCE)

parser.add_argument('--request-cert-always', dest='client_cert_action',
                    action='store_const', const=REQUEST_CLIENT_CERT_ALWAYS)

parser.set_defaults(client = False,
                    server = False,
                    db_name = 'sql:pki',
                    hostname = os.uname()[1],
                    family = io.PR_AF_UNSPEC,
                    server_nickname = 'test_server',
                    client_nickname = 'test_user',
                    password = 'db_passwd',
                    port = 1234,
                    use_ssl = True,
                    client_cert_action = NO_CLIENT_CERT,
                   )

options = parser.parse_args()

if options.client and options.server:
    print "can't be both client and server"
    sys.exit(1)
if not (options.client or options.server):
    print "must be one of client or server"
    sys.exit(1)

# Perform basic configuration and setup
if options.use_ssl:
    nss.nss_init(options.db_name)
else:
    nss.nss_init_nodb()

ssl.set_domestic_policy()
nss.set_password_callback(password_callback)

# Run as a client or as a server
if options.client:
    print "starting as client"
    Client()

if options.server:
    print "starting as server"
    Server()

try:
    nss.nss_shutdown()
except Exception, e:
    print e