
from Krank import *

#import sys, pygame
#pygame.init()
# 
#size = width, height = 800, 600
#speed = [2, 2]
#black = 0, 0, 0
#
#screen = pygame.display.set_mode(size)
#
#ball = pygame.image.load("ball.bmp")
#ballrect = ball.get_rect()
#
#while 1:
#    for event in pygame.event.get():
#        if event.type == pygame.QUIT: sys.exit()
#    ballrect = ballrect.move(speed)
#    if ballrect.left < 0 or ballrect.right > width:
#        speed[0] = -speed[0]
#    if ballrect.top < 0 or ballrect.bottom > height:
#        speed[1] = -speed[1]
#
#    screen.fill(black)
#    screen.blit(ball, ballrect)
#    pygame.display.flip()

springs = []
chains  = []

#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------
class Spring:
    #-------------------------------------------------------------------------------------------
    def __init__ (self, p1, p2):

        self.p1 = p1
        self.p2 = p2
        
    #-------------------------------------------------------------------------------------------
    def __str__ (self):
        return "<%02d %02d>" % (self.p1.index, self.p2.index)
                
#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------
class Chain:
    #-------------------------------------------------------------------------------------------
    def __init__ (self, dict={}):

        self.links = []
        self.springs = []
        self.chain = [self]
        self.index = dict.get('index')
        
    #-------------------------------------------------------------------------------------------            
    def __str__ (self):
        return "[%02d links: %s chain: %s springs; %s]" % (self.index, 
                                                        ["%02d" % l.index for l in self.links], 
                                                        ["%02d" % c.index for c in self.chain], 
                                                        [str(s) for s in self.springs])

    #-------------------------------------------------------------------------------------------    
    def hasFreeLink (self):
        return max(0, 2-len(self.links))
    
    #-------------------------------------------------------------------------------------------    
    def linkedChains (self):
        chains = [self]
        if self.links: self.links[0].traverseChain(self, chains)
        return chains

    #-------------------------------------------------------------------------------------------    
    def traverseChain (self, source, chains):
        chains.append(self)
        if self.links:
            if self.links[0] != source:                
                self.links[0].traverseChain(self, chains)
            elif len(self.links) > 1 and self.links[1] != source:
                self.links[1].traverseChain(self, chains)
    
    #-------------------------------------------------------------------------------------------    
    def link (self, c):
        if not self.hasFreeLink(): return False
        if not c.hasFreeLink():    return False
                
        if not c in self.chain:
            #log('link: ', self.index, c.index, [a.index for a in self.chain], [a.index for a in c.chain])
            log()
            log("linking %02d -> %02d" % (self.index, c.index))
            self.links.append(c)
            self.chain.extend(c.chain)
            c.links.append(self)
            for cp in self.chain:
                cp.chain = list(self.chain)
            spring = Spring(self, c)
            springs.append(spring)
            self.springs.append(spring)
            c.springs.append(spring)
            
            for c in self.chain:
                log(c)
            return True
        return False
    
    #-------------------------------------------------------------------------------------------    
    def unlink (self):
        
        log()
        log("unlinking" , self)
            
        for cp in self.chain:
            if cp != self and self in cp.chain:
                cp.chain.remove(self)
                
        for cp in self.links:
            for sp in cp.springs:
                if sp.p1 == self or sp.p2 == self:
                    cp.springs.remove(sp)
            
            if self in cp.links:
                cp.links.remove(self)
                                
            cp.chain = cp.linkedChains()
            for c in cp.chain:
                c.chain = list(cp.chain)
                
        for sp in self.springs:
            springs.remove(sp)
            
        self.links   = []
        self.springs = []
        self.chain   = []
        
        for c in chains:
            log(c)
        
#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------
        
if __name__ == "__main__":
    
    for i in range(5):
        chains.append(Chain({'index': len(chains)}))
    
    chains[1].link(chains[0])
    chains[1].link(chains[2])
    chains[2].link(chains[3])
    chains[3].link(chains[4])
    
    chains[2].unlink()
    
