#  fsh - fast remote execution
#  Copyright (C) 1999-2001 by Per Cederqvist.
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software
#  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */

import errno
import getopt
import getpass
import os
import select
import socket
import stat
import string
import sys
import time

import fshlib
import fshconfig

class remote:
    def __init__(self, method, login, server, background_wanted):
        self.__background_wanted = background_wanted
	in_fd, self.w = os.pipe()
	self.r, out_fd = os.pipe()
	self.child = os.fork()
	if self.child == 0:
	    # child
            os.close(self.w)
            os.close(self.r)
	    cmd = string.split(method)
	    if login != 0:
		cmd = cmd + ["-l", login]
	    cmd = cmd + [server, 'in.fshd']
	    os.dup2(in_fd, 0)
	    os.dup2(out_fd, 1)
            os.close(2)
	    os.close(in_fd)
	    os.close(out_fd)
	    os.execvp(method, cmd)
	    sys.exit(1)
	# parent
	os.close(in_fd)
	os.close(out_fd)
	fshlib.set_nonblocking(self.w)
	fshlib.set_nonblocking(self.r)
	self.send_queue = []
	self.receive_queue = []
        self.state = 0

    def send(self, data):
	if data != "":
	    self.send_queue.append(data)
	if self.w == -1:
	    self.send_queue = []

    def select_action(self, r, w, e):
	if self.w in w:
	    if fshlib.write(self.w, self.send_queue) == -1:
		os.close(self.w)
		self.w = -1
		print "End of file writing to remote system."

	if self.r in r:
	    if fshlib.read(self.r, self.receive_queue, 4096) == -1:
		print "Tunnel closed."
		os.close(self.r)
		self.r = -1

    def select_set(self):
	# Return a tuple of three lists of file descriptors: the read,
	# write and exception fd sets.
	r = []
	w = []
	if self.w != -1 and self.send_queue != []:
	    w.append(self.w)
	if self.r != -1:
	    r.append(self.r)
	return (r, w, [])

    def __readline(self):
	if len(self.receive_queue) == 0:
	    return None
        ix = string.find(self.receive_queue[0], "\n")
        if ix == -1 and len(self.receive_queue) > 1:
	    self.receive_queue = [string.join(self.receive_queue, "")]
	    ix = string.find(self.receive_queue[0], "\n")
	if ix == -1:
            return None
        res = self.receive_queue[0][:ix]
        self.receive_queue[0] = self.receive_queue[0][ix+1:]
        return res

    def poll_response(self):
        while self.state == 0:
            s = self.__readline()
            if s == None:
                return [None, None, None]
            if s == "fsh 1":
                if self.__background_wanted:
                    background()
                self.state = 1
		print "Connection established"
                break
        return fshlib.parse_line(self.receive_queue, 1)

    def is_closed(self):
	return self.r == -1 or self.w == -1

class client:
    def __init__(self, socket):
        self.socket = socket.fileno()
        self.socket_object = socket
        self.send_queue = []
        self.receive_queue = []
	self.pending_close = 0

    def select_set(self):
        r = []
        w = []
        e = []
        live = 0
        if self.socket != -1:
            r.append(self.socket)
            if self.send_queue != []:
                w.append(self.socket)
            live = 1
        return (r, w, e, live)

    def select_action(self, r, w, e):
        # Handle any possible read/write.
        # Return 0 normally, 1 if the connection was unexpectedly lost.
        if self.socket == -1:
            return 0
        if self.socket in r:
	    if fshlib.read(self.socket, self.receive_queue, 4096) == -1:
		self.socket_object.close()
		self.socket = -1
                return 1

        if self.socket in w:
	    if fshlib.write(self.socket, self.send_queue) == -1:
		self.socket_object.close()
		self.socket = -1
		return 1
	    if self.pending_close and self.send_queue == []:
		self.socket_object.close()
		self.socket = -1

        return 0
                
    def poll_command(self):
        # return a parsed command, if one is available
        return fshlib.parse_line(self.receive_queue, 0)

    def send(self, data):
	if data != "":
	    self.send_queue.append(data)
	if self.socket == -1:
	    self.send_queue = []

    def close(self):
	self.pending_close = 1

    def is_closed(self):
	return self.socket == -1

    def drained(self):
	return self.send_queue == []

