#!/usr/bin/python
#
# Upstream - log file aggregator and report tool for *nix systems.
# Copyright (C) 2006  Ryan Zeigler (zeiglerr@gmail.com)
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# This module provides support for our module import facilities
# there seems to exist a more "pythony" way of doing this via
# import hooks.  This is an area of future development, i.e. determine
# if it is viable for what we want to do.  Currently we use unbound
# import hooks which means we never have to deal with complex implementation
# details


# TODO some of the __repr__ seem to try to concatenate strings with non-strings

import glob, imp, sys, os, threading

#  This module was built with the goal of maximum fault tolerance for
#  any imported modules.  If there is a discovered case in which
#  fault tolerance

MLOAD_NOT_LIST = 0
MLOAD_EMPTY_LIST = 1
MLOAD_HAS_NONSTR = 2

DEBUG_NONE = 0
DEBUG_ALL = 1

class ModuleLoaderInitException(Exception):
	def __init__(self, err_type):
		Exception.__init__(self)
		self.err_type = err_type
	def __repr__(self):
		return "raised ModuleLoaderInitException(" + self.err_type + ")"
	def __str__(self):
		if self.err_type == MLOAD_NOT_LIST:
			return "Error: Package list was not a list"
		elif self.err_type == MLOAD_EMPTY_LIST:
			return "Error: Package list was empty"
		elif self.err_type == MLOAD_HAS_NONSTR:
			return "Error: Package list contained non-strings"
		else:
			return "Error: unknown?"
			
			
class IncorrectModuleReturnType(Exception):
	def __init__(self, found_type, expected_type):
		Exception.__init__(self)
		self.found_type = found_type
		self.expected_type = expected_type
	def __repr__(self):
		return "raised IncorrectModuleReturnType(" + self.found_type + "," + self.expected_type + ")"
	def __str__(self):
		return "Found type: " + self.found_type + " Expected type: " + self.expected_type

# LoadedModule now extends from Thread type
class LoadedModule(threading.Thread):
	fault_tolerance = True
	debug_output = DEBUG_NONE
	# This expects the module fields to already exist
	# Note: any new module wrappers must have this exact argument list
	def __init__(self, module, fault_tolerance, debug_output):
		threading.Thread.__init__(self)
		self.fault_tolerance = fault_tolerance
		self.debug_output = debug_output
		self.module = module
		self.module_name = self.module.module_name
		self.module_description = self.module.module_description
	def __repr__(self):
		return "< loaded module with name : %s >" % self.module_name

