#
#  Part.py

import k

from Krank  import *
from Math   import *
from Sprite import *
from Effect import *
from Tools  import *

KDRAG  = 0.001

#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------

class Particle:
    
    #-------------------------------------------------------------------------------------------
    def __init__ (self, dict={}):
        self.imass = dict.get('imass', 1.0)
        
        self.pos = vector(dict.get('pos', (k.world.rect.centerx/2, k.world.rect.centery/2)))
        self.vel = vector(dict.get('vel', (0,0)))
        self.force  = vector(dict.get('force', (0,0)))
        
        self.drag   = dict.get('drag', KDRAG)
        self.radius = dict.get('radius', 13.0)

        self.player = dict.get('player', 0)
        self.captured = 0
        
        image = dict.get('image')

        self.color = dict.get('color')
        if not image and self.color is not None:
            image = 'levels/images/dot28_%s.png' % self.color
        
        self.sprite = Sprite(self.pos, image, dict.get('sprites'))
        k.particle_sprites.add(self.sprite)
        self.sprite.part = self
        
    #-------------------------------------------------------------------------------------------
    def resetForce (self):
        self.force.zero()
            
    #-------------------------------------------------------------------------------------------
    def applyForce (self, force):
        self.force += force
        
    #-------------------------------------------------------------------------------------------
    def calcForces (self):
        self.applyForce(-self.vel*self.vel.length()*self.drag)

    #-------------------------------------------------------------------------------------------
    def solveForce (self, seconds):
        self.vel += self.force * (self.imass * seconds)
        maxvel = self.player and 500 or k.world.rect.height
        if self.vel.length() > maxvel: 
            self.vel = self.vel.norm()*maxvel
        self.pos += self.vel * seconds
        self.sprite.setPos(self.pos)
        
#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------
class Stone (Particle):
    
    def __init__ (self, dict):
        dict['image'] = 'levels/images/dot28_s_%s.png' % dict.get('color')
        Particle.__init__(self, dict)
        self.imass = 0

#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------

class Switch (Particle):
    #-------------------------------------------------------------------------------------------
    def __init__ (self, dict={}):
      
        self.color = dict.get('color', 'white')

        if not dict.has_key('image'):
            icon = dict.get('size', 'large') == "large" and "menu" or "menu_small" 
            image = pygame.image.load('levels/images/%s_%s.png' % (icon, self.color))
            dict['image'] = image
                
        Particle.__init__(self, dict)
        
        self.radius = dict.get('radius', dict.get('size', 'large') == "large" and 32.0 or 16.0)
        self.action = dict.get('action')
        self.text   = dict.get('text')
        self.group  = dict.get('group')
        self.offset = dict.get('offset', self.radius*1.5)
        self.imass  = 0
        
        textsize = dict.get('textsize', dict.get('size', 'large'))
        
        if self.text:
            align = dict.get('align', 'left')
            if align == 'right':
                pos = self.pos+vector((-self.offset, 0))
            elif align == 'bottom':
                pos = self.pos+vector((0, self.offset*1.3))
                align = 'center'
            elif align == 'top':
                pos = self.pos+vector((0, -self.offset*1.3))
                align = 'center'                
            else:
                pos = self.pos+vector((self.offset, 0))
            drawText(self.text, pos, 
                     align=align, 
                     size=textsize,
                     valign=dict.get('valign', 'center'), 
                     color=dict.get('textcolor', (255, 255, 255)))
            
    #-------------------------------------------------------------------------------------------
    def collision_action (self):
        if not k.level.exit:
            if self.action:
                eval(self.action)
            if self.group:
                self.sprite.image = pygame.image.load('levels/images/menu_orange.png')
                for part in k.particles.parts:
                    if part <> self and hasattr(part, 'group') and part.group == self.group:
                        part.sprite.image = pygame.image.load('levels/images/menu_white.png')
     
#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------

