1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
|
#
# Copyright (C) 2015 Red Hat
# Copyright (C) 2014 pytest-multihost contributors
# See COPYING for license
#
"""Objects for communicating with remote hosts
This class defines "SSHTransport" as ParamikoTransport (by default), or as
OpenSSHTransport (if Paramiko is not importable, or the
PYTESTMULTIHOST_SSH_TRANSPORT environment variable is set to "openssh").
"""
import os
import socket
import threading
import subprocess
from contextlib import contextmanager
import errno
import logging
import io
import sys
from pytest_multihost import util
try:
import paramiko
have_paramiko = True
except ImportError:
have_paramiko = False
DEFAULT = object()
class Transport(object):
"""Mechanism for communicating with remote hosts
The Transport can manipulate files on a remote host, and open a Command.
The base class defines an interface that specific subclasses implement.
"""
def __init__(self, host):
self.host = host
self.logger_name = '%s.%s' % (host.logger_name, type(self).__name__)
self.log = host.config.get_logger(self.logger_name)
self._command_index = 0
def get_file_contents(self, filename, encoding=None):
"""Read the named remote file and return the contents
The string will be decoded using the given encoding;
if encoding is None (default), it will be returned as a bytestring.
"""
raise NotImplementedError('Transport.get_file_contents')
def put_file_contents(self, filename, contents, encoding='utf-8'):
"""Write the given string (or bytestring) to the named remote file
The contents string will be encoded using the given encoding
(default: ``'utf-8'``), unless aleady a bytestring.
"""
raise NotImplementedError('Transport.put_file_contents')
def file_exists(self, filename):
"""Return true if the named remote file exists"""
raise NotImplementedError('Transport.file_exists')
def mkdir(self, path):
"""Make the named directory"""
raise NotImplementedError('Transport.mkdir')
def start_shell(self, argv, log_stdout=True, encoding=None):
"""Start a Shell
:param argv: The command this shell is intended to run (used for
logging only)
:param log_stdout: If false, the stdout will not be logged (useful when
binary output is expected)
:param encoding: Encoding for the resulting Command's ``stdout_text``
and ``stderr_text``.
Given a `shell` from this method, the caller can then use
``shell.stdin.write()`` to input any command(s), call ``shell.wait()``
to let the command run, and then inspect ``returncode``,
``stdout_text`` or ``stderr_text``.
Note that ``shell.stdin`` uses bytes I/O.
"""
raise NotImplementedError('Transport.start_shell')
def mkdir_recursive(self, path):
"""`mkdir -p` on the remote host"""
if not self.file_exists(path):
parent_path = os.path.dirname(path)
if path != parent_path:
self.mkdir_recursive(parent_path)
self.mkdir(path)
def get_file(self, remotepath, localpath):
"""Copy a file from the remote host to a local file"""
contents = self.get_file_contents(remotepath, encoding=None)
with open(localpath, 'wb') as local_file:
local_file.write(contents)
def put_file(self, localpath, remotepath):
"""Copy a local file to the remote host"""
with open(localpath, 'rb') as local_file:
contents = local_file.read()
self.put_file_contents(remotepath, contents, encoding=None)
def get_next_command_logger_name(self):
self._command_index += 1
return '%s.cmd%s' % (self.host.logger_name, self._command_index)
def rmdir(self, path):
"""Remove directory"""
raise NotImplementedError('Transport.rmdir')
def rename_file(self, oldpath, newpath):
"""Rename file"""
raise NotImplementedError('Transport.rename_file')
def remove_file(self, filepath):
"""Removes files"""
raise NotImplementedError('Transport.remove_file')
class _decoded_output_property(object):
"""Descriptor for on-demand decoding of a Command's output stream
"""
def __init__(self, name):
self.name = name
def __set_name__(self, cls, name):
# Sanity check (called only on Python 3.6+).
# This property expects to handle attributes named '<foo>_text'.
assert name == self.name + '_text'
def __get__(self, instance, cls=None):
if instance is None:
return self
else:
bytestring = getattr(instance, self.name + '_bytes')
decoded = bytestring.decode(instance.encoding)
setattr(instance, self.name + '_text', decoded)
return decoded
class Command(object):
"""A Popen-style object representing a remote command
Instances of this class should only be created via method of a concrete
Transport, such as start_shell.
The standard error and output are handled by this class. They're not
available for file-like reading, and are logged by default.
To make sure reading doesn't stall after one buffer fills up, they are read
in parallel using threads.
After calling wait(), ``stdout_bytes`` and ``stderr_bytes`` attributes will
be bytestrings containing the output, and ``returncode`` will contain the
exit code.
The ``stdout_text`` and ``stdout_text`` will be the corresponding output
decoded using the given ``encoding`` (default: ``'utf-8'``).
These are decoded on-demand; do not access them if a command
produces binary output.
A Command may be used as a context manager (in the ``with`` statement).
Exiting the context will automatically call ``wait()``.
This raises an exception if the exit code is not 0, unless the
``raiseonerr`` attribute is set to false before exiting the context.
"""
def __init__(self, argv, logger_name=None, log_stdout=True,
get_logger=None, encoding='utf-8'):
self.returncode = None
self.argv = argv
self._done = False
if logger_name:
self.logger_name = logger_name
else:
self.logger_name = '%s.%s' % (self.__module__, type(self).__name__)
if get_logger is None:
get_logger = logging.getLogger
self.get_logger = get_logger
self.log = get_logger(self.logger_name)
self.encoding = encoding
self.raiseonerr = True
stdout_text = _decoded_output_property('stdout')
stderr_text = _decoded_output_property('stderr')
def wait(self, raiseonerr=DEFAULT):
"""Wait for the remote process to exit
Raises an exception if the exit code is not 0, unless ``raiseonerr`` is
true.
When ``raiseonerr`` is not specified as argument, the ``raiseonerr``
attribute is used.
"""
if raiseonerr is DEFAULT:
raiseonerr = self.raiseonerr
if self._done:
return self.returncode
self._end_process()
self._done = True
if raiseonerr and self.returncode:
self.log.error('Exit code: %s', self.returncode)
raise subprocess.CalledProcessError(self.returncode, self.argv)
else:
self.log.debug('Exit code: %s', self.returncode)
return self.returncode
def _end_process(self):
"""Wait until the process exits and output is received, close channel
Called from wait()
"""
raise NotImplementedError()
def __enter__(self):
return self
def __exit__(self, *exc_info):
self.wait(raiseonerr=self.raiseonerr)
class ParamikoTransport(Transport):
"""Transport that uses the Paramiko SSH2 library"""
def __init__(self, host):
super(ParamikoTransport, self).__init__(host)
sock = socket.create_connection((host.external_hostname,
host.ssh_port))
self._transport = transport = paramiko.Transport(sock)
transport.connect(hostkey=host.host_key)
if host.ssh_key_filename:
filename = os.path.expanduser(host.ssh_key_filename)
key = paramiko.RSAKey.from_private_key_file(filename)
self.log.debug(
'Authenticating with private RSA key using user %s' %
host.ssh_username)
transport.auth_publickey(username=host.ssh_username, key=key)
elif host.ssh_password:
self.log.debug('Authenticating with password using user %s' %
host.ssh_username)
transport.auth_password(username=host.ssh_username,
password=host.ssh_password)
else:
self.log.critical('No SSH credentials configured')
raise RuntimeError('No SSH credentials configured')
@contextmanager
def sftp_open(self, filename, mode='r'):
"""Context manager that provides a file-like object over a SFTP channel
This provides compatibility with older Paramiko versions.
(In Paramiko 1.10+, file objects from `sftp.open` are directly usable
as context managers).
"""
file = self.sftp.open(filename, mode)
try:
yield file
finally:
file.close()
@property
def sftp(self):
"""Paramiko SFTPClient connected to this host"""
try:
return self._sftp
except AttributeError:
transport = self._transport
self._sftp = paramiko.SFTPClient.from_transport(transport)
return self._sftp
def get_file_contents(self, filename, encoding=None):
"""Read the named remote file and return the contents as a string"""
self.log.debug('READ %s', filename)
with self.sftp_open(filename, 'rb') as f:
result = f.read()
if encoding:
result = result.decode(encoding)
return result
def put_file_contents(self, filename, contents, encoding=None):
"""Write the given string to the named remote file"""
self.log.info('WRITE %s', filename)
if encoding and not isinstance(contents, bytes):
contents = contents.encode(encoding)
with self.sftp_open(filename, 'wb') as f:
f.write(contents)
def file_exists(self, filename):
"""Return true if the named remote file exists"""
self.log.debug('STAT %s', filename)
try:
self.sftp.stat(filename)
except IOError as e:
if e.errno == errno.ENOENT:
return False
else:
raise
return True
def mkdir(self, path):
self.log.info('MKDIR %s', path)
self.sftp.mkdir(path)
def start_shell(self, argv, log_stdout=True, encoding='utf-8'):
logger_name = self.get_next_command_logger_name()
ssh = self._transport.open_channel('session')
self.log.info('RUN %s', argv)
return SSHCommand(ssh, argv, logger_name=logger_name,
log_stdout=log_stdout,
get_logger=self.host.config.get_logger,
encoding=encoding)
def get_file(self, remotepath, localpath):
self.log.debug('GET %s', remotepath)
self.sftp.get(remotepath, localpath)
def put_file(self, localpath, remotepath):
self.log.info('PUT %s', remotepath)
self.sftp.put(localpath, remotepath)
def rmdir(self, path):
self.log.info('RMDIR %s', path)
self.sftp.rmdir(path)
def remove_file(self, filepath):
self.log.info('REMOVE FILE %s', filepath)
self.sftp.remove(filepath)
def rename_file(self, oldpath, newpath):
self.log.info('RENAME %s to %s', oldpath, newpath)
self.sftp.rename(oldpath, newpath)
class OpenSSHTransport(Transport):
"""Transport that uses the `ssh` binary"""
def __init__(self, host):
super(OpenSSHTransport, self).__init__(host)
self.control_dir = util.TempDir()
self.ssh_argv = self._get_ssh_argv()
# Run a "control master" process. This serves two purposes:
# - Establishes a control socket; other SSHs will connect to it
# and reuse the same connection. This way the slow handshake
# only needs to be done once
# - Writes the host to known_hosts so stderr of "real" connections
# doesn't contain the "unknown host" warning
# Popen closes the stdin pipe when it's garbage-collected, so
# this process will exit when it's no longer needed
command = ['-o', 'ControlMaster=yes', '/usr/bin/cat']
self.control_master = self._run(command, collect_output=False)
def _get_ssh_argv(self):
"""Return the path to SSH and options needed for every call"""
control_file = os.path.join(self.control_dir.path, 'control')
known_hosts_file = os.path.join(self.control_dir.path, 'known_hosts')
argv = ['ssh',
'-l', self.host.ssh_username,
'-o', 'ControlPath=%s' % control_file,
'-o', 'StrictHostKeyChecking=no',
'-o', 'UserKnownHostsFile=%s' % known_hosts_file]
if self.host.ssh_key_filename:
key_filename = os.path.expanduser(self.host.ssh_key_filename)
argv.extend(['-i', key_filename])
# disable password prompt
argv.extend(['-o', 'BatchMode=yes'])
elif self.host.ssh_password:
password_file = os.path.join(self.control_dir.path, 'password')
with open(password_file, 'w') as f:
os.fchmod(f.fileno(), 0o600)
f.write(self.host.ssh_password)
f.write('\n')
argv = ['sshpass', '-f', password_file] + argv
else:
self.log.critical('No SSH credentials configured')
raise RuntimeError('No SSH credentials configured')
argv.append(self.host.external_hostname)
self.log.debug('SSH invocation: %s', argv)
return argv
def start_shell(self, argv, log_stdout=True, encoding='utf-8'):
self.log.info('RUN %s', argv)
command = self._run(['bash'], argv=argv, log_stdout=log_stdout,
encoding=encoding)
return command
def _run(self, command, log_stdout=True, argv=None, collect_output=True,
encoding='utf-8'):
"""Run the given command on the remote host
:param command: Command to run (appended to the common SSH invocation)
:param log_stdout: If false, stdout will not be logged
:param argv: Command to log (if different from ``command``
:param collect_output: If false, no output will be collected
"""
if argv is None:
argv = command
logger_name = self.get_next_command_logger_name()
ssh = SSHCallWrapper(self.ssh_argv + list(command))
return SSHCommand(ssh, argv, logger_name, log_stdout=log_stdout,
collect_output=collect_output,
get_logger=self.host.config.get_logger,
encoding=encoding)
def file_exists(self, path):
self.log.info('STAT %s', path)
cmd = self._run(['ls', path], log_stdout=False)
cmd.wait(raiseonerr=False)
return cmd.returncode == 0
def mkdir(self, path):
self.log.info('MKDIR %s', path)
cmd = self._run(['mkdir', path])
cmd.wait()
def put_file_contents(self, filename, contents, encoding='utf-8'):
self.log.info('PUT %s', filename)
if encoding and not isinstance(contents, bytes):
contents = contents.encode(encoding)
cmd = self._run(['tee', filename], log_stdout=False)
cmd.stdin.write(contents)
cmd.wait()
assert cmd.stdout_bytes == contents
def get_file_contents(self, filename, encoding=None):
self.log.info('GET %s', filename)
cmd = self._run(['cat', filename], log_stdout=False)
cmd.wait(raiseonerr=False)
if cmd.returncode == 0:
result = cmd.stdout_bytes
if encoding:
result = result.decode(encoding)
return result
else:
raise IOError('File %r could not be read' % filename)
def rmdir(self, path):
self.log.info('RMDIR %s', path)
cmd = self._run(['rmdir', path])
cmd.wait()
def remove_file(self, filepath):
self.log.info('REMOVE FILE %s', filepath)
cmd = self._run(['rm', filepath])
cmd.wait()
if cmd.returncode != 0:
raise IOError('File %r could not be deleted' % filepath)
def rename_file(self, oldpath, newpath):
self.log.info('RENAME %s TO %s', oldpath, newpath)
cmd = self._run(['mv', oldpath, newpath])
cmd.wait()
if cmd.returncode != 0:
raise IOError('File %r could not be renamed to %r '
% (oldpath, newpath))
class SSHCallWrapper(object):
"""Adapts a /usr/bin/ssh call to the paramiko.Channel interface
This only wraps what SSHCommand needs.
"""
def __init__(self, command):
self.command = command
def invoke_shell(self):
self.command = subprocess.Popen(
self.command,
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
def makefile(self, mode):
return {
'wb': self.command.stdin,
'rb': self.command.stdout,
}[mode]
def makefile_stderr(self, mode):
assert mode == 'rb'
return self.command.stderr
def recv_exit_status(self):
return self.command.wait()
def close(self):
return self.command.wait()
class SSHCommand(Command):
"""Command implementation for ParamikoTransport and OpenSSHTranspport"""
def __init__(self, ssh, argv, logger_name, log_stdout=True,
collect_output=True, encoding='utf-8', get_logger=None):
super(SSHCommand, self).__init__(argv, logger_name,
log_stdout=log_stdout,
get_logger=get_logger,
encoding=encoding)
self._stdout_lines = []
self._stderr_lines = []
self.running_threads = set()
self._ssh = ssh
self.log.debug('RUN %s', argv)
self._ssh.invoke_shell()
self._use_bytes = (encoding is None)
def wrap_file(file, encoding):
if self._use_bytes:
return file
else:
return io.TextIOWrapper(file, encoding=encoding)
self.stdin = self._ssh.makefile('wb')
stdout = self._ssh.makefile('rb')
stderr = self._ssh.makefile_stderr('rb')
if collect_output:
self._start_pipe_thread(self._stdout_lines, stdout, 'out',
log_stdout)
self._start_pipe_thread(self._stderr_lines, stderr, 'err', True)
def _end_process(self):
self.stdin.close()
while self.running_threads:
self.running_threads.pop().join()
self.stdout_bytes = b''.join(self._stdout_lines)
self.stderr_bytes = b''.join(self._stderr_lines)
self.returncode = self._ssh.recv_exit_status()
self._ssh.close()
def _start_pipe_thread(self, result_list, stream, name, do_log=True):
"""Start a thread that copies lines from ``stream`` to ``result_list``
If do_log is true, also logs the lines under ``name``
The thread is added to ``self.running_threads``.
"""
log = self.get_logger(self.logger_name)
def read_stream():
for line in stream:
if do_log:
log.debug(line.rstrip(b'\n').decode('utf-8',
errors='replace'))
result_list.append(line)
thread = threading.Thread(target=read_stream)
self.running_threads.add(thread)
thread.start()
return thread
if (
not have_paramiko or
os.environ.get('PYTESTMULTIHOST_SSH_TRANSPORT') == 'openssh'
):
SSHTransport = OpenSSHTransport
else:
SSHTransport = ParamikoTransport
|