1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
|
# Copyright (C) 2008 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common SSH management logic."""
import functools
import multiprocessing
import os
import re
import signal
import subprocess
import sys
import tempfile
import time
import platform_utils
from repo_trace import Trace
PROXY_PATH = os.path.join(os.path.dirname(__file__), 'git_ssh')
def _run_ssh_version():
"""run ssh -V to display the version number"""
return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
def _parse_ssh_version(ver_str=None):
"""parse a ssh version string into a tuple"""
if ver_str is None:
ver_str = _run_ssh_version()
m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
if m:
return tuple(int(x) for x in m.group(1).split('.'))
else:
return ()
@functools.lru_cache(maxsize=None)
def version():
"""return ssh version as a tuple"""
try:
return _parse_ssh_version()
except subprocess.CalledProcessError:
print('fatal: unable to detect ssh version', file=sys.stderr)
sys.exit(1)
URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
class ProxyManager:
"""Manage various ssh clients & masters that we spawn.
This will take care of sharing state between multiprocessing children, and
make sure that if we crash, we don't leak any of the ssh sessions.
The code should work with a single-process scenario too, and not add too much
overhead due to the manager.
"""
# Path to the ssh program to run which will pass our master settings along.
# Set here more as a convenience API.
proxy = PROXY_PATH
def __init__(self, manager):
# Protect access to the list of active masters.
self._lock = multiprocessing.Lock()
# List of active masters (pid). These will be spawned on demand, and we are
# responsible for shutting them all down at the end.
self._masters = manager.list()
# Set of active masters indexed by "host:port" information.
# The value isn't used, but multiprocessing doesn't provide a set class.
self._master_keys = manager.dict()
# Whether ssh masters are known to be broken, so we give up entirely.
self._master_broken = manager.Value('b', False)
# List of active ssh sesssions. Clients will be added & removed as
# connections finish, so this list is just for safety & cleanup if we crash.
self._clients = manager.list()
# Path to directory for holding master sockets.
self._sock_path = None
def __enter__(self):
"""Enter a new context."""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Exit a context & clean up all resources."""
self.close()
def add_client(self, proc):
"""Track a new ssh session."""
self._clients.append(proc.pid)
def remove_client(self, proc):
"""Remove a completed ssh session."""
try:
self._clients.remove(proc.pid)
except ValueError:
pass
def add_master(self, proc):
"""Track a new master connection."""
self._masters.append(proc.pid)
def _terminate(self, procs):
"""Kill all |procs|."""
for pid in procs:
try:
os.kill(pid, signal.SIGTERM)
os.waitpid(pid, 0)
except OSError:
pass
# The multiprocessing.list() API doesn't provide many standard list()
# methods, so we have to manually clear the list.
while True:
try:
procs.pop(0)
except:
break
def close(self):
"""Close this active ssh session.
Kill all ssh clients & masters we created, and nuke the socket dir.
"""
self._terminate(self._clients)
self._terminate(self._masters)
d = self.sock(create=False)
if d:
try:
platform_utils.rmdir(os.path.dirname(d))
except OSError:
pass
def _open_unlocked(self, host, port=None):
"""Make sure a ssh master session exists for |host| & |port|.
If one doesn't exist already, we'll create it.
We won't grab any locks, so the caller has to do that. This helps keep the
business logic of actually creating the master separate from grabbing locks.
"""
# Check to see whether we already think that the master is running; if we
# think it's already running, return right away.
if port is not None:
key = '%s:%s' % (host, port)
else:
key = host
if key in self._master_keys:
return True
if self._master_broken.value or 'GIT_SSH' in os.environ:
# Failed earlier, so don't retry.
return False
# We will make two calls to ssh; this is the common part of both calls.
command_base = ['ssh', '-o', 'ControlPath %s' % self.sock(), host]
if port is not None:
command_base[1:1] = ['-p', str(port)]
# Since the key wasn't in _master_keys, we think that master isn't running.
# ...but before actually starting a master, we'll double-check. This can
# be important because we can't tell that that 'git@myhost.com' is the same
# as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
check_command = command_base + ['-O', 'check']
try:
Trace(': %s', ' '.join(check_command))
check_process = subprocess.Popen(check_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
check_process.communicate() # read output, but ignore it...
isnt_running = check_process.wait()
if not isnt_running:
# Our double-check found that the master _was_ infact running. Add to
# the list of keys.
self._master_keys[key] = True
return True
except Exception:
# Ignore excpetions. We we will fall back to the normal command and print
# to the log there.
pass
command = command_base[:1] + ['-M', '-N'] + command_base[1:]
try:
Trace(': %s', ' '.join(command))
p = subprocess.Popen(command)
except Exception as e:
self._master_broken.value = True
print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
% (host, port, str(e)), file=sys.stderr)
return False
time.sleep(1)
ssh_died = (p.poll() is not None)
if ssh_died:
return False
self.add_master(p)
self._master_keys[key] = True
return True
def _open(self, host, port=None):
"""Make sure a ssh master session exists for |host| & |port|.
If one doesn't exist already, we'll create it.
This will obtain any necessary locks to avoid inter-process races.
"""
# Bail before grabbing the lock if we already know that we aren't going to
# try creating new masters below.
if sys.platform in ('win32', 'cygwin'):
return False
# Acquire the lock. This is needed to prevent opening multiple masters for
# the same host when we're running "repo sync -jN" (for N > 1) _and_ the
# manifest <remote fetch="ssh://xyz"> specifies a different host from the
# one that was passed to repo init.
with self._lock:
return self._open_unlocked(host, port)
def preconnect(self, url):
"""If |uri| will create a ssh connection, setup the ssh master for it."""
m = URI_ALL.match(url)
if m:
scheme = m.group(1)
host = m.group(2)
if ':' in host:
host, port = host.split(':')
else:
port = None
if scheme in ('ssh', 'git+ssh', 'ssh+git'):
return self._open(host, port)
return False
m = URI_SCP.match(url)
if m:
host = m.group(1)
return self._open(host)
return False
def sock(self, create=True):
"""Return the path to the ssh socket dir.
This has all the master sockets so clients can talk to them.
"""
if self._sock_path is None:
if not create:
return None
tmp_dir = '/tmp'
if not os.path.exists(tmp_dir):
tmp_dir = tempfile.gettempdir()
if version() < (6, 7):
tokens = '%r@%h:%p'
else:
tokens = '%C' # hash of %l%h%p%r
self._sock_path = os.path.join(
tempfile.mkdtemp('', 'ssh-', tmp_dir),
'master-' + tokens)
return self._sock_path
|