class Magnet:
    #-------------------------------------------------------------------------------------------
    def __init__ (self, dict={}):
        
        self.pos = vector(dict.get('pos', (k.world.rect.centerx, k.world.rect.centery)))
        self.num = dict.get('num', 6)
        self.color = dict.get('color', 'white')
        self.angle = -math.pi/2
        self.anglefac = len(k.particles.magnets)%2 and 1 or -1
        
        image = pygame.Surface((64,64), pygame.SRCALPHA, 32)
        image.blit(pygame.image.load('levels/images/dot32_%s.png' % (self.color,)), (16, 16, 32, 32))
        self.sprite = Sprite(self.pos, image, k.magnet_sprites)
        self.sprite.magnet = self
        self.dots = []
        for i in range(self.num):
            dir = vector.withAngle(i*2*math.pi/(self.num)+self.angle, 16+10-1)
            image = pygame.image.load('levels/images/dot20_%s.png' % (self.color,))
            self.dots.append(Sprite(self.pos+dir, image, k.magnet_sprites))
        self.captured = sets.Set()
        self.oldCaptured = sets.Set()
                
        self.actionCounter = 0
        
    #-------------------------------------------------------------------------------------------
    def remove (self):
        SparkGroup(self.pos)
        self.sprite.kill()
        for dot in self.dots:
            dot.kill()
        k.particles.remove(self)
        
    #-------------------------------------------------------------------------------------------
    def applyAttractionForce(self, part):
        attraction = vector(part.pos.to(self.pos))
        dist = attraction.length()
        if dist < 32 + part.radius:
            part.applyForce(attraction*(40-part.player*15))
            self.captured.add(part)
            part.captured = 1
            
            if issubclass(part.__class__, Chain):
                part.unlink()
        
    #-------------------------------------------------------------------------------------------
    def onTick(self, delta):
        
        if self.oldCaptured.difference(self.captured):
            k.sound.play('magnet_off', 1.2-len(self.captured)/6.0, 1)
            for part in self.oldCaptured.difference(self.captured):
                part.captured = 0
        elif self.captured.difference(self.oldCaptured):
            k.sound.play('magnet_on', 0.2+len(self.captured)/6.0, 1)
        
        if len(self.captured):            
            nonplayer = len([p for p in self.captured if not p.player])
            if nonplayer >= self.num:
                self.actionCounter += delta
                if self.actionCounter > 3000:
                    self.exploding = True
                    self.actionCounter = 0
                    colorSet = sets.Set()
                    for p in self.captured:
                        p.vel += self.pos.to(p.pos).norm()*1000*k.world.forceFactor
                        p.captured = 0
                        colorSet.add(p.color)
                    nonchains = len([p for p in self.captured if (p.__class__==Particle)])
                    if len(colorSet) == 1 and self.color in colorSet and nonchains >= self.num:
                        self.remove()
                        if not k.level.checkExit():
                            k.sound.play('magnet_action', force=1)
                    else:
                        self.captured = sets.Set()
                        k.sound.play('magnet_action', force=1)
                elif int((self.actionCounter)/1000) > int((self.actionCounter-delta)/1000):
                    k.sound.play('magnet_start', force=1)
            else:
                self.actionCounter = 0

        self.oldCaptured = self.captured
        self.captured = sets.Set()

    #-------------------------------------------------------------------------------------------        
    def onFrame (self, delta):
        # rotation 
        self.angle += self.anglefac*math.pi*delta/8000
        for i in range(self.num):
            dir = vector.withAngle(self.angle+i*2*math.pi/(self.num), 16+10)
            self.dots[i].rect = pygame.Rect(self.pos+dir-vector((self.dots[i].rect.size))/2, self.dots[i].rect.size)
            
        if len(self.oldCaptured):
            pointlist = [(clamp(p.pos.x, self.pos.x-31, self.pos.x+31), clamp(p.pos.y, self.pos.y-31, self.pos.y+31)) for p in self.oldCaptured]
            pointlist.append(self.pos)
            pygame.draw.aalines(k.screen, (255,255,255), 1, pointlist, 1)
        
