1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
|
#!/usr/bin/python
from __future__ import with_statement
import threading
import random
import Queue
import traceback
import socket
import select
import sys
import time
import paramiko
SSH_PORT = 22
REMOTE_PORT = 3306
# timeout for closing an unused tunnel
TUNNEL_TIMEOUT = 60
# paramiko 1.6 didn't have this class
if hasattr(paramiko, "WarningPolicy"):
WarningPolicy = paramiko.WarningPolicy
else:
class WarningPolicy(paramiko.MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):
print 'WARNING: Unknown %s host key for %s: %s' % (key.get_name(), hostname, hexlify(key.get_fingerprint()))
class Tunnel(threading.Thread):
"""This class is a threaded implementation of an SSH tunnel.
You should not access the attributes that starts with an underscore outside this thread
of execution (e.g. self._server) for this could run into race conditions. Even when
accessing its public attributes (those that don't start with an underscore) you should
be careful of acquiring the self.lock reentrant lock (and releasing it once done):
with tunnel.lock:
if tunnel.connecting:
... whatever...
"""
def __init__(self, q, server, username, target, password, keyfile):
super(Tunnel, self).__init__()
self.daemon = True
self._server = server
self._username = username
self._target = target
self._password = password
self._keyfile = keyfile
# Acquire and release this lock while accessing the attributes of objects
# of this class from the main thread:
self.lock = threading.RLock()
# This event marks when a random port is selected to be used:
self.port_is_set = threading.Event()
self.local_port = None
self._listen_sock = None
self.q = q
self._shutdown = False
self.connecting = False
self._client = None
self._connections = []
def is_connecting(self):
with self.lock:
return self.connecting
def run(self):
sys.stdout.write('Thread started\n') # sys.stdout.write is thread safe while print isn't
# Create a socket and pick a random port number for it:
self._listen_sock = socket.socket()
while True:
local_port = random.randint(1024, 65535)
try:
self._listen_sock.bind(('127.0.0.1', local_port))
self._listen_sock.listen(2)
with self.lock:
self.local_port = local_port
break
except socket.error, exc:
sys.stdout.write('Socket error: %s for port %d\n' % (exc, local_port) )
err, msg = exc.args
if err == 22:
continue # retry
self.notify('ERROR', sys.exc_info())
raise exc
finally:
self.port_is_set.set()
with self.lock:
self.connecting = True
if self._keyfile:
self.notify('INFO', 'Connecting to SSH server at %s:%s using key %s...' % (self._server[0], self._server[1], self._keyfile) )
else:
self.notify('INFO', 'Connecting to SSH server at %s:%s...' % (self._server[0], self._server[1]) )
if not self._connect_ssh():
self._listen_sock.close()
self._shutdown = True
else:
self.notify('INFO', 'Connection opened')
with self.lock:
self.connecting = False
del self._password
while not self._shutdown:
try:
socks = [self._listen_sock]
for sock, chan in self._connections:
socks.append(sock)
socks.append(chan)
r, w, x = select.select(socks, [], [], TUNNEL_TIMEOUT)
except Exception, e:
if not self._shutdown:
self.notify('ERROR', 'Error while forwarding data: %r' % e)
break
if not r and len(socks) <= 1:
self.notify('INFO', 'Closing tunnel to %s:%s for inactivity...\n' % (self._server[0], self._server[1]) )
break
if self._listen_sock in r:
self.notify('INFO', 'New client connection\n')
self.accept_client()
closed = []
for sock, chan in self._connections:
if sock in r:
data = sock.recv(1024)
if not data:
closed.append((sock, chan))
else:
chan.send(data)
if chan in r:
data = chan.recv(1024)
if not data:
closed.append((sock, chan))
else:
sock.send(data)
for item in set(closed): # FIXME: Why a set is needed here?
sock, chan = item
try:
sock.close()
except:
pass
try:
chan.close()
except:
pass
self.notify('INFO', 'Client for %s disconnected\n' % local_port)
self._connections.remove(item)
# Time to shutdown:
for sock, chan in self._connections:
try:
sock.close()
except:
pass
try:
chan.close()
except:
pass
self._listen_sock.close()
self._client.close()
def notify(self, msg_type, msg_object):
self.q.put((msg_type, msg_object))
def match(self, server, username, target):
with self.lock:
return self._server == server and self._username == username and self._target == target
def _connect_ssh(self):
"""Create the SSH client and set up the connection.
Any exception coming from paramiko will be notified as an error
that would cause the failure of the connection. Some of these are:
paramiko.AuthenticationException --- raised when authentication failed for some reason
paramiko.PasswordRequiredException --- raised when a password is needed to unlock a private key file;
this is a subclass of paramiko.AuthenticationException
paramiko.PasswordRequiredException --- raised when a password is needed to unlock a private key file;
also a subclass of paramiko.AuthenticationException
"""
try:
self._client = paramiko.SSHClient()
self._client.load_system_host_keys()
self._client.set_missing_host_key_policy(WarningPolicy())
self._client.connect(self._server[0], self._server[1], username=self._username, key_filename=self._keyfile, password=self._password)
except:
self.notify('ERROR', sys.exc_info())
return False
else:
return True
def close(self):
self.notify('INFO', 'Closing tunnel')
self._listen_sock.close()
self._shutdown = True
def accept_client(self):
try:
local_sock, peeraddr = self._listen_sock.accept()
except Exception, e:
self.notify('ERROR', 'Error accepting new tunnel client: %r' % e)
return
self.notify('INFO', 'Client connection established')
transport = self._client.get_transport()
try:
sshchan = transport.open_channel('direct-tcpip', self._target, local_sock.getpeername())
except Exception, e:
self.notify('ERROR', 'Remote connection to %s:%d failed: %r' % (self._target[0], self._target[1], e) )
local_sock.close()
return
if sshchan is None:
self.notify('ERROR', 'Remote connection to %s:%d was rejected by the SSH server.' % (self._target[0], self._target[1]))
local_sock.close()
return
self.notify('INFO', 'Tunnel now open %r -> %r -> %r' % (local_sock.getpeername(), sshchan.getpeername(), self._target))
self._connections.append((local_sock, sshchan))
class TunnelManager:
def __init__(self):
self.tunnel_by_port = {}
self.inpipe = sys.stdin
self.outpipe = sys.stdout
def _address_port_tuple(self, raw_address, default_port):
if type(raw_address) is str:
if ':' in raw_address:
address, port = raw_address.split(':', 1)
try:
port = int(port)
except:
port = default_port
return (address, port)
else:
return (raw_address, default_port)
else:
return raw_address
def lookup_tunnel(self, server, username, target):
server = self._address_port_tuple(server, default_port=SSH_PORT)
target = self._address_port_tuple(target, default_port=REMOTE_PORT)
for port, tunnel in self.tunnel_by_port.iteritems():
if tunnel.match(server, username, target) and tunnel.isAlive():
with tunnel.lock:
return tunnel.local_port
return None
def open_tunnel(self, server, username, password, keyfile, target):
try:
port = self.open_ssh(server, username, password, keyfile, target)
except Exception, exc:
import traceback
traceback.print_exc()
return (False, str(exc))
return (True, port)
def open_ssh(self, server, username, password, keyfile, target):
server = self._address_port_tuple(server, default_port=SSH_PORT)
target = self._address_port_tuple(target, default_port=REMOTE_PORT)
password = password or ''
keyfile = keyfile or None
found = None
for tunnel in self.tunnel_by_port.itervalues():
if tunnel.match(server, username, target) and tunnel.isAlive():
found = tunnel
break
if found:
with tunnel.lock:
print 'Reusing tunnel at port %d' % tunnel.local_port
return tunnel.local_port
else:
tunnel = Tunnel(Queue.Queue(), server, username, target, password, keyfile)
tunnel.start()
tunnel.port_is_set.wait()
with tunnel.lock:
port = tunnel.local_port
self.tunnel_by_port[port] = tunnel
return port
def wait_connection(self, port):
tunnel = self.tunnel_by_port.get(port)
if not tunnel:
return 'Could not find a tunnel for port %d' % port
error = None
if tunnel.isAlive():
while True:
# Process any message in queue. Every retrieved message is printed.
# If an error is detected in the queue, exit returning its message:
try:
msg_type, msg = tunnel.q.get_nowait()
except Queue.Empty:
pass
else:
_msg = msg
if type(msg) is tuple:
msg = '\n' + ''.join(traceback.format_exception(*msg))
_msg = str(_msg[1])
print '%s: %s' % (msg_type, msg)
if msg_type == 'ERROR':
error = _msg
break # Exit returning the error message
if not tunnel.is_connecting() or not tunnel.isAlive():
break
time.sleep(0.3)
return error
def get_message(self, port):
tunnel = self.tunnel_by_port[port]
return tunnel.q.get(block=False)
def close(self, port):
pass
# tunnels auto-close when inactive
#tunnel = self.tunnel_by_port.get(port, None)
#if tunnel:
# tunnel.num_clients -= 1
# if tunnel.num_clients == 0:
# tunnel.close()
# del self.tunnel_by_port[port]
def send(self, code, arg=''):
if arg:
self.outpipe.write(code + ' ' + arg + '\n')
else:
self.outpipe.write(code + '\n')
self.outpipe.flush()
def shutdown(self):
for tunnel in self.tunnel_by_port.itervalues():
tunnel.close()
# FIXME: It seems that this function is never called. Should we remove it?
def wait_requests(self):
#print "SSH Tunnel Manager started, waiting for requests..."
self.send("READY")
while True:
request = self.inpipe.readline()
if not request:
#print "Exiting tunnel manager..."
break
try:
cmd, args = eval(request, {}, {})
except:
print "Ignoring invalid request"
self.send("ERROR", "Invalid request")
continue
if cmd == "LOOKUP":
try:
port = self.lookup_tunnel(*args)
if port is not None:
self.send("OK", str(port))
else:
self.send("ERROR", "not found")
except Exception, exc:
self.send("ERROR", str(exc))
elif cmd == "OPENSSH":
try:
port = self.open_ssh(*args)
self.send("OK", str(port))
except Exception, exc:
self.send("ERROR", str(exc))
elif cmd == "CLOSE":
#self.close(args[0])
self.send("OK")
elif cmd == "WAIT":
# wait for the SSH connection to be established
error = self.wait_connection(args)
if not error:
self.send("OK")
else:
self.send("ERROR "+error)
elif cmd == "MESSAGE":
msg = self.get_message(args)
if msg:
self.send(msg)
else:
self.send("NONE")
else:
print "Invalid request %s" % request
self.send("ERROR", "Invalid request")
"""
if "--single" in sys.argv:
target = sys.argv[2]
if "-pw" in sys.argv:
password = sys.argv[sys.argv.index("-pw")+1]
else:
password = None
if "-i" in sys.argv:
keyfile = sys.argv[sys.argv.index("-i")+1]
else:
keyfile = None
server = sys.argv[-1]
tunnel = Tunnel(None, False)
if "@" in server:
username, server = server.split("@", 1)
else:
username = ""
print "Starting tunnel..."
if type(server) == str:
if ':' in server:
server = server.split(":", 1)
server = (server[0], int(server[1]))
else:
server = (server, SSH_PORT)
if type(target) == str:
if ':' in target:
target = target.split(":", 1)
target = (target[0], int(target[1]))
else:
target = (target, REMOTE_PORT)
tunnel.start(server, username, password or "", keyfile, target)
else:
tm = TunnelManager()
sys.stdout = sys.stderr
try:
tm.wait_requests()
except KeyboardInterrupt, e:
pass
"""
|