class fshd:
    def __init__(self, server, method, login, use_l_flag, background_wanted,
                 timeout):
	# Connect to server using method.  Parse away leading
	# noice until we receive "^fsh ".  Check the version number.
        sockname, sockdir = fshlib.fshd_socket(server, method, login)
        try:
            os.mkdir(sockdir, 0700)
        except os.error, (eno, emsg):
            if eno == errno.EEXIST:
                status = os.lstat(sockdir)
                if status[stat.ST_UID] != os.getuid():
                    raise
                if not stat.S_ISDIR(status[stat.ST_MODE]):
                    raise
                os.chmod(sockdir, 0700)
            else:
                raise
	self.r = remote(method, use_l_flag and login, server,
                        background_wanted)
        self.sessions = {}
        self.listen = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
	self.listen.setblocking(0)
        self.listen.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
	try:
	    os.unlink(sockname)
	except os.error, (eno, emsg):
	    if eno != errno.ENOENT:
		raise
        self.listen.bind(sockname)
        self.listen.listen(3)
        self.__timeout = timeout
        self.__expire_time = None
        self.setup_timer()

    def setup_timer(self):
        if self.__expire_time == None and self.__timeout != None:
            self.__expire_time = time.time() + self.__timeout

    def cancel_timer(self):
        if self.__expire_time != None:
            self.__expire_time = None

    def timeout(self):
        if self.__expire_time == None:
            return 10
        else:
            left = self.__expire_time - time.time()
            if left < 0:
                left = 0
            if left > 60:
                left = 60
            return left

    def expired(self):
        return self.__expire_time != None and self.__expire_time <= time.time()

    def toploop(self):
        try:
            while 1:
		drained = self.r.is_closed()
                r, w, e = self.r.select_set()
                r.append(self.listen.fileno())
                live = 0
                for s in self.sessions.keys():
		    if drained and not self.sessions[s].drained():
			drained = 0
                    r1, w1, e1, live1 = self.sessions[s].select_set()
                    r = r + r1
                    w = w + w1
                    e = e + e1
                    live = live + live1
		if drained:
		    sys.exit(0)
                if live > 0:
                    self.cancel_timer()
                else:
                    self.setup_timer()
                    if self.expired():
                        sys.exit(0)
                r, w, e = select.select(r, w, e, self.timeout())
                self.r.select_action(r, w, e)
                if self.listen.fileno() in r:
                    sock, addr = self.listen.accept()
		    sock.setblocking(0)
		    s = 1
		    while (self.sessions.has_key(s) and \
			   not self.sessions[s].is_closed()):
			s = s + 1
                    print s, "start"
                    self.sessions[s] = client(sock)
                while 1:
                    cmd, s, data = self.r.poll_response()
                    if cmd == None:
                        break
                    elif cmd in ["stdout", "stderr", "exit", "signal-exit", \
				 "stdin-flow"]:
			self.sessions[s].send("%s %s\n" % (
			    cmd, fshlib.hollerith(data)))
		    elif cmd in ["eof-stdin", "eof-stdout", "eof-stderr"]:
			self.sessions[s].send("%s\n" % cmd)
		    elif cmd == "eos":
			self.sessions[s].send("eos\n")
			self.sessions[s].close()
			print s, "EOS"
		    else:
			print "Unknown command from in.fshd:", cmd
                for s in self.sessions.keys():
                    if self.sessions[s].select_action(r, w, e) == 1:
                        # fsh died, so kill this session.
                        self.r.send("eos %d\n" % s)
                    while 1:
                        cmd, data = self.sessions[s].poll_command()
                        if cmd == None:
                            break
                        if data == None:
                            self.r.send("%s %d\n" % (cmd, s))
                        else:
                            self.r.send("%s %d %s\n" % (
                                cmd, s, fshlib.hollerith(data)))
			    if cmd == "new":
				print s, "$", data
        except KeyboardInterrupt:
            for s in self.sessions.keys():
                self.sessions[s].send("eos\n")
            sys.exit(0)

def background():
    try:
	signal.signal(signal.SIGTTOU, signal.SIG_IGN)
    except:
	pass
    try:
	signal.signal(signal.SIGTTIN, signal.SIG_IGN)
    except:
	pass
    try:
	signal.signal(signal.SIGTSTP, signal.SIG_IGN)
    except:
	pass

    if os.fork() > 0:
	os._exit(0)	# parent

    try:
	os.setsid()
    except:
	pass

    for i in range(3):
    	try:
	    os.close(i)
	except os.error, (eno, emsg):
	    if eno != errno.EBADF:
		raise
    fd = os.open("/dev/null", os.O_RDWR)
    assert fd == 0
    fd = os.dup(0)
    assert fd == 1
    fd = os.dup(0)
    assert fd == 2

def open_fds():
    """Return a list of file descriptors that might be open.
    """

    try:
        import resource
    except ImportError:
        # Fall back to using a large value.
        return range(1024)
    return range(resource.getrlimit(resource.RLIMIT_NOFILE)[0] + 1)

def usage(ret):
    sys.stderr.write(
	"fshd: usage: fshd [options] host\n"
	"             fshd { -h | --help }\n"
	"             fshd { -V | --version }\n")
    if ret == 0:
        sys.stderr.write(
            "Options:\n"
            "  -b --background       Run in the background.\n"
            "  -r method             Use ``method'' (e.g. ``rsh'') instead "
            "of ssh.\n"
            "  -l login              Log in as user ``login''.\n"
            "  -T --timeout=timeout  Set idle timeout (in seconds); exit when "
            "no session\n                        has existed for this long. "
            "0 disables the timeout.\n")

    sys.exit(ret)

def main():
    method = "ssh"
    login = None
    use_l_flag = 0
    print_version = 0
    b_flag = 0
    timeout = fshconfig.default_fshd_timeout
    try:
	opts, args = getopt.getopt(sys.argv[1:], "hr:l:bVT:",
                                   ["version", "help", "background",
                                    "timeout="])
    except getopt.error, msg:
	sys.stderr.write(str(msg) + "\n")
	sys.exit(1)
    for opt, val in opts:
	if opt == "-r":
	    method = fshlib.shell_unquote(val)
	elif opt == "-l":
	    login = val
	    use_l_flag = 1
	elif opt == "-b" or opt == "--background":
	    b_flag = 1
	elif opt == "-V" or opt == "--version":
	    print_version = 1
	elif opt == "-h" or opt == "--help":
	    usage(0)
        elif opt == "-T" or opt == "--timeout":
            timeout = string.atof(val)
        else:
            assert 0
    if print_version:
	fshlib.print_version("fshd")
    if len(args) != 1:
	usage(1)
    if login == None:
	login = getpass.getuser()

    for i in open_fds():
        if i > 2:
            try:
                os.close(i)
            except os.error, (eno, emsg):
                if eno != errno.EBADF:
                    raise

    s = fshd(args[0], method, login, use_l_flag, b_flag, timeout)
    s.toploop()

if __name__ == '__main__':
    main()