#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------
class Spring:
    #-------------------------------------------------------------------------------------------
    def __init__ (self, p1, p2, length=32, oneWay=1, spring=50.0, damp=5.0):

        self.p1 = p1
        self.p2 = p2
        self.rl = length
        self.ks = spring
        self.kd = damp
        self.oneWay = oneWay
        
    #-------------------------------------------------------------------------------------------
    def __str__ (self):
        return "<%02d %02d>" % (self.p1.index, self.p2.index)        
                
    #-------------------------------------------------------------------------------------------
    def calcForces(self):
        v = self.p1.pos - self.p2.pos

        l = v.length()
        if abs(l) > 0: 
            fac = (self.ks*(l-self.rl)+self.kd*((self.p1.vel-self.p2.vel).dot(v))/l)/l
            f = fac * v
            self.p2.applyForce(f)
            if not self.oneWay:
                self.p1.applyForce(-f)
            
    #-------------------------------------------------------------------------------------------            
    def onFrame (self, delta):
        if not self.p1.player:
            pointlist = [self.p1.pos, self.p2.pos]
            pygame.draw.aalines(k.screen, k.level.linkColor, 1, pointlist, 1)
        
    #-------------------------------------------------------------------------------------------            
    def getRect (self):
        if not self.p1.player:
            rect = pygame.Rect(self.p1.pos, self.p2.pos-self.p1.pos)        
            rect.normalize()
            rect = rect.inflate(8, 8)
            return rect
        return pygame.Rect(0,0,0,0)

#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------

class Chain (Particle):
    #-------------------------------------------------------------------------------------------
    def __init__ (self, dict={}):

        self.links = []
        self.springs = []
        self.chain = [self]
        self.maxLinks = dict.get('maxLinks', 2)
        self.isAnchor = False
        
        if not dict.has_key('image'):
            self.color = dict.get('color', 'white')
            image = pygame.image.load('levels/images/dot28_d_%s.png' % (self.color,))
            dict['image'] = image
        
        Particle.__init__(self, dict)
        
    #-------------------------------------------------------------------------------------------            
    def __str__ (self):
        return "[%02d links: %s chain: %s springs; %s]" % (self.index, 
                                                        ["%02d" % l.index for l in self.links], 
                                                        ["%02d" % c.index for c in self.chain], 
                                                        [str(s) for s in self.springs])        

    #-------------------------------------------------------------------------------------------    
    def hasFreeLink (self):
        return bool(max(0, self.maxLinks-len(self.links)))
    
    #-------------------------------------------------------------------------------------------    
    def linkedChains (self):
        chains = [self]
        if self.links: self.links[0].traverseChain(self, chains)
        return chains

    #-------------------------------------------------------------------------------------------    
    def traverseChain (self, source, chains):
        chains.append(self)
        if self.isAnchor: return
        if self.links:
            if self.links[0] <> source:                
                self.links[0].traverseChain(self, chains)
            elif len(self.links) > 1 and self.links[1] <> source:
                self.links[1].traverseChain(self, chains)
    
    #-------------------------------------------------------------------------------------------    
    def link (self, other):
        if not self.hasFreeLink():  return False
        if not other.hasFreeLink(): return False
        if other in self.chain: return False
        
        anchors = [p for p in self.chain if p.isAnchor]
        if anchors and anchors[0] in other.chain: return False

        k.sound.play('link')
        
        self.links.append(other)
        other.links.append(self)
        
        self.chain.extend(other.isAnchor and [other] or list(other.chain))
        for cp in self.chain:
            if not cp.isAnchor:
                cp.chain = list(self.chain)
        spring = Spring(self, other, length=3*self.radius, oneWay=0, spring=50.0, damp=5.0)
        k.particles.add(spring)
        self.springs.append(spring)
        other.springs.append(spring)
        
        anchors = [p for p in self.chain if p.isAnchor]
        if len(anchors) >= 2:
            for a in anchors:
                if a.checkExplode(): 
                    break                                
        return True
    
    #-------------------------------------------------------------------------------------------    
    def unlink (self):
        if self.links:

            k.sound.play('unlink')

            for cp in self.chain:
                if cp <> self and self in cp.chain:
                    cp.chain.remove(self)
                    
            for cp in self.links:
                for sp in cp.springs:
                    if sp.p1 == self or sp.p2 == self:
                        cp.springs.remove(sp)
                
                if self in cp.links:
                    cp.links.remove(self)
                                    
                cp.chain = cp.linkedChains()
                for c in cp.chain:
                    if not cp.isAnchor:
                        c.chain = list(cp.chain)
                    
            for sp in self.springs:
                k.particles.remove(sp)
                
            self.links   = []
            self.springs = []
            self.chain   = [self]

