File: test_client_server.py

package info (click to toggle)
python-nss 0.15.0-1
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 1,452 kB
  • ctags: 1,755
  • sloc: ansic: 27,607; python: 2,688; makefile: 2
file content (346 lines) | stat: -rwxr-xr-x 11,931 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
#!/usr/bin/python

import os
import sys
import errno
import signal
import time

import unittest

from nss.error import NSPRError
import nss.io as io
import nss.nss as nss
import nss.ssl as ssl

# -----------------------------------------------------------------------------
NO_CLIENT_CERT             = 0
REQUEST_CLIENT_CERT_ONCE   = 1
REQUIRE_CLIENT_CERT_ONCE   = 2
REQUEST_CLIENT_CERT_ALWAYS = 3
REQUIRE_CLIENT_CERT_ALWAYS = 4

verbose = False
info = False
password = 'db_passwd'
use_ssl = True
client_cert_action = NO_CLIENT_CERT
db_name = 'sql:pki'
hostname = os.uname()[1]
server_nickname = 'test_server'
client_nickname = 'test_user'
port = 1234
timeout_secs = 10
sleep_time = 5


# -----------------------------------------------------------------------------
# Callback Functions
# -----------------------------------------------------------------------------

def password_callback(slot, retry, password):
    if password: return password
    return getpass.getpass("Enter password: ");

def handshake_callback(sock):
    if verbose: print "handshake complete, peer = %s" % (sock.get_peer_name())

def auth_certificate_callback(sock, check_sig, is_server, certdb):
    if verbose: print "auth_certificate_callback: check_sig=%s is_server=%s" % (check_sig, is_server)
    cert_is_valid = False

    cert = sock.get_peer_certificate()
    pin_args = sock.get_pkcs11_pin_arg()
    if pin_args is None:
        pin_args = ()

    #if verbose: print "cert:\n%s" % cert

    # Define how the cert is being used based upon the is_server flag.  This may
    # seem backwards, but isn't. If we're a server we're trying to validate a
    # client cert. If we're a client we're trying to validate a server cert.
    if is_server:
        intended_usage = nss.certificateUsageSSLClient
    else:
        intended_usage = nss.certificateUsageSSLServer

    try:
        # If the cert fails validation it will raise an exception, the errno attribute
        # will be set to the error code matching the reason why the validation failed
        # and the strerror attribute will contain a string describing the reason.
        approved_usage = cert.verify_now(certdb, check_sig, intended_usage, *pin_args)
    except Exception, e:
        print >>sys.stderr, "auth_certificate_callback: %s" % e
        cert_is_valid = False
        if verbose: print "Returning cert_is_valid = %s" % cert_is_valid
        return cert_is_valid

    if verbose: print "approved_usage = %s" % ', '.join(nss.cert_usage_flags(approved_usage))

    # Is the intended usage a proper subset of the approved usage
    if approved_usage & intended_usage:
        cert_is_valid = True
    else:
        cert_is_valid = False

    # If this is a server, we're finished
    if is_server or not cert_is_valid:
        if verbose: print "Returning cert_is_valid = %s" % cert_is_valid
        return cert_is_valid

    # Certificate is OK.  Since this is the client side of an SSL
    # connection, we need to verify that the name field in the cert
    # matches the desired hostname.  This is our defense against
    # man-in-the-middle attacks.

    hostname = sock.get_hostname()
    if verbose: print "verifying socket hostname (%s) matches cert subject (%s)" % (hostname, cert.subject)
    try:
        # If the cert fails validation it will raise an exception
        cert_is_valid = cert.verify_hostname(hostname)
    except Exception, e:
        print >>sys.stderr, "auth_certificate_callback: %s" % e
        cert_is_valid = False
        if verbose: print "Returning cert_is_valid = %s" % cert_is_valid
        return cert_is_valid

    if verbose: print "Returning cert_is_valid = %s" % cert_is_valid
    return cert_is_valid

def client_auth_data_callback(ca_names, chosen_nickname, password, certdb):
    cert = None
    if chosen_nickname:
        try:
            cert = nss.find_cert_from_nickname(chosen_nickname, password)
            priv_key = nss.find_key_by_any_cert(cert, password)
            if verbose: print "client cert:\n%s" % cert
            return cert, priv_key
        except NSPRError, e:
            print >>sys.stderr, "client_auth_data_callback: %s" % e
            return False
    else:
        nicknames = nss.get_cert_nicknames(certdb, cert.SEC_CERT_NICKNAMES_USER)
        for nickname in nicknames:
            try:
                cert = nss.find_cert_from_nickname(nickname, password)
                if verbose: print "client cert:\n%s" % cert
                if cert.check_valid_times():
                    if cert.has_signer_in_ca_names(ca_names):
                        priv_key = nss.find_key_by_any_cert(cert, password)
                        return cert, priv_key
            except NSPRError, e:
                print >>sys.stderr, "client_auth_data_callback: %s" % e
        return False

# -----------------------------------------------------------------------------
# Client Implementation
# -----------------------------------------------------------------------------

