#  fsh - fast remote execution
#  Copyright (C) 1999 by Per Cederqvist.
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software
#  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */

import FCNTL
import errno
import fcntl
import getopt
import os
import select
import signal
import string
import sys

import fshlib

def set_cloexec_flag(fd):
    oldflags = fcntl.fcntl(fd, FCNTL.F_GETFD)
    fcntl.fcntl(fd, FCNTL.F_SETFD, oldflags | FCNTL.FD_CLOEXEC)

def sig_to_name(sig):
    if sig == signal.SIGHUP:
	return "HUP"
    if sig == signal.SIGINT:
	return "INT"
    if sig == signal.SIGQUIT:
	return "QUIT"
    if sig == signal.SIGILL:
	return "ILL"
    if sig == signal.SIGTRAP:
	return "TRAP"
    if sig == signal.SIGIOT:
	return "IOT"
    if sig == signal.SIGBUS:
	return "BUS"
    if sig == signal.SIGFPE:
	return "FPE"
    if sig == signal.SIGKILL:
	return "KILL"
    if sig == signal.SIGUSR1:
	return "USR1"
    if sig == signal.SIGSEGV:
	return "SEGV"
    if sig == signal.SIGUSR2:
	return "USR2"
    if sig == signal.SIGPIPE:
	return "PIPE"
    if sig == signal.SIGALRM:
	return "ALRM"
    if sig == signal.SIGTERM:
	return "TERM"
    if sig == signal.SIGCHLD:
	return "CHLD"
    if sig == signal.SIGCONT:
	return "CONT"
    if sig == signal.SIGSTOP:
	return "STOP"
    if sig == signal.SIGTSTP:
	return "TSTP"
    if sig == signal.SIGTTIN:
	return "TTIN"
    if sig == signal.SIGTTOU:
	return "TTOU"
    if sig == signal.SIGURG:
	return "URG"
    if sig == signal.SIGXCPU:
	return "XCPU"
    if sig == signal.SIGXFSZ:
	return "XFSZ"
    if sig == signal.SIGVTALRM:
	return "VTALRM"
    if sig == signal.SIGPROF:
	return "PROF"
    if sig == signal.SIGWINCH:
	return "WINCH"
    if sig == signal.SIGIO:
	return "IO"
    if sig == signal.SIGPWR:
	return "PWR"
    return str(sig)

def name_to_sig(sig):
    if sig == "HUP":
	return signal.SIGHUP
    if sig == "INT":
	return signal.SIGINT
    if sig == "QUIT":
	return signal.SIGQUIT
    if sig == "ILL":
	return signal.SIGILL
    if sig == "TRAP":
	return signal.SIGTRAP
    if sig == "IOT":
	return signal.SIGIOT
    if sig == "BUS":
	return signal.SIGBUS
    if sig == "FPE":
	return signal.SIGFPE
    if sig == "KILL":
	return signal.SIGKILL
    if sig == "USR1":
	return signal.SIGUSR1
    if sig == "SEGV":
	return signal.SIGSEGV
    if sig == "USR2":
	return signal.SIGUSR2
    if sig == "PIPE":
	return signal.SIGPIPE
    if sig == "ALRM":
	return signal.SIGALRM
    if sig == "TERM":
	return signal.SIGTERM
    if sig == "CHLD":
	return signal.SIGCHLD
    if sig == "CONT":
	return signal.SIGCONT
    if sig == "STOP":
	return signal.SIGSTOP
    if sig == "TSTP":
	return signal.SIGTSTP
    if sig == "TTIN":
	return signal.SIGTTIN
    if sig == "TTOU":
	return signal.SIGTTOU
    if sig == "URG":
	return signal.SIGURG
    if sig == "XCPU":
	return signal.SIGXCPU
    if sig == "XFSZ":
	return signal.SIGXFSZ
    if sig == "VTALRM":
	return signal.SIGVTALRM
    if sig == "PROF":
	return signal.SIGPROF
    if sig == "WINCH":
	return signal.SIGWINCH
    if sig == "IO":
	return signal.SIGIO
    if sig == "PWR":
	return signal.SIGPWR
    return signal.SIGINT

