require "rexml/functions"

module REXML
	class XPath
		include Functions

		def XPath::first element, path
			match(element, path)[0]
		end

		def XPath::each element, path, &block
			path = "*" unless path
			match(element, path).each( &block )
		end

		def XPath::match element, path
			raise "nil is not a valid xpath" unless path
			results = nil
			case path
			when /^\/([^\/]|$)/u
				# match on root
				path = path[1..-1]
				return [element.root.parent] if path == ''
				results = filter([element.root], path)
			when /^[\w-]*::/u
				results = filter([element], path)
			when /^\*/u
				results = filter(element.to_a, path)
			when /^[[:!\w_]/u
				# match on child
				matches = []
				children = element.to_a
				results = filter(children, path)
			else
				results = filter([element], path)
			end
			return results
		end

		# Given an array of nodes it filters the array based on the path. The
		# result is that when this method returns, the array will contain elements
		# which match the path
		def XPath::filter elements, path
			return elements if path.nil? or path == '' or elements.size == 0
			case path
			when /^\/\//u											# Descendant
				return axe( elements, "descendant-or-self", $' )
			when /^\/?\b(\w[\w-]*)\b::/u							# Axe
				axe_name = $1
				rest = $'
				return axe( elements, $1, $' )
			when /^\/(?=\b([:!\w_][-\.\w_]*:)?[-!\*\.\w_]*\b([^:(]|$)|\*)/u	# Child
				rest = $'
				results = []
				elements.each do |element|
					results |= filter( element.to_a, rest )
				end
				return results
			when /^\/?(\w[\w-]*)\(/u							# / Function
				return function( elements, $1, $' )
			when /^\b((?:[:!\w_][-\.\w_]*:)?[-!\*\.\w_]*)\b/u		# Element name
				name = $1
				rest = $'
				elements.delete_if do |element|
					!(element.kind_of? Element and element.name == name)
				end
				return filter( elements, rest )
			when /^\/\[/u
				matches = []
				elements.each do |element|
					matches |= predicate( element.to_a, path[1..-1] ) if element.kind_of? Element
				end
				return matches
			when /^\[/u												# Predicate
				return predicate( elements, path )
			when /^\/?\.\.\./u										# Ancestor
				return axe( elements, "ancestor", $' )
			when /^\/?\.\./u											# Parent
				return filter( elements.collect{|e|e.parent}, $' )
			when /^\/?\./u												# Self
				return filter( elements, $' )
			when /^\*/u													# Any
				results = []
				elements.each do |element|
					results |= filter( [element], $' ) if element.kind_of? Element
					#if element.kind_of? Element
					#	children = element.to_a
					#	children.delete_if { |child| !child.kind_of?(Element) }
					#	results |= filter( children, $' )
					#end
				end
				return results
			end
			return []
		end

		def XPath::axe( elements, axe_name, rest )
			matches = []
			matches = filter( elements.dup, rest ) if axe_name =~ /-or-self$/u
			case axe_name
			when /^descendant/u
				elements.each do |element|
					matches |= filter( element.to_a, "descendant-or-self::#{rest}" ) if element.kind_of? Element
				end
			when /^ancestor/u
				elements.each do |element|
					while element.parent
						matches << element.parent
						element = element.parent
					end
				end
				matches = filter( matches, rest )
			when "self"
				matches = filter( elements, rest )
			when "child"
				elements.each do |element|
					matches |= filter( element.to_a, rest ) if element.kind_of? Element
				end
			when "attribute"
				elements.each do |element|
					matches << element.attributes[ rest ] if element.kind_of? Element
				end
			when "parent"
				matches = filter(elements.collect{|element| element.parent}.uniq, rest)
			when "following-sibling"
				matches = filter(elements.collect{|element| element.next_sibling}.uniq,
					rest)
			when "previous-sibling"
				matches = filter(elements.collect{|element| 
					element.previous_sibling}.uniq, rest )
			end
			return matches.uniq
		end

		def XPath::predicate( elements, path ) 
			ind = 1
			bcount = 1
			while bcount > 0
				bcount += 1 if path[ind] == ?[
				bcount -= 1 if path[ind] == ?]
				ind += 1
			end
			ind -= 1
			predicate = path[1..ind-1]
			rest = path[ind+1..-1]

			# have to change 'a [=<>] b [=<>] c' into 'a [=<>] b and b [=<>] c'
			predicate.gsub!( /([^\s(and)(or)<>=]+)\s*([<>=])\s*([^\s(and)(or)<>=]+)\s*([<>=])\s*([^\s(and)(or)<>=]+)/u ) { 
				"#$1 #$2 #$3 and #$3 #$4 #$5"
			}
			# Let's do some Ruby trickery to avoid some work:
			predicate.gsub!( /&/u, "&&" )
			predicate.gsub!( /=/u, "==" )
			predicate.gsub!( /@(\w[\w.-]*)/u ) {
				"attribute(\"#$1\")" 
			}
			predicate.gsub!( /\bmod\b/u, "%" )
			predicate.gsub!( /\b(\w[\w.-]*\()/u ) {
				fname = $1
				fname.gsub( /-/u, "_" )
			}
			
			Functions.pair = [ 0, elements.size ]
			results = []
			elements.each do |element|
				Functions.pair[0] += 1
				Functions.node = element
				res = eval( predicate )
				case res
				when true
					results << element
				when Fixnum
					results << element if Functions.pair[0] == res
				when String
					results << element
				end
			end
			return filter( results, rest )
		end

		def XPath::attribute( name )
			return Functions.node.attributes[name] if Functions.node.kind_of? Element
		end

		def XPath::name()
			return Functions.node.name if Functions.node.kind_of? Element
		end

		def XPath::method_missing( id, *args )
			begin
				Functions.send( id.id2name, *args )
			rescue Exception
				raise "METHOD: #{id.id2name}(#{args.join ', '})\n#{$!.message}"
			end
		end

		def XPath::function( elements, fname, rest )
			args = parse_args( elements, rest )
			Functions.pair = [0, elements.size]
			results = []
			elements.each do |element|
				Functions.pair[0] += 1
				Functions.node = element
				res = Functions.send( fname, *args )
				case res
				when true
					results << element
				when Fixnum
					results << element if Functions.pair[0] == res
				end
			end
			return results
		end

		def XPath::parse_args( element, string )
			# /.*?(?:\)|,)/
			arguments = []
			buffer = ""
			while string and string != ""
				c = string[0]
				string.sub!(/^./u, "")
				case c
				when ?,
					# if depth = 1, then we start a new argument
					arguments << evaluate( buffer )
					#arguments << evaluate( string[0..count] )
				when ?(
					# start a new method call
					function( element, buffer, string )
					buffer = ""
				when ?)
					# close the method call and return arguments
					return arguments
				else
					buffer << c
				end
			end
			""
		end
	end
end