def client(request):
    if use_ssl:
        if info: print "client: using SSL"
        ssl.set_domestic_policy()

    # Get the IP Address of our server
    try:
        addr_info = io.AddrInfo(hostname)
    except Exception, e:
        print >>sys.stderr, "client: could not resolve host address \"%s\"" % hostname
        return

    for net_addr in addr_info:
        net_addr.port = port

        if use_ssl:
            sock = ssl.SSLSocket(net_addr.family)

            # Set client SSL socket options
            sock.set_ssl_option(ssl.SSL_SECURITY, True)
            sock.set_ssl_option(ssl.SSL_HANDSHAKE_AS_CLIENT, True)
            sock.set_hostname(hostname)

            # Provide a callback which notifies us when the SSL handshake is complete
            sock.set_handshake_callback(handshake_callback)

            # Provide a callback to supply our client certificate info
            sock.set_client_auth_data_callback(client_auth_data_callback, client_nickname,
                                               password, nss.get_default_certdb())

            # Provide a callback to verify the servers certificate
            sock.set_auth_certificate_callback(auth_certificate_callback,
                                               nss.get_default_certdb())
        else:
            sock = io.Socket(net_addr.family)

        try:
            if verbose: print "client trying connection to: %s" % (net_addr)
            sock.connect(net_addr, timeout=io.seconds_to_interval(timeout_secs))
            if verbose: print "client connected to: %s" % (net_addr)
            break
        except Exception, e:
            sock.close()
            print >>sys.stderr, "client: connection to: %s failed (%s)" % (net_addr, e)

    # Talk to the server
    try:
        if info: print "client: sending \"%s\"" % (request)
        sock.send(request + '\n') # newline is protocol record separator
        buf = sock.readline()
        if not buf:
            print >>sys.stderr, "client: lost connection"
            sock.close()
            return
        buf = buf.rstrip()        # remove newline record separator
        if info: print "client: received \"%s\"" % (buf)
    except Exception, e:
        print >>sys.stderr, "client: %s" % e
        try:
            sock.close()
        except:
            pass
        return

    try:
        sock.shutdown()
    except Exception, e:
        print >>sys.stderr, "client: %s" % e

    try:
        sock.close()
        if use_ssl:
            ssl.clear_session_cache()
    except Exception, e:
        print >>sys.stderr, "client: %s" % e

    return buf

# -----------------------------------------------------------------------------
# Server Implementation
# -----------------------------------------------------------------------------

def server():
    if verbose: print "starting server:"

    # Initialize
    # Setup an IP Address to listen on any of our interfaces
    net_addr = io.NetworkAddress(io.PR_IpAddrAny, port)

    if use_ssl:
        if info: print "server: using SSL"
        ssl.set_domestic_policy()
        nss.set_password_callback(password_callback)

        # Perform basic SSL server configuration
        ssl.set_default_cipher_pref(ssl.SSL_RSA_WITH_NULL_MD5, True)
        ssl.config_server_session_id_cache()

        # Get our certificate and private key
        server_cert = nss.find_cert_from_nickname(server_nickname, password)
        priv_key = nss.find_key_by_any_cert(server_cert, password)
        server_cert_kea = server_cert.find_kea_type();

        #if verbose: print "server cert:\n%s" % server_cert

        sock = ssl.SSLSocket(net_addr.family)

        # Set server SSL socket options
        sock.set_pkcs11_pin_arg(password)
        sock.set_ssl_option(ssl.SSL_SECURITY, True)
        sock.set_ssl_option(ssl.SSL_HANDSHAKE_AS_SERVER, True)

        # If we're doing client authentication then set it up
        if client_cert_action >= REQUEST_CLIENT_CERT_ONCE:
            sock.set_ssl_option(ssl.SSL_REQUEST_CERTIFICATE, True)
        if client_cert_action == REQUIRE_CLIENT_CERT_ONCE:
            sock.set_ssl_option(ssl.SSL_REQUIRE_CERTIFICATE, True)
        sock.set_auth_certificate_callback(auth_certificate_callback, nss.get_default_certdb())

        # Configure the server SSL socket
        sock.config_secure_server(server_cert, priv_key, server_cert_kea)

    else:
        sock = io.Socket(net_addr.family)

    # Bind to our network address and listen for clients
    sock.bind(net_addr)
    if verbose: print "listening on: %s" % (net_addr)
    sock.listen()

    while True:
        # Accept a connection from a client
        client_sock, client_addr = sock.accept()
        if use_ssl:
            client_sock.set_handshake_callback(handshake_callback)

        if verbose: print "client connect from: %s" % (client_addr)

        while True:
            try:
                # Handle the client connection
                buf = client_sock.readline()   # newline is protocol record separator
                if not buf:
                    print >>sys.stderr, "server: lost lost connection to %s" % (client_addr)
                    break
                buf = buf.rstrip()             # remove newline record separator

                if info: print "server: received \"%s\"" % (buf)
                reply = "{%s}" % buf           # echo embedded inside braces
                if info: print "server: sending \"%s\"" % (reply)
                client_sock.send(reply + '\n') # send echo with record separator

                time.sleep(sleep_time)
                client_sock.shutdown()
                client_sock.close()
                break
            except Exception, e:
                print >>sys.stderr, "server: %s" % e
                break
        break

    # Clean up
    sock.shutdown()
    sock.close()
    if use_ssl:
        ssl.shutdown_server_session_id_cache()

# -----------------------------------------------------------------------------

def run_server():
    pid = os.fork()
    if pid == 0:
        nss.nss_init(db_name)
        server()
        nss.nss_shutdown()
    time.sleep(sleep_time)
    return pid

def cleanup_server(pid):
    try:
        wait_pid, wait_status = os.waitpid(pid, os.WNOHANG)
        if wait_pid == 0:
            os.kill(pid, signal.SIGKILL)
    except OSError, e:
        if e.errno == errno.ECHILD:
            pass                # child already exited
        else:
            print >>sys.stderr, "cleanup_server: %s" % e

class TestSSL(unittest.TestCase):

    def setUp(self):
        print
        self.server_pid = run_server()

    def tearDown(self):
        cleanup_server(self.server_pid)

    def test_ssl(self):
        request = "foo"
        nss.nss_init(db_name)
        reply = client(request)
        nss.nss_shutdown()
        self.assertEqual("{%s}" % request, reply)


if __name__ == '__main__':
    unittest.main()