class session:
    def __init__(self, session_no, cmd):
	# Spawn a process.
	in_fd, self.stdin_fd = os.pipe()
	self.stdout_fd, out_fd = os.pipe()
	self.stderr_fd, err_fd = os.pipe()
	set_cloexec_flag(self.stdin_fd)
	set_cloexec_flag(self.stdout_fd)
	set_cloexec_flag(self.stderr_fd)
	self.child = os.fork()
	if self.child == 0:
	    # child
	    os.dup2(in_fd, 0)
	    os.dup2(out_fd, 1)
	    os.dup2(err_fd, 2)
	    os.close(in_fd)
	    os.close(out_fd)
	    os.close(err_fd)
	    shell = os.environ.get("SHELL", "/bin/sh")
	    shell_name = string.split(shell, "/")[-1]
	    os.execv(shell, [shell_name, "-c", cmd])
	    sys.exit(1)
	# parent
	os.close(in_fd)
	os.close(out_fd)
	os.close(err_fd)
	fshlib.set_nonblocking(self.stdin_fd)
	fshlib.set_nonblocking(self.stdout_fd)
	fshlib.set_nonblocking(self.stderr_fd)
	self.session_no = session_no
	self.stdin_queue = []
	self.stdin_pending_close = 0
        self.stdin_quota = fshlib.QUOTA
        self.stdout_quota = fshlib.QUOTA
        self.stderr_quota = fshlib.QUOTA
        self.stdin_counter = 0
        self.stdout_counter = 0
        self.stderr_counter = 0

    def stdin(self, data):
	# Queue DATA for later transmission to the process.
        assert(not self.stdin_pending_close)
	if data != "":
	    self.stdin_queue.append(data)
	if self.stdin_fd == -1:
	    self.stdin_queue = []

    def select_action(self, r, w, e):
	# See if there is anything to do.

	if self.stdin_fd in w:
	    sz = fshlib.write(self.stdin_fd, self.stdin_queue)
	    if sz == -1:
		os.close(self.stdin_fd)
		self.stdin_fd = -1
		# sys.stderr.write("infshd reporting eof-stdin\n")
		print "eof-stdin", self.session_no
	    else:
		self.stdin_counter = self.stdin_counter + sz
            if self.stdin_pending_close and self.stdin_queue == []:
                self.eof_stdin()
            if self.stdin_fd != -1 and \
               (self.stdin_quota - self.stdin_counter) < fshlib.QUOTA/2:

                self.stdin_quota = self.stdin_quota + fshlib.QUOTA
                print "stdin-flow", self.session_no, \
                      fshlib.hollerith(self.stdin_quota)

	if self.stdout_fd in r:
            wanted = min(4096, self.stdout_quota - self.stdout_counter)
            assert(wanted > 0)
	    queue = []
	    if fshlib.read(self.stdout_fd, queue, wanted) == -1:
                # sys.stderr.write("infshd reporting eof-stdout\n")
		print "eof-stdout", self.session_no
		os.close(self.stdout_fd)
		self.stdout_fd = -1
	    elif len(queue) > 0:
		assert(len(queue) == 1)
                # sys.stderr.write("infshd reporting stdout data\n")
                self.stdout_counter = self.stdout_counter + len(queue[0])
		print "stdout", self.session_no, fshlib.hollerith(queue[0])

	if self.stderr_fd in r:
            wanted = min(4096, self.stderr_quota - self.stderr_counter)
            assert(wanted > 0)
	    queue = []
	    if fshlib.read(self.stderr_fd, queue, wanted) == -1:
                # sys.stderr.write("infshd reporting eof-stderr\n")
		print "eof-stderr", self.session_no
		os.close(self.stderr_fd)
		self.stderr_fd = -1
	    elif len(queue) > 0:
		assert(len(queue) == 1)
                # sys.stderr.write("infshd reporting stderr data\n")
                self.stderr_counter = self.stderr_counter + len(queue[0])
		print "stderr", self.session_no, fshlib.hollerith(queue[0])


    def select_set(self):
	# Return a tuple of three lists of file descriptors: the read,
	# write and exception fd sets.
	r = []
	w = []
	if self.stdin_fd != -1 and self.stdin_queue != []:
	    w.append(self.stdin_fd)
	if self.stdout_fd != -1 and self.stdout_quota > self.stdout_counter:
	    r.append(self.stdout_fd)
	if self.stderr_fd != -1 and self.stderr_quota > self.stderr_counter:
	    r.append(self.stderr_fd)
	return (r, w, [])

    def wait_poll(self):
	# Check if the process has died by calling waitpid.
	if self.child == -1:
	    return
	try:
	    pid, status = os.waitpid(self.child, os.WNOHANG)
	except os.error, (eno, emsg):
	    if eno == errno.EINTR:
		return
	    raise
	if pid == 0:
	    return
	self.child = -1
	if os.WIFEXITED(status):
            # sys.stderr.write("infshd reporting exit\n")
	    print "exit", self.session_no, fshlib.hollerith(
                os.WEXITSTATUS(status))
	elif os.WIFSIGNALED(status):
            # sys.stderr.write("infshd reporting signal-exit\n")
	    print "signal-exit", self.session_no,
	    print fshlib.hollerith(sig_to_name(os.WTERMSIG(status)))
	else:
            # sys.stderr.write("infshd reporting bad exit\n")
	    print "exit", self.session_no, fshlib.hollerith("unknown")

    def signal(self, signal_name):
	# Send the signal SIGNAL_NAME to the process.
	if self.child == -1:
	    return
	os.kill(self.child, name_to_sig(signal_name))

    def eof_stdin(self):
	# Close stdin of the process.
        self.stdin_pending_close = 1
	if self.stdin_fd != -1 and self.stdin_queue == []:
	    os.close(self.stdin_fd)
	    self.stdin_fd = -1

    def eof_stdout(self):
	# Close stdout of the process.
	if self.stdout_fd != -1:
	    os.close(self.stdout_fd)
	    self.stdout_fd = -1

    def eof_stderr(self):
	# Close stderr of the process.
	if self.stderr_fd != -1:
	    os.close(self.stderr_fd)
	    self.stderr_fd = -1

    def eos(self):
	# Kill the process and close everything.
	self.signal("KILL")
        if self.stdin_fd != -1:
            os.close(self.stdin_fd)
            self.stdin_fd = -1
	self.eof_stdout()
	self.eof_stderr()
        # sys.stderr.write("infshd reporting eos\n")
	print "eos", self.session_no

    def spontaneous_eos(self):
	# Check if it is time to spontaneously do an eos.
	# Return true if an eos was performed.
	if (self.stdout_fd == -1 and self.stderr_fd == -1 \
	    and self.child == -1):

	    self.eos()
	    return 1
	else:
	    return 0

    def stdout_flow(self, new_quota):
        self.stdout_quota = string.atoi(new_quota)

    def stderr_flow(self, new_quota):
        self.stderr_quota = string.atoi(new_quota)


