#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

from winswitch.util.simple_logger import Logger, msig, set_loggers_debug
logger=Logger("process_util", log_colour=Logger.CYAN)
debug_import = logger.get_debug_import()

debug_import("os/sys/time/signal/subprocess/threading/collections")
import os
import sys
import time
import signal
import subprocess
import threading
from collections import deque

debug_import("twisted.internet.error")
from twisted.internet.error import ProcessDone, ProcessTerminated
debug_import("twisted.python.failure")
from twisted.python.failure import Failure
debug_import("winswitch.twisted.HiddenSpawnedProcess")
from winswitch.twisted.HiddenSpawnedProcess import HiddenSpawnedProcess

debug_import("globals")
from winswitch.globals import SUBPROCESS_CREATION_FLAGS, WIN32
debug_import("util.common")
from winswitch.util.common import visible_command, escape_newlines, csv_list, is_valid_file, is_valid_dir
debug_import("util.main_loop")
from winswitch.util.main_loop import callLater
debug_import("util.consts")
from winswitch.consts import NOTIFY_ERROR

logger.slog("using twisted ProcessProtocol=%s" % HiddenSpawnedProcess)

BUFFER_PREVIOUS_LINES = 64
DEFAULT_MAX_BUFFER_SIZE = 64*1024	#we're not expecting lines longer than this

LOG_WAITING_PROCESS = not "--no-log-waiting-process" in sys.argv




class DaemonPidWrapper:
	""" Wrapper object for daemon processes of which we know the pid,
		implements stop() so we can kill the daemon """
	def __init__(self, cmd, pid):
		Logger(self)
		self.cmd = cmd
		self.pid = pid

	def __str__(self):
		return	"DaemonPidWrapper(%s : %s)" % (csv_list(self.cmd), self.pid)

	def stop(self):
		self.slog("wrapper=%s" % self)
		if (self.pid>1):
			send_kill_signal(self.pid)


def get_output(cmd, cwd=None, env=None, shell=False, setsid=False):
	logger.sdebug(None, cmd, cwd, env, shell, setsid)
	kwargs = {}
	if os.name=="posix" and setsid:
		def _setsid():
			os.setsid()
		kwargs["preexec_fn"] = _setsid
	proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
							cwd=cwd, env=env, shell=shell, close_fds=True, **kwargs)
	out, err = proc.communicate()
	return proc.returncode, out, err

def exec_daemonpid_wrapper(cmd, env=None, shell=False, cwd=None):
	""" Starts a daemon that prints its pid and returns a DaemonPidWrapper
		which can be stopped with a call to stop()
	"""
	try:
		code, stdoutdata, stderrdata = get_output(cmd, shell=shell, env=env, cwd=cwd)
		logger.sdebug("code=%s, stdout=%s, stderr=%s" % (code, visible_command(stdoutdata), visible_command(stderrdata)))
		if not stdoutdata:
			logger.serror("no stdout, no pid found!", cmd, env, shell, cwd)
			return	None
	except Exception, e:
		logger.serr(None, e, cmd, env, shell, cwd)
		return	None
	try:
		pid = int(stdoutdata)
	except Exception, e:
		logger.serror("failed to parse stdout as a pid: %s" % visible_command(stdoutdata))
		return	None
	return DaemonPidWrapper(cmd, pid)


def _add_process_stop(proc):
	if proc is not None:
		def stop_server_process():
			subprocess_terminate(proc)
		proc.stop = stop_server_process


def exec_nopipe(cmd, env=None, extra_env=None, wait=False, shell=False, cwd=None, log_args=True, setsid=False):
	if log_args:
		scmd = str(cmd)
	else:
		scmd = "%s ..." % cmd[0]
	sig = msig(scmd, env, extra_env, wait, shell, cwd, log_args)
	process = None
	try:
		pipe = subprocess.PIPE
		if env is None:
			env = os.environ.copy()
		if extra_env:
			for key, value in extra_env.items():
				env[key] = value
		if WIN32:
			process = subprocess.Popen(cmd, stdin=pipe, stdout=pipe, stderr=pipe, shell=shell, env=env, cwd=cwd,
									creationflags=SUBPROCESS_CREATION_FLAGS)
		else:
			kwargs = {}
			if os.name=="posix" and setsid:
				def _setsid():
					os.setsid()
				kwargs["preexec_fn"] = _setsid
			process = subprocess.Popen(cmd, stdin=None, stdout=None, stderr=None, close_fds=True, shell=shell, env=env, cwd=cwd, **kwargs)
		_add_process_stop(process)
		if wait:
			if LOG_WAITING_PROCESS:
				logger.log(sig+" waiting on process=%s" % process)
			process.wait()
		else:
			logger.log(sig+"=%s" % process)
	except Exception, e:
		logger.error(sig, e)
	return process

