File: server.py

package info (click to toggle)
fabric 1.14.0-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye, buster, sid
  • size: 1,240 kB
  • sloc: python: 7,363; makefile: 10
file content (489 lines) | stat: -rw-r--r-- 16,063 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
from __future__ import with_statement

import os
import re
import socket
import threading
import time
import types
from functools import wraps
from Python26SocketServer import BaseRequestHandler, ThreadingMixIn, TCPServer

from fabric.operations import _sudo_prefix
from fabric.api import env, hide
from fabric.thread_handling import ThreadHandler
from fabric.network import disconnect_all, ssh

from fake_filesystem import FakeFilesystem, FakeFile

#
# Debugging
#

import logging
logging.basicConfig(filename='/tmp/fab.log', level=logging.DEBUG)
logger = logging.getLogger('server.py')


#
# Constants
#

HOST = '127.0.0.1'
PORT = 2200
USER = 'username'
HOME = '/'
RESPONSES = {
    "ls /simple": "some output",
    "ls /": """AUTHORS
FAQ
Fabric.egg-info
INSTALL
LICENSE
MANIFEST
README
build
docs
fabfile.py
fabfile.pyc
fabric
requirements.txt
setup.py
tests""",
    "both_streams": [
        "stdout",
        "stderr"
    ],
}
FILES = FakeFilesystem({
    '/file.txt': 'contents',
    '/file2.txt': 'contents2',
    '/folder/file3.txt': 'contents3',
    '/empty_folder': None,
    '/tree/file1.txt': 'x',
    '/tree/file2.txt': 'y',
    '/tree/subfolder/file3.txt': 'z',
    '/etc/apache2/apache2.conf': 'Include other.conf',
    HOME: None  # So $HOME is a directory
})
PASSWORDS = {
    'root': 'root',
    USER: 'password'
}


def _local_file(filename):
    return os.path.join(os.path.dirname(__file__), filename)

SERVER_PRIVKEY = _local_file('private.key')
CLIENT_PUBKEY = _local_file('client.key.pub')
CLIENT_PRIVKEY = _local_file('client.key')
CLIENT_PRIVKEY_PASSPHRASE = "passphrase"


def _equalize(lists, fillval=None):
    """
    Pad all given list items in ``lists`` to be the same length.
    """
    lists = map(list, lists)
    upper = max(len(x) for x in lists)
    for lst in lists:
        diff = upper - len(lst)
        if diff:
            lst.extend([fillval] * diff)
    return lists


class TestServer(ssh.ServerInterface):
    """
    Test server implementing the 'ssh' lib's server interface parent class.

    Mostly just handles the bare minimum necessary to handle SSH-level things
    such as honoring authentication types and exec/shell/etc requests.

    The bulk of the actual server side logic is handled in the
    ``serve_responses`` function and its ``SSHHandler`` class.
    """
    def __init__(self, passwords, home, pubkeys, files):
        self.event = threading.Event()
        self.passwords = passwords
        self.pubkeys = pubkeys
        self.files = FakeFilesystem(files)
        self.home = home
        self.command = None

    def check_channel_request(self, kind, chanid):
        if kind == 'session':
            return ssh.OPEN_SUCCEEDED
        return ssh.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED

    def check_channel_exec_request(self, channel, command):
        self.command = command
        self.event.set()
        return True

    def check_channel_pty_request(self, *args):
        return True

    def check_channel_shell_request(self, channel):
        self.event.set()
        return True

    def check_auth_password(self, username, password):
        self.username = username
        passed = self.passwords.get(username) == password
        return ssh.AUTH_SUCCESSFUL if passed else ssh.AUTH_FAILED

    def check_auth_publickey(self, username, key):
        self.username = username
        return ssh.AUTH_SUCCESSFUL if self.pubkeys else ssh.AUTH_FAILED

    def get_allowed_auths(self, username):
        return 'password,publickey'


class SSHServer(ThreadingMixIn, TCPServer):
    """
    Threading TCPServer subclass.
    """
    def _socket_info(self, addr_tup):
        """
        Clone of the very top of Paramiko 1.7.6 SSHClient.connect().

        We must use this in order to make sure that our address family matches
        up with the client side (which we cannot control, and which varies
        depending on individual computers and their network settings).
        """
        hostname, port = addr_tup
        addr_info = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC,
            socket.SOCK_STREAM)
        for (family, socktype, proto, canonname, sockaddr) in addr_info:
            if socktype == socket.SOCK_STREAM:
                af = family
                addr = sockaddr
                break
        else:
            # some OS like AIX don't indicate SOCK_STREAM support, so just
            # guess. :(
            af, _, _, _, addr = socket.getaddrinfo(hostname, port,
                socket.AF_UNSPEC, socket.SOCK_STREAM)
        return af, addr

    def __init__(
        self, server_address, RequestHandlerClass, bind_and_activate=True
    ):
        # Prevent "address already in use" errors when running tests 2x in a
        # row.
        self.allow_reuse_address = True

        # Handle network family/host addr (see docstring for _socket_info)
        family, addr = self._socket_info(server_address)
        self.address_family = family
        TCPServer.__init__(self, addr, RequestHandlerClass,
            bind_and_activate)


