File: ssh_client.py

package info (click to toggle)
waagent 2.15.0.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 9,820 kB
  • sloc: python: 60,164; xml: 4,126; sh: 1,354; makefile: 22
file content (101 lines) | stat: -rw-r--r-- 4,404 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3

# Microsoft Azure Linux Agent
#
# Copyright 2018 Microsoft Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import re

from pathlib import Path

from azurelinuxagent.common.future import UTC

from tests_e2e.tests.lib import shell
from tests_e2e.tests.lib.retry import retry_ssh_run

ATTEMPTS: int = 3
ATTEMPT_DELAY: int = 30


class SshClient(object):
    def __init__(self, ip_address: str, username: str, identity_file: Path, port: int = 22):
        self.ip_address: str = ip_address
        self.username: str = username
        self.identity_file: Path = identity_file
        self.port: int = port

    def run_command(self, command: str, use_sudo: bool = False, attempts: int = ATTEMPTS, attempt_delay: int = ATTEMPT_DELAY) -> str:
        """
        Executes the given command over SSH and returns its stdout. If the command returns a non-zero exit code,
        the function raises a CommandError.
        """
        if re.match(r"^\s*sudo\s*", command):
            raise Exception("Do not include 'sudo' in the 'command' argument, use the 'use_sudo' parameter instead")

        destination = f"ssh://{self.username}@{self.ip_address}:{self.port}"

        # Note that we add ~/bin to the remote PATH, since Python (Pypy) and other test tools are installed there.
        # Note, too, that when using sudo we need to carry over the value of PATH to the sudo session
        sudo = "sudo env PATH=$PATH PYTHONPATH=$PYTHONPATH" if use_sudo else ''
        command = [
            "ssh", "-o", "StrictHostKeyChecking=no", "-i", self.identity_file,
            destination,
            f"if [[ -e ~/bin/set-agent-env ]]; then source ~/bin/set-agent-env; fi; {sudo} {command}"
        ]
        return retry_ssh_run(lambda: shell.run_command(command), attempts, attempt_delay)

    @staticmethod
    def generate_ssh_key(private_key_file: Path) -> None:
        """
        Generates an SSH key on the given Path
        """
        shell.run_command(
            ["ssh-keygen", "-m", "PEM", "-t", "rsa", "-b", "4096", "-q", "-N", "", "-f", str(private_key_file)])

    def get_architecture(self) -> str:
        return self.run_command("uname -m").rstrip()

    def get_distro(self):
        return self.run_command("get_distro.py").rstrip()

    def get_time(self) -> datetime.datetime:
        time_string = self.run_command("date --utc '+%Y-%m-%dT%T.%6NZ'").rstrip()
        return datetime.datetime.strptime(time_string, '%Y-%m-%dT%H:%M:%S.%fZ').replace(tzinfo=UTC)

    def copy_to_node(self, local_path: Path, remote_path: Path, recursive: bool = False, attempts: int = ATTEMPTS, attempt_delay: int = ATTEMPT_DELAY) -> None:
        """
        File copy to a remote node
        """
        self._copy(local_path, remote_path, remote_source=False, remote_target=True, recursive=recursive, attempts=attempts, attempt_delay=attempt_delay)

    def copy_from_node(self, remote_path: Path, local_path: Path, recursive: bool = False, attempts: int = ATTEMPTS, attempt_delay: int = ATTEMPT_DELAY) -> None:
        """
        File copy from a remote node
        """
        self._copy(remote_path, local_path, remote_source=True, remote_target=False, recursive=recursive, attempts=attempts, attempt_delay=attempt_delay)

    def _copy(self, source: Path, target: Path, remote_source: bool, remote_target: bool, recursive: bool, attempts: int, attempt_delay: int) -> None:
        if remote_source:
            source = f"{self.username}@{self.ip_address}:{source}"
        if remote_target:
            target = f"{self.username}@{self.ip_address}:{target}"

        command = ["scp", "-o", "StrictHostKeyChecking=no", "-i", self.identity_file]
        if recursive:
            command.append("-r")
        command.extend([str(source), str(target)])

        return retry_ssh_run(lambda: shell.run_command(command), attempts, attempt_delay)