# Copyright (c) 2011, Peter A. Bigot, licensed under New BSD (see COPYING)
# This file is part of msp430mcu (http://sourceforge.net/projects/mspgcc/)
#

import csv
import re
import os

msp430_root = os.environ.get('MSP430_ROOT', '/msp430')
msp430mcu_root = os.environ.get('MSP430MCU_ROOT', os.path.join(msp430_root, 'msp430mcu'))

analysis_dir = os.path.join(msp430mcu_root, 'analysis')
upstream_dir = os.path.join(msp430mcu_root, 'upstream')

def analysis_path (filename):
    global analysis_dir
    return os.path.join(analysis_dir, filename)

def upstream_path (filename):
    global upstream_dir
    return os.path.join(upstream_dir, filename)

class Region (object):
    name = None
    attributes = None
    address_width = 4
    __is_fixed = False
    origin = None
    length = None
    segments = None
    segment_size = None

    def __eq__ (self, other):
        return (self.name == other.name
                and self.origin == other.origin
                and self.length == other.length
                and self.segment_size == other.segment_size)

    def __init__ (self, name, attributes=None, address_width=4):
        self.name = name
        self.attributes = attributes
        self.address_width = address_width
        self.reset()

    def set (self, origin, length, segment_size=None, segments=None, fixed=False):
        assert not self.__is_fixed
        if fixed or ((0 != origin) and (0 != length)):
            self.origin = origin
            self.length = length
            if segment_size is not None:
                self.segment_size = segment_size
                self.segments = length / segment_size
                assert self.length == self.segments * self.segment_size, 'Length %d expect %d * %d' % (self.length, self.segments, self.segment_size)
        self.__is_fixed = fixed
        return self

    def reset (self):
        if not self.__is_fixed:
            self.origin = self.length = 0

    def __cmp__ (self, other):
        # Sort infomem before other info sections
        if self.name.startswith('info') and other.name.startswith('info') and (len(self.name) != len(other.name)):
            return - cmp(len(self.name), len(other.name))
        # If both lengths zero, sort by name within near/far
        if (0 == self.length) and (0 == other.length):
            if self.address_width != other.address_width:
                return cmp(self.address_width, other.address_width)
            return cmp(self.name, other.name)
        # If one length zero, sort it after
        if (0 == self.length) or (0 == other.length):
            return - cmp(self.length, other.length)
        return cmp(self.origin, other.origin) or cmp(self.length, other.length)

    __M = 1024 * 1024
    __K = 1024

    def _integerToLinkConstants (self, b):
        m = b / self.__M
        if (1 <= m) and (0 == (b % self.__M)):
            return '%uM' % (m,)
        k = b / self.__K
        if (1 <= k) and (0 == (b % self.__K)):
            return '%uK' % (k,)
        return '%u' % (b,)

    def _formatAsString (self, name_width):
        attr_str = ''
        if self.attributes:
            attr_str = ' (%s)' % (self.attributes,)
        rv = []
        rv.append('%-*s : ' % (name_width, self.name + attr_str))
        rv.append('ORIGIN = 0x%0*x, ' % (self.address_width, self.origin))
        rv.append('LENGTH = 0x%0*x' % (self.address_width, self.length))
        if 0 < self.length:
            segment_info = ''
            if self.segment_size:
                segment_info = ' as %d %d-byte segments' % (self.segments, self.segment_size)
            rv.append(' /* END=0x%0*x, size %s%s */' % (self.address_width, self.origin + self.length, self._integerToLinkConstants(self.length), segment_info))
        return ''.join(rv)

    @classmethod
    def Memory (cls, regions):
        regions = sorted(regions)
        name_width = 0
        valid_bank = True
        for r in regions:
            n = len(r.name)
            if r.attributes:
                n += 3 + len(r.attributes)
            name_width = max(name_width, n)
        rv = [ 'MEMORY {']
        for r in regions:
            if valid_bank and (0 == r.length):
                rv.append('  /* Remaining banks are absent */')
                valid_bank = False
            rv.append('  ' + r._formatAsString(name_width))
        rv.append('}')
        return "\n".join(rv);
    
    @classmethod
    def Reset (cls, regions):
        [ _r.reset() for _r in cls._Regions ]

    def __str__ (self):
        return '%s@0x%0*x (%s)' % (self.name, self.address_width, self.origin, self._integerToLinkConstants(self.length))

class _DeviceTypeBase (object):
    key = None
    tag = None
    enum_value = None
    
    __ClsKeyMap = { }
    __ClsTagMap = { }

    @classmethod
    def __KeyMap (cls):
        return cls.__ClsKeyMap.setdefault(cls, {})

    @classmethod
    def __TagMap (cls):
        return cls.__ClsTagMap.setdefault(cls, {})

    @classmethod
    def LookupByKey (cls, key):
        return cls.__KeyMap().get(key)
    
    @classmethod
    def LookupByTag (cls, tag):
        return cls.__TagMap().get(tag)
    
    def __init__ (self, key, tag, enum_value):
        self.key = key
        assert key not in self.__KeyMap(), '%s duplicated key in %s' % (key, self)
        self.__KeyMap()[key] = self
        self.tag = tag
        assert tag not in self.__TagMap(), '%s duplicated tag in %s' % (tag, self)
        self.__TagMap()[tag] = self
        self.enum_value = enum_value

    def __str__ (self):
        return self.tag