# ModuleLoader now extends from Thread type
class ModuleLoader(threading.Thread):
	# Necessary attributes for a generic module
	# Subclasses may override this to provide for different required attributes
	necessary_attributes = ["module_name", "module_description"]
	necessary_attr_types = [str, str]
	load_stats = -1
	validation_status = -1
	# New classes should override the ModuleWrapper item
	ModuleWrapper = LoadedModule
	def __init__(self, pack_list, fault_tolerance=True, debug_output=DEBUG_NONE, use_threading = False):
		# Chain up
		threading.Thread.__init__(self)
		self.threaded = use_threading
		self.pack_list = pack_list
		self.debug_output = debug_output
		self.fault_tolerance = fault_tolerance
		self.valid_modules = []
		if use_threading:
			self.start()
		else:
			self.execute_load()
	
	def __repr__(self):
		return "ModuleLoader(" + str(self.pack_list) + ", " + str(self.fault_tolerance) + ", " + str(self.debug_output) + ")"
	
	def __getitem__(self, modid):
		if type(modid) is not str and type(modid) is not int:
			raise TypeError
		# Find at index
		if type(modid) is int:
			if modid >= len(self.valid_modules) or modid < 0:
				raise IndexError
			else:
				return self.valid_modules[modid]
		# Find at id
		if type(modid) is str:		
			for mod in self.valid_modules:
				if mod.module_name == modid:
					return mod
			raise KeyError
		
	def __delitem__(self, modid):
		if type(modid) is not str and type(modid) is not int:
			raise TypeError
		# Find at index
		if type(modid) is int:
			if modid >= len(self.valid_modules) or modid < 0:
				raise IndexError
			else:
				self.valid_modules.remove(self.valid_modules[modid])
		# Find at id
		if type(modid) is str:		
			# This will already raise an exception if necessary
			mod = self.__getitem__(modid)
			self.valid_modules.remove(mod)
			
					
	def __len__(self):
		return len(self.valid_modules)
			
	def __iter__(self):
		return ModuleLoaderIterator(self)
	# Methods utilized by the thread class
	def run(self):
		print "Running threaded load"
		self.execute_load()
		
	def execute_load(self):
		# Perform validation to ensure that we didn't end up invalid
		# parameters
		if type(self.pack_list) is not list:
			raise ModuleLoaderInitException(MLOAD_NOT_LIST)
			
		# Only perform an actual import if we have a non-zero list
		if len(self.pack_list) is not 0:
			# Perform validation of the content of the pack_list					
			for p in self.pack_list:
				if type(p) is not str:
				# Prune from list if not a string
					if self.fault_tolerance:
						self.pack_list.remove(p)
					else:
						raise ModuleLoaderInitException(MLOAD_HAS_NONSTR)
					
			# Find all packages and import their modules
			loaded_modules = []
			for package_name in self.pack_list:
				if self.debug_output >= DEBUG_ALL:
					print "Importing %s" % package_name
				imp_pack = __import__(package_name)
				for plugin_name in imp_pack.__all__:
					__import__(package_name + "." + plugin_name)
					loaded_modules.append(getattr(imp_pack, plugin_name))
					
				#candidate_modules = []
				#candidate_modules = candidate_modules + self.scan_directory(path_name)

			# Load modules
			#self.load_status = 0.0
			#loaded_modules = []
			#counter = 0
			#for mod in candidate_modules:
			#	tmp_module = self.load_module(mod)
			#	if tmp_module:
			#		loaded_modules.append(tmp_module)
			#	counter = counter + 1
			#	self.load_status = (counter + 0.0)/len(candidate_modules) 

			# Validate modules
			self.validation_status = 0.0
			counter = 0
			for mod in loaded_modules:
				if self.validate_module(mod):
					if self.debug_output >= DEBUG_ALL:
						print " Adding %s to valid modules" % (mod)
					self.valid_modules.append(self.ModuleWrapper(mod, self.fault_tolerance, self.debug_output))
				elif self.debug_output >= DEBUG_ALL:
						print " Not adding %s to valid modules" % (mod)
				counter = counter + 1
				self.validation_status = (counter + 0.0)/len(loaded_modules)
								
		else:
			if not self.fault_tolerance:
				raise ModuleLoaderInitException(MLOAD_EMPTY_LIST)
				

	# Helper function.  This function scans a directory and returns
	# all of the *.py things it globs.  It returns a whole bunch of
	# truncated module names
	def scan_directory(self, path):
		if self.debug_output >= DEBUG_ALL:
			print "Scanning directory: %s" % path
		if path[len(path) - 1] == '/':
			py_pattern = path + "*.py"
			pyc_pattern = path = "*.pyc"
			pyo_pattern = path = "*.pyo"
		else:
			py_pattern = path + "/*.py"
			pyc_pattern = path + "/*.pyc"
			pyo_pattern = path + "/*.pyo"
			
		# Get globs
		found_py = glob.glob(py_pattern)
		found_pyc = glob.glob(pyc_pattern)
		found_pyo = glob.glob(pyo_pattern)
		
		# Convert to basename
		base_py = [os.path.basename(mod) for mod in found_py]
		base_pyc = [os.path.basename(mod) for mod in found_pyc]
		base_pyo = [os.path.basename(mod) for mod in found_pyo]
		
		stripped_py = [mod[0:mod.rfind(".py")] for mod in base_py]
		stripped_pyc = [mod[0:mod.rfind(".pyc")] for mod in base_pyc]
		stripped_pyo = [mod[0:mod.rfind(".pyo")] for mod in base_pyo]
		
		
		# Remove the remaining path names from the module, since 
		# We also have to add 1 to the index of the /
		found_modules = stripped_py
		for mod in stripped_pyc:
			if mod not in found_modules:
				if self.debug_output >= DEBUG_ALL:
					print "Module %s does not have a corresponding source file" % mod
				found_modules.append(mod)
			elif self.debug_output >= DEBUG_ALL:
				print "Module %s has a bytecode file" % mod
				
		for mod in stripped_pyo:
			if mod not in found_modules:
				if self.debug_output >= DEBUG_ALL:
					print "Module %s does not have a corresponding source file or standard bytecode file" % mod
				found_modules.append(mod)
			elif self.debug_output >= DEBUG_ALL:
				print "Module %s has a bytecode file" % mode
		# Return the found modules
		return found_modules
		
	def load_module(self, modname):
		if self.debug_output >= DEBUG_ALL:
			print "Attempting to 'find' module: %s" % modname
		file_handle, filename, description = imp.find_module(modname)
		loaded_module = None
		if not file_handle:
			if self.debug_output >= DEBUG_ALL:
				print "Failed 'finding' module (programming error or possible collision with builtin module): %s" % stripped_modname
		else:
			try:
				loaded_module = imp.load_module(modname, file_handle, modname, description)
			except:
				# Close our file handle
				file_handle.close()
				# If we are not using fault tolerance, reraise
				if self.debug_output >= DEBUG_ALL:
					print "Load failed on module: %s, attempting recovery" % modname
					print sys.exc_info()[0]
				if not self.fault_tolerance:
					raise

			file_handle.close()
		return loaded_module
		
	
	# Provide a string method
	def __str__(self):
		return "Module loader:\n" + repr(self.valid_modules	)
		
	# This is the bare minimum necessary for one of our		
	def validate_module(self, module):
		if self.debug_output >= DEBUG_ALL:
			print "Validating module: %s" % module.__name__
		valid_fields = self.validate_fields(module) 
		valid_additional = self.validate_additional(module)
		return valid_fields and valid_additional
	# Determine if the module has the necessary fields to be a valid module
	# Subclasses should probably not have to override this method, and
	# instead, they should rely on overriding the "necessary_attributes" field
	def validate_fields(self, module):
		for field in self.necessary_attributes:
			if self.debug_output >= DEBUG_ALL:
				print " Validating fields %s : %s" % (field, hasattr(module, field))
			if not hasattr(module, field):
				return False
			ind = self.necessary_attributes.index(field)
			# This only runs when we have actually specified out to that
			# type
			if ind < self.necessary_attr_types:
				# Ahh ye olde debugging output, you take up so much space
				if self.debug_output >= DEBUG_ALL:
					print " Validating field %s as type %s : %s" % (field, self.necessary_attr_types[ind], type(module.__dict__[field]) == self.necessary_attr_types[ind] and self.necessary_attr_types[ind] is not None)
				
				if not type(module.__dict__[field]) == self.necessary_attr_types[ind] and self.necessary_attr_types[ind] is not None:
					return False
		# If we get to the end, we were successful
		return True
	# Determine if the module has the necessary activation hooks to be
	# a module.  Subclasses will probably have to reimplement this method
	# from scratch, since a default module provides no hooks into the module
	# and simply returns true.  DEBUG output can be retrieved by simply chaining
	# up, since there is no
	def validate_additional(self, module):
		return True
		
	# This is not actually used by default, but is provided as a convenience,
	# so that base classes do not have to reimplement it
	def validate_execution_hook(self, module, name, num_args):
		if self.debug_output >= DEBUG_ALL:
			print " Module %s has attribute %s: %s " % (module, name, hasattr(module, name))
			print " Module attribute %s is of type 'func_code': %s" % (name, hasattr(module.__dict__[name], "func_code"))
			print " Module function %s has %d args: %s" % (name, num_args, module.__dict__[name].func_code.co_argcount is num_args)
		hasfunc = hasattr(module, name) 
		func_is_func = hasattr(module.__dict__[name], "func_code")
		func_has_correct_param = module.__dict__[name].func_code.co_argcount is num_args
		# Ryan: Is this ok? 
		# Fixed, good spot
		return hasfunc and func_is_func and func_has_correct_param
	
	# Deprecated: Use mappings instead	
	def module(self, mod_name):	
		if self.debug_output >= DEBUG_ALL:
			print "Searching for module: %s" % mod_name	
		for x in self.valid_modules:
			if self.debug_output >= DEBUG_ALL:
				print "Looking at module: %s trying to find: %s" % (x, mod_name)
			if x.module_name == mod_name or x.module.__name__ == mod_name:
				return x
		return None