class infshd:
    def __init__(self):
	self.sessions = {}
	self.commands = []
	set_cloexec_flag(0)
	set_cloexec_flag(1)
	set_cloexec_flag(2)
        # sys.stderr.write("infshd starting\n")
	print "fsh 1"

    def new(self, session_no, cmd):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].eos()
	self.sessions[session_no] = session(session_no, cmd)

    def stdin(self, session_no, data):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].stdin(data)

    def signal(self, session_no, data):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].signal(data)

    def eof_stdin(self, session_no):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].eof_stdin()

    def eof_stdout(self, session_no):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].eof_stdout()

    def eof_stderr(self, session_no):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].eof_stderr()

    def eos(self, session_no):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].eos()
	    del self.sessions[session_no]

    def stdout_flow(self, session_no, data):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].stdout_flow(data)

    def stderr_flow(self, session_no, data):
	if self.sessions.has_key(session_no):
	    self.sessions[session_no].stderr_flow(data)

    def toploop(self):
	while 1:
	    r = [0]
	    w = []
	    e = []
	    for s in self.sessions.keys():
		r1, w1, e1 = self.sessions[s].select_set()
		r = r + r1
		w = w + w1
		e = e + e1
            sys.stdout.flush()
            # sys.stderr.write("infshd entering select\n")

	    # Don't wait too long -- we want to reap the children once
	    # in a while.
	    r, w, e = select.select(r, w, e, 5)

            # sys.stderr.write("infshd exiting select\n")
	    for s in self.sessions.keys():
		self.sessions[s].select_action(r, w, e)
		self.sessions[s].wait_poll()
		if self.sessions[s].spontaneous_eos():
		    del self.sessions[s]
	    if 0 in r:
		data = os.read(0, 1024)
		if data == "":
		    # End of file.  Kill everything.
		    for s in self.sessions.keys():
			self.sessions[s].eos()
		    sys.exit(0)
		self.commands.append(data)
		self.parse_commands()

    parse_table = [
	["new",        1, new],
	["stdin",      1, stdin],
	["signal",     1, signal],
	["eof-stdin",  0, eof_stdin],
	["eof-stdout", 0, eof_stdout],
	["eof-stderr", 0, eof_stderr],
	["eos",        0, eos],
	["stdout-flow", 1, stdout_flow],
	["stderr-flow", 1, stderr_flow],
	]

    def do_one_command(self):

	[parsed_cmd, parsed_session, data] = fshlib.parse_line(self.commands,
							       1)

	if parsed_cmd == None:
	    return 0

	for [cmd, need_data, cb] in self.parse_table:
	    if cmd == parsed_cmd:
		if need_data:
		    if data == None:
                        # sys.stderr.write("infshd reporting syntax error\n")
			print "syntax-error 3"
			return 1
		    cb(self, parsed_session, data)
		else:
		    if data != None:
                        # sys.stderr.write("infshd reporting syntax error\n")
			print "syntax-error 4"
			return 1
		    cb(self, parsed_session)
		return 1
        # sys.stderr.write("infshd reporting syntax error\n")
	print "syntax-error 5"
	return 1

    def parse_commands(self):
	while self.do_one_command() == 1:
	    pass

def usage(ret):
    sys.stderr.write(
        "in.fshd: usage: in.fshd [ -V | --version | -h | --help ]\n")
    sys.exit(ret)

def main():
    print_version = 0
    try:
	opts, args = getopt.getopt(sys.argv[1:], "hV",
                                   ["version", "help"])
    except getopt.error, msg:
	sys.stderr.write(msg + "\n")
	sys.exit(1)
    for opt, val in opts:
	if opt == "-V" or opt == "--version":
	    print_version = 1
	elif opt == "-h" or opt == "--help":
	    usage(0)
    if print_version:
	fshlib.print_version("in.fshd")
    if len(args) != 0:
	usage(1)
    
    i = infshd()
    i.toploop()

if __name__ == '__main__':
    main()