class CPU (_DeviceTypeBase):
    pass

# MSP430: 16-bit MCU
CPU_MSP430 = CPU('0', '430', 0)
# MSP430X: 16-bit MCU with 20-bit addresses
CPU_MSP430X = CPU('1', '430x', 0x02)
# MSP430X with altered timing characteristics
CPU_MSP430XV2 = CPU('2', '430xv2', 0x03)

class MPY (_DeviceTypeBase):
    TYPE_16 = 0x0010
    TYPE_32 = 0x0020
    HAS_SE = 0x0001
    HAS_DW = 0x0002

MPY_NONE = MPY('0', 'none', 0)
# 16-bit
MPY_16 = MPY('1', '16', MPY.TYPE_16)
# 16-bit with sign extension
MPY_16SE = MPY('2', '16se', MPY_16.enum_value + MPY.HAS_SE)
# 32-bit
MPY_32 = MPY('4', '32', MPY.TYPE_16 + MPY.TYPE_32 + MPY.HAS_SE)
# 32-bit with delayed write
MPY_32DW = MPY('8', '32dw', MPY_32.enum_value + MPY.HAS_DW)

KnownDevices = []

# Device Name,CPU_TYPE,CPU_Bugs,MPY_TYPE,STACKSIZE,RAMStart,RAMEnd,RAMStart2,RAMEnd2,USBRAMStart,USBRAMEnd,MirrowedRAMSource,MirrowedRAMStart,MirrowRAMEnd,BSLStart,BSLSize,BSLEnd,INFOStart,INFOSize,INFOEnd,INFOA,INFOB,INFOC,INFOD,FStart,FEnd,FStart2,FEnd2,INTStart,INTEnd
class MCU (object):
    __HexFields = ( 'STACKSIZE',
                    'RAMStart',
                    'RAMEnd',
                    'RAMStart2',
                    'RAMEnd2',
                    'USBRAMStart',
                    'USBRAMEnd',
                    'MirroredRAMSource',
                    'MirroredRAMStart',
                    'MirroredRAMEnd',
                    'BSLStart',
                    'BSLSize',
                    'BSLEnd',
                    'INFOStart',
                    'INFOSize',
                    'INFOEnd',
                    'INFOA',
                    'INFOB',
                    'INFOC',
                    'INFOD',
                    'FStart',
                    'FEnd',
                    'FStart2',
                    'FEnd2',
                    'INTStart',
                    'INTEnd' )
    _DeviceFileFields = ( 'Device_Name',
                          'CPU_TYPE',
                          'CPU_Bugs',
                          'MPY_TYPE', ) + __HexFields
                          
    # Part-number regex.
    __PartNumber_re = re.compile('''
^(?P<processorFamily>[a-z]*)
(?P<platform>430)
(?P<deviceType>[a-z]+)
(?P<generation>[0-9])
(?P<family>[0-9])
(?P<series>[0-9]+)
(?P<revision>[a-z]?)
$''', re.VERBOSE + re.IGNORECASE)

    __PartNumberFields = ( 'processorFamily', 'platform', 'deviceType', 'generation',
                           'family', 'series', 'revision' )

    sfr = Region('sfr').set(0x0000, 0x0010, fixed=True)
    peripheral_8bit = Region('peripheral_8bit').set(0x0010, 0x00F0, fixed=True)
    peripheral_16bit = Region('peripheral_16bit').set(0x0100, 0x0100, fixed=True)
    bsl = None
    infoa = None
    infob = None
    infoc = None
    infod = None
    infomem = None
    ram = None
    rom = None
    far_rom = None
    vectors = None

    __Regions = set()
    __Regions.add(sfr)
    __Regions.add(peripheral_8bit)
    __Regions.add(peripheral_16bit)
    _Regions = None
    
    def checkCompatible (self, other):
        mismatches = []
        for region in ( 'bsl', 'infoa', 'infob', 'infoc', 'infod', 'infomem', 'ram', 'rom', 'far_rom', 'vectors', 'mpy', 'cpu', 'cpu_bugs' ):
            srg = self.__dict__.get(region)
            org = other.__dict__.get(region)
            if srg != org:
                mismatches.append( (srg, org) )
                
        return mismatches

    def _addRegion (self, region):
        if self._Regions is None:
            self._Regions = self.__Regions.copy()
        self._Regions.add (region)
        return region

    __ExtractRegionKW = ( 'attributes', 'address_width' )
    def _extractRegionData (self, name, tag=None, suffix=None, segment_size=None, **kw):
        ctor_kw = { }
        for k in self.__ExtractRegionKW:
            if k in kw:
                ctor_kw.setdefault(k, kw.get(k))
        region = self._addRegion(Region(name, **ctor_kw))

        if tag is None:
            tag = name.upper()
        if suffix is None:
            suffix = ''
        start = kw.get('%sStart%s' % (tag, suffix), 0)
        segment_size = kw.get('%sSize%s' % (tag, suffix), segment_size)
        end = kw.get('%sEnd%s' % (tag, suffix), 0)
        if 0 == (start + end):
            return region
        length = 1 + end - start
        if segment_size is not None:
            segments = length / segment_size
            assert length == segments * segment_size
        return region.set(start, length, segment_size)

    # Map from genericized part numbers to specific part numbers that
    # correspond to them.
    GenericsMap = { }

    __MCUMap = { }

    @classmethod
    def Lookup (cls, mcu):
        return cls.__MCUMap.get(mcu)

    def __init__ (self, mcu, cpu='430', mpy='none', **kw):
        global KnownDevices
        self._Regions = self.__Regions.copy()
        self.mcu = mcu
        assert mcu not in self.__MCUMap
        self.__MCUMap[mcu] = self
        KnownDevices.append(self)
        self.cpu = cpu
        self.cpu_bugs = kw.get('CPU_Bugs')
        self.mpy = mpy
        
        self.bsl = self._extractRegionData('bsl', **kw)
        self.ram = self._extractRegionData('ram', attributes='wx', **kw)
        self.ram2 = self._extractRegionData('ram2', suffix='2', attributes='wx', **kw)
        self.ram_mirror = self._extractRegionData('ram_mirror', tag='MirroredRAM', attributes='wx', **kw)
        self.usbram = self._extractRegionData('usbram', tag='USBRAM', attributes='wx', **kw)
        self.rom = self._extractRegionData('rom', tag='F', attributes='rx', **kw)
        self.far_rom = self._extractRegionData('far_rom', tag='F', suffix='2', address_width=8, **kw)
        self.vectors = self._extractRegionData('vectors', tag='INT', segment_size=2, **kw)
        self.infomem = self._extractRegionData('infomem', tag='INFO', **kw)
        self.infoa = self._addRegion(Region('infoa'))
        self.infob = self._addRegion(Region('infob'))
        self.infoc = self._addRegion(Region('infoc'))
        self.infod = self._addRegion(Region('infod'))
            
        length = self.infomem.segment_size
        if 0 < length:
            infoa = kw.get('INFOA', 0)
            if 0 < infoa:
                self.infoa.set(infoa, length)
            infob = kw.get('INFOB', 0)
            if 0 < infob:
                self.infob.set(infob, length)
            infoc = kw.get('INFOC', 0)
            if 0 < infoc:
                self.infoc.set(infoc, length)
            infod = kw.get('INFOD', 0)
            if 0 < infod:
                self.infod.set(infod, length)

        # Extract the part number fields
        mcu_pndict = None
        m = self.__PartNumber_re.match(mcu)
        if m is None:
            raise Exception('Unparsable part number: %s' % (mcu,))
        mcu_pndict = m.groupdict()

        # Create a copy that uses a genericized memory type, as with legacy mspgcc
        xmcu_pndict = mcu_pndict.copy()
        xmcu_pndict['deviceType'] = '_'
        xmcu_pndict['revision'] = ''
        self.generic_mcu = ''.join([ xmcu_pndict.get(_field) for _field in self.__PartNumberFields ])
        self.GenericsMap.setdefault(self.generic_mcu, set()).add(self)

    @classmethod
    def CreateFromRow (cls, row):
        """Create an instance from a devices.csv line.

        Returns C{None} if the mcu is not a chip target."""
        cfg = { }
        mcu = row.pop(0)
        if mcu.endswith('generic'):
            return None
        cfg['cpu'] = CPU.LookupByKey(row.pop(0))
        cfg['CPU_Bugs'] = row.pop(0)
        cfg['mpy'] = MPY.LookupByKey(row.pop(0))
            
        cfg.update(dict(zip(cls.__HexFields, [ int(_x, 16) for _x in row ])))
        return cls(mcu, **cfg)

    def regionSection (self):
        return Region.Memory(self._Regions)

    def mergeRegions (self, r1, r2):
        """Attempt to merge r1 into r2.

        If r1 is nonempty and contiguous with r2, combine them and
        return C{True}; otherwise return C{False}."""

        if (0 == r1.length) or (r1.origin + r1.length != r2.origin):
            return False
        r2.set(r1.origin, r1.length + r2.length)
        r1.reset()
        return True

def load_devices (dev_path=None):
    if dev_path is None:
        dev_path = upstream_path('devices.csv')
    reader = csv.reader(open(dev_path))
    for row in reader:
        if (0 == len(row)) or row[0].startswith('#'):
            continue
        MCU.CreateFromRow(row)

if __name__ == '__main__':
    import sys
    
    load_devices()
    args = sys.argv[1:]
    if not args:
        args.push('msp430g2231')
    for a in args:
        mcu = MCU.Lookup(a)
        print 'Chip %s is a %s with a %s cpu and %s mpy' % (mcu.mcu, mcu.generic_mcu, mcu.cpu, mcu.mpy)
        print mcu.regionSection()
        print mcu.vectors.segments