class ModuleLoaderIterator:
	def __init__(self, parent):
		self.parent = parent
		self.ind = -1
		
	def next(self):
		self.ind = self.ind + 1
		if self.ind == len(self.parent.valid_modules):
			raise StopIteration
		else:
			return self.parent.valid_modules[self.ind]


########################################################
# All classes/functions below this marker are deprecated,
# and will be removed soon
########################################################
class ModuleDirectoryScanner:
	def __init__(self, path, fault_tolerance, debug_output):
		self.duplicate_path =  False
		self.fault_tolerance = fault_tolerance
		self.debug_output = debug_output
		
		self.path = path
		
		if self.path not in sys.path:
			sys.path.append(self.path)
			if self.debug_output >= DEBUG_ALL:
				print "Adding %s to path list" % self.path
				print sys.path
				
			
		else:
			if self.debug_output >= DEBUG_ALL:
				print "Not adding %s to path list" % self.path
				print sys.path
			self.duplicate_path = True
		
		self.dir_modules = []
		if debug_output >= DEBUG_ALL:
			print "Scanning directory: %s" % self.path
		self.scan()
		self.load()
			
			
	def __iter__(self):
		return ModuleDirectoryScannerIterator(self)		
	
	# Scan for possible modules that can be loaded	
	#def scan(self):
		#if self.path[len(self.path) - 1] == '/':
		#	glob_pattern = self.path + "*.py"
		#else:
		#	glob_pattern = self.path + "/*.py"
		#found_modules = glob.glob(glob_pattern)
		# Remove the remaining path names from the module, since 
		# We also have to add 1 to the index of the /
		#self.found_modules = [os.path.basename(mod) for mod in found_modules]
		
	def scan(self):
		if self.path[len(self.path) - 1] == '/':
			py_pattern = self.path + "*.py"
			pyc_pattern = self.path = "*.pyc"
			pyo_pattern = self.path = "*.pyo"
		else:
			py_pattern = self.path + "/*.py"
			pyc_pattern = self.path + "/*.pyc"
			pyo_pattern = self.path + "/*.pyo"
			
		# Get globs
		found_py = glob.glob(py_pattern)
		found_pyc = glob.glob(pyc_pattern)
		found_pyo = glob.glob(pyo_pattern)
		
		# Convert to basename
		base_py = [os.path.basename(mod) for mod in found_py]
		base_pyc = [os.path.basename(mod) for mod in found_pyc]
		base_pyo = [os.path.basename(mod) for mod in found_pyo]
		
		stripped_py = [mod[0:mod.rfind(".py")] for mod in base_py]
		stripped_pyc = [mod[0:mod.rfind(".pyc")] for mod in base_pyc]
		stripped_pyo = [mod[0:mod.rfind(".pyo")] for mod in base_pyo]
		
		
		# Remove the remaining path names from the module, since 
		# We also have to add 1 to the index of the /
		self.found_modules = stripped_py
		for mod in stripped_pyc:
			if mod not in self.found_modules:
				if self.debug_output >= DEBUG_ALL:
					print "Module %s does not have a corresponding source file" % mod
				self.found_modules.append(mod)
			elif self.debug_output >= DEBUG_ALL:
				print "Module %s has a bytecode file" % mod
				
		for mod in stripped_pyo:
			if mod not in self.found_modules:
				if self.debug_output >= DEBUG_ALL:
					print "Module %s does not have a corresponding source file or standard bytecode file" % mod
				self.found_modules.append(mod)
			elif self.debug_output >= DEBUG_ALL:
				print "Module %s has a bytecode file" % mode
		
	# Load modules			
	def load(self):
		if self.debug_output >= DEBUG_ALL:
			print "Located all modules: %s" % self.found_modules
		for modname in self.found_modules:
			if self.debug_output >= DEBUG_ALL:
				print "Attempting to 'find' module: %s" % modname
			file_handle, filename, description = imp.find_module(modname)
			loaded_module = None
			if not file_handle:
				if self.debug_output >= DEBUG_ALL:
					print "Failed 'finding' module (programming error or possible collision with builtin module): %s" % stripped_modname
			else:
				try:
					loaded_module = imp.load_module(modname, file_handle, modname, description)
				except:
					# Close our file handle
					file_handle.close()
					# If we are not using fault tolerance, reraise
					if self.debug_output >= DEBUG_ALL:
						print "Load failed on module: %s, attempting recovery" % modname
						print sys.exc_info()[0]
					if not self.fault_tolerance:
						raise
							
				file_handle.close()

				
			# If loaded module is None, something went wrong
			if loaded_module:
				self.dir_modules.append(loaded_module)
					
class ModuleDirectoryScannerIterator:
	def __init__(self, parent):
		self.parent = parent
		self.ind = -1
	def next(self):
		self.ind = self.ind + 1
		if self.ind == len(self.parent.dir_modules):
			raise StopIteration
		else:
			return self.parent.dir_modules[self.ind]
		