def exec_to_log(log_filename, args, env=None, cwd=None, shell=False):
	"""
	Simply launches the process with stdout and stderr sent to log_filename.
	On win32, we can't close_fds, and we need to specify stdin... (tried os.devnull which did not work)
	See: http://www.py2exe.org/index.cgi/Py2ExeSubprocessInteractions
	The Popen object returned is augmented with a stop() method.
	"""
	try:
		logfile = open(log_filename, "wb")
		log_fd = os.fdopen(logfile.fileno(), 'w', 0)
		stdin = None
		if WIN32:
			stdin = subprocess.PIPE
		proc = subprocess.Popen(args, bufsize=1, universal_newlines=True, stdin=stdin, stdout=log_fd, stderr=log_fd,
							env=env, cwd=cwd, shell=shell, close_fds=not WIN32,
							creationflags=SUBPROCESS_CREATION_FLAGS)
		_add_process_stop(proc)
		return proc
	except Exception, e:
		logger.serr(None, e, log_filename, args, env, cwd, shell)
		return	None



class LineProcessProtocolWrapper(HiddenSpawnedProcess):
	"""
	Abstract superclass for interacting with the input and output of a process using a line protocol
	similar to twisted's LineReceiver.
	"""

	def __init__(self):
		self.pid = None
		HiddenSpawnedProcess.__init__(self)
		Logger(self)

		self.command = None				#the command we are running
		self.cwd = None					#the directory to run the command from
		self.process = None				#the IProcessTransport - see: http://twistedmatrix.com/documents/current/api/twisted.internet.interfaces.IProcessTransport.html
		self.returncode = None
		self.start_time = None
		self.log_full_command = True	#can be turned off to prevent information disclosure (ie: SSH passwords with putty)

		self.stopping = False			#flag indicating we know we're in the process of closing down (may be used to ignore data, errors, etc)
		self.terminated = False			#the process has terminated
		self.merge_stderr = False		#should stderr be added to the buffer?
		self.buffer = ''
		self.line_count = 0
		self.previous_lines = deque([])
		self.buffer_previous_lines = BUFFER_PREVIOUS_LINES
		self.max_buffer_size = DEFAULT_MAX_BUFFER_SIZE
		self.DEBUG = False
		self.LOG_ALL_DATA = False
		self.log_lifecycle = True
		self.usePTY = False
		self.notify = None
		self.env = self.make_env()

	def make_env(self):
		return	os.environ.copy()

	def start(self):
		self.command = self.get_command()
		if not self.command:
			self.slog("no command to run!")
			self.terminated = True
			return
		if self.log_full_command:
			cmd_str = str(self.command)
		else:
			cmd_str = "%s ..." % self.command[0]
		self.sdebug("usePTY=%s, env=%s, cmd=%s" % (self.usePTY, self.env, cmd_str))
		self.start_time = time.time()
		self.process = None
		try:
			self.process = self.startProcess(self.command, env=self.env, path=self.cwd, usePTY=self.usePTY)
			if self.process and hasattr(self.process, "pid"):
				self.pid = self.process.pid
			if self.log_lifecycle:
				self.slog("process(%s)=%s" % (cmd_str, self.process))
		except Exception, e:
			self.serr("failed to start %s" % cmd_str, e)

	def get_command(self):
		raise Exception("Subclass must implement this!")
		#ie: return ["echo", "echo", "hello world"]

	def terminate(self):
		"""
		Emulates a subprocess.Popen.terminate()
		"""
		self.stop("called terminate()")

	def poll(self):
		if self.terminated:
			return	0
		return	None

	def stop(self, reason=None):
		self.slog("process=%s" % self.process, reason)
		self.stopping = True
		if self.process and not self.terminated:
			self.process.signalProcess("INT")


	def kill(self):
		self.sdebug("process=%s" % self.process)
		self.stopping = True
		if self.process and not self.terminated:
			self.process.signalProcess("KILL")
			self.process = None


	#We override the ProcessProtocol functions:
	#see: http://twistedmatrix.com/documents/current/api/twisted.internet.protocol.ProcessProtocol.html
	def connectionMade(self):
		self.sdebug(None)
		#self.sdebug("pid=%s" % self.process.pid)
	def outReceived(self, data):
		if self.LOG_ALL_DATA:
			self.sdebug(None, escape_newlines(data))
		self.dataReceived(data)
	def	errReceived(self, data):
		if self.LOG_ALL_DATA:
			self.sdebug(None, escape_newlines(str(data)))
		if self.merge_stderr and type(data)==str:
			self.dataReceived(data)
			return
		if self.stopping:
			self.slog(None, visible_command(data))
		else:
			self.serror(None, visible_command(data, 255))
	def childConnectionLost(self, fd):
		if self.DEBUG:
			self.sdebug(None, fd)
	def processExited(self, reason):
		if not isinstance(reason, Failure):
			self.serror("unknown reason type: %s" % type(reason), reason)
		else:
			process_info = reason.value
			if isinstance(process_info, ProcessDone):
				if self.log_lifecycle:
					self.slog("OK: ProcessDone (command=%s)" % str(self.command), reason)
				else:
					self.sdebug("OK: ProcessDone (command=%s)" % str(self.command), reason)
				self.returncode = 0
			elif isinstance(process_info, ProcessTerminated):
				self.returncode = process_info.exitCode
				if self.log_lifecycle:
					self.serror("Failure (command=%s), exitCode=%s" % (str(self.command), self.returncode), reason)
				else:
					self.sdebug("Failure (command=%s), exitCode=%s" % (str(self.command), self.returncode), reason)
		self.exit()

	def processEnded(self, reason):
		self.sdebug(None, reason)
		self.exit()

	def exit(self):
		self.terminated = True

	def dataReceived(self, data):
		"""
		We buffer things until we get at least one full line of text to handle.
		"""
		if self.DEBUG:
			self.sdebug(None, visible_command(escape_newlines(data)))
		if self.buffer:
			data = "%s%s" % (self.buffer, data)
		lines = data.splitlines(True)
		last = lines[len(lines)-1]
		if self.LOG_ALL_DATA:
			self.sdebug("last=%s" % escape_newlines(last), escape_newlines(data))
			self.sdebug("lines=%s" % csv_list(lines, '"'), escape_newlines(data))
		if not (last.endswith('\n') or last.endswith('\r')):
			#save incomplete line in buffer
			self.buffer = last
			lines = lines[:len(lines)-1]
		else:
			self.buffer = ''

		for line in lines:
			self.line_count += 1
			while line.endswith("\n") or line.endswith("\r"):
				line = line[:len(line)-1]
			self.handle(line)
			#keep line in previous_lines fifo list
			if self.buffer_previous_lines>0:
				self.previous_lines.append(line)
				while len(self.previous_lines)>=self.buffer_previous_lines:
					self.previous_lines.popleft()

		if len(self.buffer)>self.max_buffer_size:
			self.stop("buffer too large: %s" % len(self.buffer))

	def exec_error(self):
		self.serror("failed to start '%s' with command '%s'" % (self.name, self.command))
		if self.notify:
			self.notify("Failed to launch", "Process error",
					notification_type=NOTIFY_ERROR)


	def handle(self, line):
		"""
		This is the method that subclasses will want to implement.
		It is called on every line of input as soon as it is received.
		"""
		self.slog(None, line)

	def callLater(self, timeout, *args):
		"""
		Wrapper for reactor.callLater(..) so subclasses can call this without needing to import or know about reactor or main_loop.
		"""
		callLater(timeout, *args)