class FakeSFTPHandle(ssh.SFTPHandle):
    """
    Extremely basic way to get SFTPHandle working with our fake setup.
    """
    def chattr(self, attr):
        self.readfile.attributes = attr
        return ssh.SFTP_OK

    def stat(self):
        return self.readfile.attributes


class PrependList(list):
    def prepend(self, val):
        self.insert(0, val)


def expand(path):
    """
    '/foo/bar/biz' => ('/', 'foo', 'bar', 'biz')
    'relative/path' => ('relative', 'path')
    """
    # Base case
    if path in ['', os.path.sep]:
        return [path]
    ret = PrependList()
    directory, filename = os.path.split(path)
    while directory and directory != os.path.sep:
        ret.prepend(filename)
        directory, filename = os.path.split(directory)
    ret.prepend(filename)
    # Handle absolute vs relative paths
    ret.prepend(directory if directory == os.path.sep else '')
    return ret


def contains(folder, path):
    """
    contains(('a', 'b', 'c'), ('a', 'b')) => True
    contains('a', 'b', 'c'), ('f',)) => False
    """
    return False if len(path) >= len(folder) else folder[:len(path)] == path


def missing_folders(paths):
    """
    missing_folders(['a/b/c']) => ['a', 'a/b', 'a/b/c']
    """
    ret = []
    pool = set(paths)
    for path in paths:
        expanded = expand(path)
        for i in range(len(expanded)):
            folder = os.path.join(*expanded[:len(expanded) - i])
            if folder and folder not in pool:
                pool.add(folder)
                ret.append(folder)
    return ret


def canonicalize(path, home):
    ret = path
    if not os.path.isabs(path):
        ret = os.path.normpath(os.path.join(home, path))
    return ret


class FakeSFTPServer(ssh.SFTPServerInterface):
    def __init__(self, server, *args, **kwargs):
        self.server = server
        files = self.server.files
        # Expand such that omitted, implied folders get added explicitly
        for folder in missing_folders(files.keys()):
            files[folder] = None
        self.files = files

    def canonicalize(self, path):
        """
        Make non-absolute paths relative to $HOME.
        """
        return canonicalize(path, self.server.home)

    def list_folder(self, path):
        path = self.files.normalize(path)
        expanded_files = map(expand, self.files)
        expanded_path = expand(path)
        candidates = [x for x in expanded_files if contains(x, expanded_path)]
        children = []
        for candidate in candidates:
            cut = candidate[:len(expanded_path) + 1]
            if cut not in children:
                children.append(cut)
        results = [self.stat(os.path.join(*x)) for x in children]
        bad = not results or any(x == ssh.SFTP_NO_SUCH_FILE for x in results)
        return ssh.SFTP_NO_SUCH_FILE if bad else results

    def open(self, path, flags, attr):
        path = self.files.normalize(path)
        try:
            fobj = self.files[path]
        except KeyError:
            if flags & os.O_WRONLY:
                # Only allow writes to files in existing directories.
                if os.path.dirname(path) not in self.files:
                    return ssh.SFTP_NO_SUCH_FILE
                self.files[path] = fobj = FakeFile("", path)
            # No write flag means a read, which means they tried to read a
            # nonexistent file.
            else:
                return ssh.SFTP_NO_SUCH_FILE
        f = FakeSFTPHandle()
        f.readfile = f.writefile = fobj
        return f

    def stat(self, path):
        path = self.files.normalize(path)
        try:
            fobj = self.files[path]
        except KeyError:
            return ssh.SFTP_NO_SUCH_FILE
        return fobj.attributes

    # Don't care about links right now
    lstat = stat

    def chattr(self, path, attr):
        path = self.files.normalize(path)
        if path not in self.files:
            return ssh.SFTP_NO_SUCH_FILE
        # Attempt to gracefully update instead of overwrite, since things like
        # chmod will call us with an SFTPAttributes object that only exhibits
        # e.g. st_mode, and we don't want to lose our filename or size...
        for which in "size uid gid mode atime mtime".split():
            attname = "st_" + which
            incoming = getattr(attr, attname)
            if incoming is not None:
                setattr(self.files[path].attributes, attname, incoming)
        return ssh.SFTP_OK

    def mkdir(self, path, attr):
        self.files[path] = None
        return ssh.SFTP_OK