#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------
   
class Anchor (Chain):
    #-------------------------------------------------------------------------------------------
    def __init__ (self, dict={}):

        self.angle = -math.pi/2
        self.anglefac = len(k.particles.anchors)%2 and 1 or -1

        dict['maxLinks'] = dict.get('maxLinks', 6)

        self.color = dict.get('color', 'white')
        image = pygame.image.load('levels/images/dot32_d_%s.png' % (self.color,))
        dict['image'] = image
        
        dict['radius'] = 16.0
        Chain.__init__(self, dict)

        self.dots = []
        for i in range(self.maxLinks):
            dir = vector.withAngle(i*2*math.pi/(self.maxLinks)+self.angle, 16+10-1)
            image = pygame.image.load('levels/images/dot20_d_%s.png' % (self.color,))
            self.dots.append(Sprite(self.pos+dir, image, k.magnet_sprites))            
                    
        self.isAnchor = True
        self.imass = 0
        
        self.actionCounters = {}
        k.framed.append(self)

    #-------------------------------------------------------------------------------------------
    def onFrame (self, delta):
        
        for key in list(self.actionCounters.keys()):
            if key == self.color:
                if not self.checkComplete():
                    self.actionCounters.pop(key)
                    continue
            
            if key <> self.color:
                if not key in self.links:
                    self.actionCounters.pop(key)
                    continue
                    
                anchors = [p for p in key.chain if p.isAnchor]
                if len(anchors) < 2:
                    self.actionCounters.pop(key)
                    continue
                
            self.actionCounters[key] += delta
            
            if self.actionCounters[key] > 3000:
                chains = []
                if key == self.color:
                    allanchors = [a for a in k.particles.anchors if a.color == self.color]
                    for a in allanchors:
                        for l in a.links:
                            if l.chain not in chains:
                                chains.append(l.chain)
                else:
                    chains.append(key.chain)
                    
                for chain in chains:
                    anchors = [p for p in chain if p.isAnchor]
                    center = anchors[0].pos+anchors[0].pos.to(anchors[1].pos)*0.5

                    for p in chain:
                        if not p.isAnchor:
                            p.vel += (center.to(p.pos)+anchors[0].pos.to(p.pos)+anchors[1].pos.to(p.pos)).norm()*500**k.world.forceFactor
                            
                    for p in list(chain):
                        if not p.isAnchor:
                            p.unlink()
                
                if key == self.color:
                    for a in allanchors:
                        a.remove()
                        
                if not k.level.checkExit():
                    k.sound.play('anchor_action')
    
            elif int((self.actionCounters[key])/1000) > int((self.actionCounters[key]-delta)/1000):
                k.sound.play('magnet_start')
            
        # rotation 
        self.angle += self.anglefac*math.pi*delta/8000
        for i in range(self.maxLinks):
            dir = vector.withAngle(self.angle+i*2*math.pi/(self.maxLinks), 16+10)
            self.dots[i].rect = pygame.Rect(self.pos+dir-vector((self.dots[i].rect.size))/2, self.dots[i].rect.size)
                   
    #-------------------------------------------------------------------------------------------
    def checkComplete (self): 
        anchors = [a for a in k.particles.anchors if a.color == self.color]
        if True in map(lambda a: a.hasFreeLink(), anchors):
            return False
        for a in anchors:
            for link in a.links:
                if self.colorsInChain(link.chain) > 1:
                    return False
                if len([a for a in link.chain if a.isAnchor]) < 2:
                    return False
        return True
                        
    #-------------------------------------------------------------------------------------------
    def checkExplode (self):
        if len(self.links) == self.maxLinks:
            if self.checkComplete():
                self.actionCounters = {} 
                self.actionCounters[self.color] = 1
                return True
        for link in self.links:
            if len([p for p in link.chain if p.isAnchor]) == 2:
                if self.colorsInChain(link.chain) > 1:
                    if not self.actionCounters.has_key(link):
                        self.actionCounters[link] = 1
                        return True
        return False
    
    #-------------------------------------------------------------------------------------------
    def colorsInChain (self, chain):
        colorSet = sets.Set()
        for p in chain:
            colorSet.add(p.color)
        return len(colorSet)
            
    #-------------------------------------------------------------------------------------------
    def remove (self):
        self.actionCounters = {}
        SparkGroup(self.pos)
        self.sprite.kill()
        for dot in self.dots: dot.kill()
        k.particles.remove(self)
        k.framed.remove(self)
                            