class SimpleLineProcess(LineProcessProtocolWrapper):

	def __init__(self, args, env, cwd, line_handler, start_callback=None, exit_callback=None, log_full_command=True):
		LineProcessProtocolWrapper.__init__(self)
		self.command = args
		self.cwd = cwd
		self.env = env
		self.line_handler = line_handler
		self.start_callback = start_callback
		self.exit_callback = exit_callback
		self.log_full_command = log_full_command
		self.merge_stderr = True
		self.DEBUG = True
		self.LOG_ALL_DATA = False

	def handle(self, line):
		self.line_handler(line)

	def connectionMade(self):
		LineProcessProtocolWrapper.connectionMade(self)
		if self.start_callback:
			try:
				self.start_callback(self)
			except Exception, e:
				self.serr("error on %s" % self.start_callback, e)

	def exit(self):
		LineProcessProtocolWrapper.exit(self)
		if self.exit_callback:
			try:
				self.exit_callback(self)
			except Exception, e:
				self.serr(None, e)
			self.exit_callback = None

	def get_command(self):
		return	self.command

class ProcessOutput(SimpleLineProcess):

	def __init__(self, args, env, cwd, ok_callback=None, err_callback=True, log_full_command=True):
		self.stdout = []
		self.ok_callback = ok_callback
		self.err_callback = err_callback
		self.merge_stderr = True
		self.callbacks_done = False
		SimpleLineProcess.__init__(self, args, env, cwd, self.handle_line, None, None, log_full_command)
		self.DEBUG = False

	def handle_line(self, line):
		self.stdout.append(line)

	def processEnded(self, reason):
		SimpleLineProcess.processEnded(self, reason)
		if not self.callbacks_done:
			self.callbacks_done = True
			try:
				if self.returncode==0:
					if self.ok_callback:
						out = "\n".join(self.stdout)
						self.ok_callback(out)
				else:
					if self.err_callback:
						out = "\n".join(self.stdout)
						self.err_callback(out)
			except Exception, e:
				self.serr("ok_callback=%s, err_callback=%s" % (self.ok_callback, self.err_callback), e, reason)



