import sys

from ga_util import shallow_clone


class AddError(Exception):
    """ Cannot add to the tree.
    """

class ParentError(Exception):
    """ Parent is incorrect.
    """

class base_tree_node(object):
    objects_ever = 0
    objects = 0
    circular = 0

    def inc(self):
        base_tree_node.objects = base_tree_node.objects + 1
        base_tree_node.objects_ever = base_tree_node.objects_ever + 1

    def dec(self):
        base_tree_node.objects = base_tree_node.objects - 1

    def __init__(self,child_count,node_type='',derive_type='', parent=None):
#               print 'trenode init',type(parent)
        self.inc()
        if parent is None:
            self.symbol_table = {}
        self.node_type = node_type
        self.derive_type = derive_type
        self.child_count = child_count
        self._children = []
        self.set_parent(parent)
        self.label = node_type

    #so that we don't overwrite the general creation function in a derived class
    #and screw up clone
    def _create(self,parent = None):
        """
        make lean and mean cause its called a ton
        """
        self.inc()
        new = shallow_clone(self)
        new.set_parent(parent)
        #next 2 lines removed to higher level for speed
        #new._children = []
        #if parent is None: new.symbol_table = {}
        return new

    def create(self,parent = None):
        new = self._create(parent)
        new._children = []
        if parent is None:
            new.symbol_table = {}
        return new

    def clone(self,parent = None):
        """
        make lean and mean cause its called a ton
        """
        new = self._create(parent)
        new._children = [x.clone(parent) for x in self._children]
        if parent is None:
            new.generate_symbol_table()
        return new

    def set_parent(self,parent):
        self.parent = parent

    def get_parent(self):
        return self.parent

    def generate_symbol_table(self):
        self.symbol_table = {}
        self._generate_symbol_table(self.symbol_table)

    def _generate_symbol_table(self,symbol_table):
        """
        Return:

        symbol_table -- a dictionary with derive_type
                 symbols as its keys.  Each value is a
                 list of all the nodes in this tree with
                 that derive_type.
        """
        for child in self._children:
            child._generate_symbol_table(symbol_table)
        if symbol_table.has_key(self.derive_type):
            symbol_table[self.derive_type].append(self)
        else:   symbol_table[self.derive_type] = [self]

    def parent_test(self):
        for child in self._children:
            child.parent_test()
        for child in self._children:
            if not child.get_parent() is self:
                #pdb.set_trace()
                raise ParentError

    def node_count(self,type=None):
        cnt = 0
        for child in self._children:
            cnt = cnt + child.node_count(type)
        if self.node_type == type or type is None:
            cnt = cnt + 1
        return cnt

    def apply_node(self,func,args):
        for child in self._children:
            child.apply_node(func,args)
        apply(func,(self,)+args)

    def add_child(self,node):
        if len(self._children) < self.child_count:
            node.set_parent(self)
            self._children.append(node)
        else:
            raise AddError('too many children')

    def filled(self):
        return len(self._children) >= self.child_count

    def children(self):
        return self._children

    def leaves(self):
        if self.child_count == 0:
            return 1
        else:
            leaves = 0
            for child in self._children:
                leaves = leaves + child.leaves()
        return leaves

    def depth(self):
        if self.child_count == 0:
            return 1
        else:
            return max(map(tree_node.depth,self._children)) + 1

    def ancestors(self):
        if self.get_parent(): return self.get_parent().ancestors() + 1
        return 1

    def root(self):
        if self.get_parent(): return self.get_parent().root()
        return self

    # I needed this to output string chromosomes for antennas
    # rethink this later
    def file_output(self,file):
        for child in self._children: child.file_output(file)

    def __repr__(self):
        res = '%s %s' % (self.label, self.derive_type)
        if len(self._children):
#                               res = res + ':%s' % self._children
            res =  res + '\n' + '\t'*self.ancestors()
            res = res + '%s' % self._children
        return res

    def delete_circulars(self):
        #if hasattr(self,'parent'):
        #       if self.parent is None: print 'deleting root ciculars'
        base_tree_node.circular = base_tree_node.circular + 1
        self.symbol_table = None
        for child in self._children:
            if len(child._children):
                child.delete_circulars()
            else:
                base_tree_node.circular = base_tree_node.circular + 1
            child.parent = None
            if sys.getrefcount(child) > 3:
            #       print 'aaaah!', sys.getrefcount(child)
                pass
            del child

    def __del__(self):
#               print 'tree_node killed:',tree_node.objects
        self.dec()
        #base_tree_node.objects = base_tree_node.objects - 1

    def __cmp__(self,other):
        #like ga_list compare...
        try:
            return cmp(self.__dict__,other.__dict__)
        except AttributeError:
            return 1

    def __setstate__(self,state):
        for key in state.keys():
            setattr(self, key, state[key])
        self.inc()


##core dumps on linux
#import weakdict
#class weak_tree_node(base_tree_node,weakdict.WeakValue):
#        def __init__(self,child_count,node_type='',derive_type='', parent=None):
#                weakdict.WeakValue.__init__(self)
#                base_tree_node.__init__(self,child_count,node_type,derive_type, parent)
#        def set_parent(self,parent):
#                print 'in set'
#                if not hasattr(self,'parent'):
#                        self.parent = weakdict.WeakDict()
#                if parent: self.parent[0] = parent
#                elif self.parent.has_key(0): del self.parent[0]
#                print 'out set'
#        def get_parent(self):
#                print 'in get'
#                if self.parent.has_key(0): p =  self.parent[0]
#                else: p = None
#                print 'out get'
#                return p
#        def delete_circulars(self):
#                pass

#import mxProxy
#class proxy_tree_node(base_tree_node):
#        passobj = 2 #could be anything
#        def set_parent(self,parent):
#                self.parent = mxProxy.WeakProxy(parent,None,self.passobj)
#        def get_parent(self):
#                if self.parent: return self.parent.proxy_object(self.passobj)
#                return None
#        def delete_circulars(self):
#                pass

tree_node = base_tree_node
#tree_node = weak_tree_node
#tree_node = proxy_tree_node

def ref():
    print 'current', base_tree_node.objects
    print 'ever', base_tree_node.objects_ever
    print 'circular deletes', base_tree_node.circular

def test_treenode():
    a = tree_node(2,'root')
    a.add_child(tree_node(0,'kid1'))
    a.add_child(tree_node(0,'kid2'))
    print a
    return a
    """
#       regress.test('tree_node: add child',a,notes)
    print a

    notes = 'tree_node: add child'
    a = tree_node(2,'root')
    a.add_child(tree_node(0,'kid1'))
    a.add_child(tree_node(0,'kid2'))
    temp = a.children()[0]
    a.children()[0] = a.children()[1]
    a.children()[1] = temp
    print a
    print a.children()[1].root()
    regress.test('tree_node: swap kids',a,notes)
    """