#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------

class Collision:
    def __init__(self, part, normal, factor=0):
        self.p = part
        self.n = normal
        self.f = factor
        self.s = None

#-----------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------

class Particles:
    #-------------------------------------------------------------------------------------------
    def __init__ (self):
        #log(log='startup')
        k.particles = self
        self.restdelta = 0
        self.reset()
        
    #-------------------------------------------------------------------------------------------
    def reset (self):
        self.parts = []
        self.magnets = []
        self.springs = []
        self.anchors = []
                
    #-------------------------------------------------------------------------------------------
    def add (self, item):
        if issubclass(item.__class__, Particle):
            self.parts.append(item)
            item.index = len(self.parts)
            if issubclass(item.__class__, Anchor):
                self.anchors.append(item)
        elif item.__class__ == Magnet:
            self.magnets.append(item)
        elif item.__class__ == Spring:
            self.springs.append(item)
            
    #-------------------------------------------------------------------------------------------
    def stoneCircle (self, pos, color, num, radius, start=0):
        for i in range(num):
            p = vector(pos) + vector.withAngle(start+i*2*math.pi/num, radius)
            self.add(Stone({'pos': p, 'color': color}))
            
    #-------------------------------------------------------------------------------------------
    def ballCircle (self, pos, color, num, radius, start=0):
        for i in range(num):
            p = vector(pos) + vector.withAngle(start+i*2*math.pi/num, radius)
            self.add(Particle({'pos': p, 'color': color}))

    #-------------------------------------------------------------------------------------------
    def chainCircle (self, pos, color, num, radius, start=0):
        for i in range(num):
            p = vector(pos) + vector.withAngle(start+i*2*math.pi/num, radius)
            self.add(Chain({'pos': p, 'color': color}))
                        
    #-------------------------------------------------------------------------------------------
    def remove (self, item):
        if issubclass(item.__class__, Particle):
            if item in self.parts:
                self.parts.remove(item)
            if issubclass(item.__class__, Anchor):
                if item in self.anchors:
                    self.anchors.remove(item)
        elif item.__class__ == Magnet:
            if item in self.magnets:
                self.magnets.remove(item)
        elif item.__class__ == Spring:
            if item in self.springs:
                self.springs.remove(item)
        
    #-------------------------------------------------------------------------------------------
    def onFrame (self, delta):
        stepSize = clamp(delta, 1, 10)
        steps = clamp(delta/stepSize, 1, 20)
        stepSize = 1.0*delta/steps
        for i in range(steps):
            k.player.onTick(stepSize)
            self.calcForces()
            self.solveForces(stepSize/1000.0)
            self.checkCollisions()
            self.resolveCollisions()
            self.resetForces()
            for magnet in self.magnets:
                magnet.onTick(stepSize)

    #-------------------------------------------------------------------------------------------
    def resetForces(self):
        for part in self.parts:
            part.resetForce()

    #-------------------------------------------------------------------------------------------
    def calcForces (self):
        for part in self.parts:

            part.calcForces()
            # magnet collisions
            magnets = pygame.sprite.spritecollide(part.sprite, k.magnet_sprites, 0)
            for magnet in magnets:
                if hasattr(magnet, 'magnet'):
                    magnet.magnet.applyAttractionForce(part)
                
        for spring in self.springs:
            spring.calcForces()

    #-------------------------------------------------------------------------------------------
    def solveForces (self, seconds):
        for part in self.parts:
            part.solveForce(seconds)
                
    #-------------------------------------------------------------------------------------------
    def checkCollisions(self):
        self.collisions = []
        for index in range(len(self.parts)):
            part = self.parts[index]

            norm = None
            if part.pos.y < k.world.rect.top:
                norm = vector((0,1))
            elif part.pos.y > k.world.rect.bottom:
                norm = vector((0,-1))
            elif part.pos.x < k.world.rect.left:
                norm = vector((1,0))
            elif part.pos.x > k.world.rect.right:
                norm = vector((-1,0))
                
            if norm <> None:
                part.pos = pos(clamp(part.pos.x, 0, k.world.rect.width), clamp(part.pos.y, 0, k.world.rect.height))
                if (norm.dot(norm)*part.imass) <> 0:
                    factor = 2 * part.vel.dot(norm) / (norm.dot(norm)*part.imass)
                else: 
                    factor = 0
                collision = Collision(part, -norm, -factor)
                self.collisions.append(collision)
                k.sound.play('wall', clamp((part.vel.length()-50)/200, 0, 1))
                SparkGroup(part.pos, int(-factor/10))
                
            collisions = pygame.sprite.spritecollide(part.sprite, k.particle_sprites, 0)
            
            for collision in collisions:
                other = collision.part
                if other <> part:
                    partToOther = part.pos.to(other.pos)
                    distance = partToOther.length()
                    radius = part.radius + other.radius
                    if distance < radius:                        
                        relvel = part.vel - other.vel                        
                        norm = vector(partToOther).norm()
                        if relvel.dot(norm) > 0: # particles approaching
                            f = 1.0-distance/radius
                            factor = min(1000, 2 * (f*16+relvel.dot(norm)) / (norm.dot(norm)*(part.imass+other.imass)))
                            
                            tailtotail = part.player and other.player and (abs(part.player-other.player)==1)
                            
                            if not tailtotail:
                                self.collisions.append(Collision(part, -norm, -factor))
                                                            
                            if not part.captured:
                                
                                if not other.captured:
                                    self.handleChainCollision(part, other)

                                if not tailtotail: # ignore tail to tail collisions
                                    if part.index < other.index:
                                        if other.__class__ <> Switch:
                                            k.sound.play('part', clamp((relvel.length()-50)/200, 0, 1))
                                        SparkGroup(part.pos+partToOther*0.5, int(factor/10))
                                        if not part.player and other.player:
                                            k.score += factor/10.0
                                    if hasattr(part, 'collision_action'):
                                        part.collision_action()
                                        if k.reset: return
                                        
    #-------------------------------------------------------------------------------------------
    def handleChainCollision(self, c1, c2):
        if not issubclass(c1.__class__, Chain):
            if c1.__class__ == Particle and not c1.player and issubclass(c2.__class__, Chain):
                c2.unlink()
            return False
        if not issubclass(c2.__class__, Chain): 
            if c2.__class__ == Particle and not c2.player and issubclass(c1.__class__, Chain):
                c1.unlink()
            return False
        if c1.__class__ == Chain:
            return c1.link(c2)
        elif c2.__class__ == Chain:
            return c2.link(c1)
        return False
                    
    #-------------------------------------------------------------------------------------------
    def resolveCollisions(self):
        for collision in self.collisions:
            collision.p.vel -= collision.f * collision.n * collision.p.imass


             