def twisted_exec(cmd, ok_callback, err_callback=None, path=None, env=None, log_lifecycle=True):
	"""
	Wrapper around twisted's getProcessOutput()
	Allows us to hide the fact that this is implemented using twisted (no imports in caller) and provides an easier way
	of dealing with the double errback issue. (see: http://twistedmatrix.com/trac/ticket/3892).
	The ok_callback and err_callback are simple callbacks receiving the output of the command or a description of the error.
	"""
	logger.sdebug(None, cmd, ok_callback, err_callback, path, env)
	if env is None:
		if os.name=="posix":
			env = {}
			#preserve = ["USER", "PATH", "LOGNAME", "TERM", "HOME", "USERNAME", "DBUS_SESSION_BUS_ADDRESS", "XDG_SESSION_COOKIE", "DISPLAY", "HOSTNAME", "PWD"]
			preserve = ["USER", "PATH", "LOGNAME", "TERM", "HOME", "USERNAME", "DISPLAY", "HOSTNAME", "PWD"]
			for x in preserve:
				if x in os.environ:
					env[x] = os.environ.get(x)
		else:
			env = os.environ.copy()
	proc = ProcessOutput(cmd, env, path, ok_callback, err_callback)
	proc.log_lifecycle  = log_lifecycle
	proc.start()
	return proc










def kill_daemon(pid, fd_file):
	if is_daemon_alive(pid, fd_file):
		send_kill_signal(pid)

def is_daemon_alive(pid, fd_file, fds=None):
	"""
	Checks to see if a daemon is alive by finding its list of open file descriptors
	and checking that the fd_file we expect is still opened.
	(this ensures we don't kill a random process by pid, without using pidfiles)
	"""
	if not pid or pid<=0:
		logger.serror("invalid pid!", pid, fd_file, fds)
		return	False
	if not is_valid_dir("/proc"):
		logger.serror("cannot access /proc!", pid, fd_file, fds)
		return	False
	path = os.path.join("/proc", str(pid))
	if not is_valid_dir(path):
		logger.serror("cannot find process directory %s" % path, pid, fd_file, fds)
		return	False
	fd_dir = os.path.join(path, "fd")
	if not is_valid_dir(fd_dir):
		logger.serror("cannot find process file descriptor directory %s" % fd_dir, pid, fd_file, fds)
		return	False
	if not fds:
		fds = os.listdir(fd_dir)			#xpra opens the file itself, so it won't be fd1 or fd2
		fds.insert(0, "2")					#try stderr first
		fds.insert(1, "1")					#try stdout next
	for filename in fds:
		test_file = os.path.join(fd_dir, filename)
		if is_valid_file(test_file) and os.path.islink(test_file):
			real_path = os.path.realpath(test_file)
			if real_path==fd_file:
				ok = kill0(pid)
				logger.slog("found process which has the file opened, testing it with 'kill -0': %s" % ok, pid, fd_file, fds)
				return ok
	logger.slog("pid exists but does not have file opened! ignoring this process.", pid, fd_file, fds)
	return	False

def kill0(pid):
	"""
	Probes the pid by sending it a "kill -0"
	Returns True if the pid is alive (received the signal), False otherwise.
	(on win32 we use win32_process_info_handle())
	"""
	if WIN32:
		return	win32_process_info_handle(pid) is not None
	try:
		os.kill(pid, 0)
		logger.sdebug("pid is still alive", pid)
		return	True		#OK
	except OSError, e:
		logger.slog("error polling pid: %s" % e, pid)
		return	False

