File: ssh_utils.py

package info (click to toggle)
mistral-tempest-plugin 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 416 kB
  • sloc: python: 2,503; makefile: 21
file content (102 lines) | stat: -rwxr-xr-x 3,006 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright 2014 - Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

from os import path
from oslo_log import log as logging
import paramiko

KEY_PATH = path.expanduser("~/.ssh/")
LOG = logging.getLogger(__name__)


def _read_paramimko_stream(recv_func):
    result = ''
    buf = recv_func(1024)
    while buf != '':
        result += buf
        buf = recv_func(1024)

    return result


def _to_paramiko_private_key(private_key_filename, password=None):
    if '../' in private_key_filename or '..\\' in private_key_filename:
        raise OSError(
            "Private key filename must not contain '..'. "
            "Actual: %s" % private_key_filename
        )

    private_key_path = KEY_PATH + private_key_filename

    return paramiko.RSAKey(
        filename=private_key_path,
        password=password
    )


def _connect(host, username, password=None, pkey=None, proxy=None):
    if isinstance(pkey, str):
        pkey = _to_paramiko_private_key(pkey, password)

    LOG.debug('Creating SSH connection to %s', host)

    ssh_client = paramiko.SSHClient()
    ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

    ssh_client.connect(
        host,
        username=username,
        password=password,
        pkey=pkey,
        sock=proxy
    )

    return ssh_client


def _cleanup(ssh_client):
    ssh_client.close()


def _execute_command(ssh_client, cmd, get_stderr=False,
                     raise_when_error=True):
    try:
        chan = ssh_client.get_transport().open_session()
        chan.exec_command(cmd)

        # TODO(nmakhotkin): that could hang if stderr buffer overflows
        stdout = _read_paramimko_stream(chan.recv)
        stderr = _read_paramimko_stream(chan.recv_stderr)

        ret_code = chan.recv_exit_status()

        if ret_code and raise_when_error:
            raise RuntimeError("Cmd: %s\nReturn code: %s\nstdout: %s"
                               % (cmd, ret_code, stdout))
        if get_stderr:
            return ret_code, stdout, stderr
        else:
            return ret_code, stdout
    finally:
        _cleanup(ssh_client)


def execute_command(cmd, host, username, password=None,
                    private_key_filename=None, get_stderr=False,
                    raise_when_error=True):
    ssh_client = _connect(host, username, password, private_key_filename)

    LOG.debug("Executing command %s", cmd)

    return _execute_command(ssh_client, cmd, get_stderr, raise_when_error)