def serve_responses(responses, files, passwords, home, pubkeys, port):
    """
    Return a threading TCP based SocketServer listening on ``port``.

    Used as a fake SSH server which will respond to commands given in
    ``responses`` and allow connections for users listed in ``passwords``.
    ``home`` is used as the remote $HOME (mostly for SFTP purposes).

    ``pubkeys`` is a Boolean value determining whether the server will allow
    pubkey auth or not.
    """
    # Define handler class inline so it can access serve_responses' args
    class SSHHandler(BaseRequestHandler):
        def handle(self):
            try:
                self.init_transport()
                self.waiting_for_command = False
                while not self.server.all_done.isSet():
                    # Don't overwrite channel if we're waiting for a command.
                    if not self.waiting_for_command:
                        self.channel = self.transport.accept(1)
                        if not self.channel:
                            continue
                    self.ssh_server.event.wait(10)
                    if self.ssh_server.command:
                        self.command = self.ssh_server.command
                        # Set self.sudo_prompt, update self.command
                        self.split_sudo_prompt()
                        if self.command in responses:
                            self.stdout, self.stderr, self.status = \
                                self.response()
                            if self.sudo_prompt and not self.sudo_password():
                                self.channel.send(
                                    "sudo: 3 incorrect password attempts\n"
                                )
                                break
                            self.respond()
                        else:
                            self.channel.send_stderr(
                                "Sorry, I don't recognize that command.\n"
                            )
                            self.channel.send_exit_status(1)
                        # Close up shop
                        self.command = self.ssh_server.command = None
                        self.waiting_for_command = False
                        time.sleep(0.5)
                        self.channel.close()
                    else:
                        # If we're here, self.command was False or None,
                        # but we do have a valid Channel object. Thus we're
                        # waiting for the command to show up.
                        self.waiting_for_command = True

            finally:
                self.transport.close()

        def init_transport(self):
            transport = ssh.Transport(self.request)
            transport.add_server_key(ssh.RSAKey(filename=SERVER_PRIVKEY))
            transport.set_subsystem_handler('sftp', ssh.SFTPServer,
                sftp_si=FakeSFTPServer)
            server = TestServer(passwords, home, pubkeys, files)
            transport.start_server(server=server)
            self.ssh_server = server
            self.transport = transport

        def split_sudo_prompt(self):
            prefix = re.escape(_sudo_prefix(None, None).rstrip()) + ' +'
            result = re.findall(r'^(%s)?(.*)$' % prefix, self.command)[0]
            self.sudo_prompt, self.command = result

        def response(self):
            result = responses[self.command]
            stderr = ""
            status = 0
            sleep = 0
            if isinstance(result, types.StringTypes):
                stdout = result
            else:
                size = len(result)
                if size == 1:
                    stdout = result[0]
                elif size == 2:
                    stdout, stderr = result
                elif size == 3:
                    stdout, stderr, status = result
                elif size == 4:
                    stdout, stderr, status, sleep = result
            stdout, stderr = _equalize((stdout, stderr))
            time.sleep(sleep)
            return stdout, stderr, status

        def sudo_password(self):
            # Give user 3 tries, as is typical
            passed = False
            for x in range(3):
                self.channel.send(env.sudo_prompt)
                password = self.channel.recv(65535).strip()
                # Spit back newline to fake the echo of user's
                # newline
                self.channel.send('\n')
                # Test password
                if password == passwords[self.ssh_server.username]:
                    passed = True
                    break
                # If here, password was bad.
                self.channel.send("Sorry, try again.\n")
            return passed

        def respond(self):
            for out, err in zip(self.stdout, self.stderr):
                if out is not None:
                    self.channel.send(out)
                if err is not None:
                    self.channel.send_stderr(err)
            self.channel.send_exit_status(self.status)

    return SSHServer((HOST, port), SSHHandler)


def server(
        responses=RESPONSES,
        files=FILES,
        passwords=PASSWORDS,
        home=HOME,
        pubkeys=False,
        port=PORT
    ):
    """
    Returns a decorator that runs an SSH server during function execution.

    Direct passthrough to ``serve_responses``.
    """
    def run_server(func):
        @wraps(func)
        def inner(*args, **kwargs):
            # Start server
            _server = serve_responses(responses, files, passwords, home,
                pubkeys, port)
            _server.all_done = threading.Event()
            worker = ThreadHandler('server', _server.serve_forever)
            # Execute function
            try:
                return func(*args, **kwargs)
            finally:
                # Clean up client side connections
                with hide('status'):
                    disconnect_all()
                # Stop server
                _server.all_done.set()
                _server.shutdown()
                # Why this is not called in shutdown() is beyond me.
                _server.server_close()
                worker.thread.join()
                # Handle subthread exceptions
                e = worker.exception
                if e:
                    raise e[0], e[1], e[2]
        return inner
    return run_server