def win32_process_info_handle(pid):
	PROCESS_QUERY_INFORMATION = 0x0400
	try:
		import win32api		#@UnresolvedImport
		handle = win32api.OpenProcess(PROCESS_QUERY_INFORMATION , False, pid)
		win32api.CloseHandle(handle)
		return True
	except:
		return	None

def get_command_for_pid(pid):
	if WIN32:
		return	""
	proc_cmdline = os.path.join("/proc", "%s" % pid, "cmdline")
	if not is_valid_file(proc_cmdline):
		logger.sdebug("process not found in /proc", pid)
		return	""
	f = None
	try:
		f = open(proc_cmdline, "rU")
		data = f.readline()
	finally:
		if f:
			f.close()
	logger.sdebug("=%s" % data, pid)
	if not data:
		return	""
	return	("%s" % data).replace("\x00", "")



def win32_TerminateProcess(pid):
	handle = None
	try:
		import win32api		#@UnresolvedImport
		PROCESS_TERMINATE = 1
		handle = win32api.OpenProcess(PROCESS_TERMINATE, False, pid)
		win32api.TerminateProcess(handle, -1)
		win32api.CloseHandle(handle)
		logger.slog("used win32api.CloseHandle(%s)" % handle, pid)
		return	True
	except Exception, e:
		logger.serr("win32api.CloseHandle(%s) failed" % handle, e, pid)
		return	False


def send_kill_signal(pid, kill_signal=signal.SIGTERM, may_be_missing=False):
	logger.slog(None, pid, kill_signal, may_be_missing)
	if pid is None or pid<1:
		return False

	#only supported on win32 since Python 2.7
	if hasattr(os, "kill"):
		logger.sdebug("using os.kill", pid, kill_signal, may_be_missing)
		e = None
		try:
			os.kill(pid , kill_signal)
			return	True
		except OSError, e:
			if may_be_missing and str(e).find("No such process")>=0:
				logger.sdebug("%s" % e, pid, kill_signal)
				return False
		except Exception, e:
			pass
		logger.serror("%s" % e, pid, kill_signal, may_be_missing)
		return False
	elif WIN32:
		return win32_TerminateProcess(pid)
	else:
		os.system("kill -%s %d" % (kill_signal, pid))
		return True


def subprocess_terminate(proc, check_returncode=True):
	logger.slog(None, proc, check_returncode)
	if check_returncode:
		proc.poll()
		retcode = proc.returncode
		if retcode is not None:
			logger.sdebug("already terminated, return code=%s" % retcode, proc, check_returncode)
			return
	try:
		if hasattr(proc, "terminate"):
			proc.terminate()
			return
		if proc.pid<=1:
			logger.serror("illegal pid: %s" % proc.pid, proc, check_returncode)
			return

		send_kill_signal(proc.pid)
	except Exception, e:
		logger.serr("pid=%s" % proc.pid, e, proc, check_returncode)



def dump_all_frames(*args):
	import traceback
	frames = sys._current_frames()
	logger.sdebug("found %s frames:" % len(frames))
	for fid,frame in frames.items():
		logger.sdebug("%s - %s:" % (fid, frame))
		traceback.print_stack(frame)

def register_sigusr_debugging():
	if os.name!="posix":
		return
	def sigusr1_handler(*args):
		logger.slog(None, *args)
		set_loggers_debug(True)
		dump_all_frames()
	def sigusr2_handler(*args):
		logger.slog(None, *args)
		set_loggers_debug(False)

	signal.signal(signal.SIGUSR1, sigusr1_handler)
	signal.signal(signal.SIGUSR2, sigusr2_handler)
	logger.slog("signal debugging installed, use SIGUSR1/SIGUSR2 to enable/disable debug logging and frames dumping")


def dump_threads():
	try:
		cur = threading.currentThread()
		logger.sdebug("current_thread=%s" % cur)
		count = 1
		for t in threading.enumerate():
			if t!=cur:
				logger.sdebug("found thread: %s, alive=%s" % (t, t.isAlive()))
				count += 1
		logger.sdebug("total number of threads=%s" % count)
	except Exception, e:
		logger.exc(e)

def start_dump_threads():
	if "--dump-threads" in sys.argv:
		def dump_and_reschedule():
			dump_threads()
			callLater(60, dump_and_reschedule)
		callLater(60, dump_and_reschedule)
