import collections
import functools
import operator
import os
import re
from time import time

from ase.atoms import Atoms, symbols2numbers
from ase.calculators.calculator import all_properties, all_changes
from ase.data import atomic_numbers
from ase.parallel import world, DummyMPI, parallel_function, parallel_generator
from ase.utils import Lock, basestring


T2000 = 946681200.0  # January 1. 2000
YEAR = 31557600.0  # 365.25 days


def now():
    """Return time since January 1. 2000 in years."""
    return (time() - T2000) / YEAR


seconds = {'s': 1,
           'm': 60,
           'h': 3600,
           'd': 86400,
           'w': 604800,
           'M': 2629800,
           'y': YEAR}

longwords = {'s': 'second',
             'm': 'minute',
             'h': 'hour',
             'd': 'day',
             'w': 'week',
             'M': 'month',
             'y': 'year'}

ops = {'<': operator.lt,
       '<=': operator.le,
       '=': operator.eq,
       '>=': operator.ge,
       '>': operator.gt,
       '!=': operator.ne}

invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='}

word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$')

reserved_keys = set(all_properties + all_changes +
                    ['id', 'unique_id', 'ctime', 'mtime', 'user',
                     'momenta', 'constraints', 'natoms', 'formula', 'age',
                     'calculator', 'calculator_parameters',
                     'key_value_pairs', 'data'])

numeric_keys = set(['id', 'energy', 'magmom', 'charge', 'natoms'])


def check(key_value_pairs):
    for key, value in key_value_pairs.items():
        if not word.match(key) or key in reserved_keys:
            raise ValueError('Bad key: {0}'.format(key))
        if not isinstance(value, (int, float, basestring)):
            raise ValueError('Bad value: {0}'.format(value))
        if isinstance(value, basestring):
            for t in [int, float]:
                if str_represents(value, t):
                    raise ValueError(
                        'Value ' + value + ' is put in as string ' +
                        'but can be interpreted as ' +
                        '{0}! Please convert '.format(t.__name__) +
                        'to {0} using '.format(t.__name__) +
                        '{0}(value) before '.format(t.__name__) +
                        'writing to the database OR change ' +
                        'to a different string.')


def str_represents(value, t=int):
    try:
        t(value)
    except ValueError:
        return False
    return True


def connect(name, type='extract_from_name', create_indices=True,
            use_lock_file=True, append=True, serial=False):
    """Create connection to database.

    name: str
        Filename or address of database.
    type: str
        One of 'json', 'db', 'postgresql',
        (JSON, SQLite, PostgreSQL).
        Default is 'extract_from_name', which will guess the type
        from the name.
    use_lock_file: bool
        You can turn this off if you know what you are doing ...
    append: bool
        Use append=False to start a new database.
    """

    if type == 'extract_from_name':
        if name is None:
            type = None
        elif not isinstance(name, basestring):
            type = 'json'
        elif name.startswith('pg://'):
            type = 'postgresql'
        else:
            type = os.path.splitext(name)[1][1:]

    if type is None:
        return Database()

    if not append and world.rank == 0 and os.path.isfile(name):
        os.remove(name)

    if type == 'json':
        from ase.db.jsondb import JSONDatabase
        return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial)
    if type == 'db':
        from ase.db.sqlite import SQLite3Database
        return SQLite3Database(name, create_indices, use_lock_file,
                               serial=serial)
    if type == 'postgresql':
        from ase.db.postgresql import PostgreSQLDatabase
        return PostgreSQLDatabase(name[5:])
    raise ValueError('Unknown database type: ' + type)


def lock(method):
    """Decorator for using a lock-file."""
    @functools.wraps(method)
    def new_method(self, *args, **kwargs):
        if self.lock is None:
            return method(self, *args, **kwargs)
        else:
            with self.lock:
                return method(self, *args, **kwargs)
    return new_method


def convert_str_to_int_float_or_str(value):
    """Safe eval()"""
    try:
        return int(value)
    except ValueError:
        try:
            value = float(value)
        except ValueError:
            value = {'True': 1.0, 'False': 0.0}.get(value, value)
        return value


class Database:
    """Base class for all databases."""
    def __init__(self, filename=None, create_indices=True,
                 use_lock_file=False, serial=False):
        """Database object.

        serial: bool
            Let someone else handle parallelization.  Default behavior is
            to interact with the database on the master only and then
            distribute results to all slaves.
        """
        if isinstance(filename, str):
            filename = os.path.expanduser(filename)
        self.filename = filename
        self.create_indices = create_indices
        if use_lock_file and isinstance(filename, str):
            self.lock = Lock(filename + '.lock', world=DummyMPI())
        else:
            self.lock = None
        self.serial = serial

    @parallel_function
    @lock
    def write(self, atoms, key_value_pairs={}, data={}, **kwargs):
        """Write atoms to database with key-value pairs.

        atoms: Atoms object
            Write atomic numbers, positions, unit cell and boundary
            conditions.  If a calculator is attached, write also already
            calculated properties such as the energy and forces.
        key_value_pairs: dict
            Dictionary of key-value pairs.  Values must be strings or numbers.
        data: dict
            Extra stuff (not for searching).

        Key-value pairs can also be set using keyword arguments::

            connection.write(atoms, name='ABC', frequency=42.0)

        Returns integer id of the new row.
        """

        if atoms is None:
            atoms = Atoms()

        kvp = dict(key_value_pairs)  # modify a copy
        kvp.update(kwargs)

        id = self._write(atoms, kvp, data)
        return id

    def _write(self, atoms, key_value_pairs, data):
        check(key_value_pairs)
        return 1

    @parallel_function
    @lock
    def reserve(self, **key_value_pairs):
        """Write empty row if not already present.

        Usage::

            id = conn.reserve(key1=value1, key2=value2, ...)

        Write an empty row with the given key-value pairs and
        return the integer id.  If such a row already exists, don't write
        anything and return None.
        """

        for dct in self._select([],
                                [(key, '=', value)
                                 for key, value in key_value_pairs.items()]):
            return None

        atoms = Atoms()

        calc_name = key_value_pairs.pop('calculator', None)

        if calc_name:
            # Allow use of calculator key
            assert calc_name.lower() == calc_name

            # Fake calculator class:
            class Fake:
                name = calc_name

                def todict(self):
                    return {}

                def check_state(self, atoms):
                    return ['positions']

            atoms.calc = Fake()

        id = self._write(atoms, key_value_pairs, {})

        return id

    def __delitem__(self, id):
        self.delete([id])

    def get_atoms(self, selection=None, attach_calculator=False,
                  add_additional_information=False, **kwargs):
        """Get Atoms object.

        selection: int, str or list
            See the select() method.
        attach_calculator: bool
            Attach calculator object to Atoms object (default value is
            False).
        add_additional_information: bool
            Put key-value pairs and data into Atoms.info dictionary.

        In addition, one can use keyword arguments to select specific
        key-value pairs.
        """

        row = self.get(selection, **kwargs)
        return row.toatoms(attach_calculator, add_additional_information)

    def __getitem__(self, selection):
        return self.get(selection)

    def get(self, selection=None, **kwargs):
        """Select a single row and return it as a dictionary.

        selection: int, str or list
            See the select() method.
        """

        rows = list(self.select(selection, limit=2, **kwargs))
        if not rows:
            raise KeyError('no match')
        assert len(rows) == 1, 'more than one row matched'
        return rows[0]

    def parse_selection(self, selection, **kwargs):
        if selection is None or selection == '':
            expressions = []
        elif isinstance(selection, int):
            expressions = [('id', '=', selection)]
        elif isinstance(selection, list):
            expressions = selection
        else:
            expressions = [w.strip() for w in selection.split(',')]
        keys = []
        comparisons = []
        for expression in expressions:
            if isinstance(expression, (list, tuple)):
                comparisons.append(expression)
                continue
            if expression.count('<') == 2:
                value, expression = expression.split('<', 1)
                if expression[0] == '=':
                    op = '>='
                    expression = expression[1:]
                else:
                    op = '>'
                key = expression.split('<', 1)[0]
                comparisons.append((key, op, value))
            for op in ['!=', '<=', '>=', '<', '>', '=']:
                if op in expression:
                    break
            else:
                if expression in atomic_numbers:
                    comparisons.append((expression, '>', 0))
                else:
                    keys.append(expression)
                continue
            key, value = expression.split(op)
            comparisons.append((key, op, value))

        cmps = []
        for key, value in kwargs.items():
            comparisons.append((key, '=', value))

        for key, op, value in comparisons:
            if key == 'age':
                key = 'ctime'
                op = invop[op]
                value = now() - time_string_to_float(value)
            elif key == 'formula':
                assert op == '='
                numbers = symbols2numbers(value)
                count = collections.defaultdict(int)
                for Z in numbers:
                    count[Z] += 1
                cmps.extend((Z, '=', count[Z]) for Z in count)
                key = 'natoms'
                value = len(numbers)
            elif key in atomic_numbers:
                key = atomic_numbers[key]
                value = int(value)
            elif isinstance(value, basestring):
                value = convert_str_to_int_float_or_str(value)
            if key in numeric_keys and not isinstance(value, (int, float)):
                msg = 'Wrong type for "{0}{1}{2}" - must be a number'
                raise ValueError(msg.format(key, op, value))
            cmps.append((key, op, value))

        return keys, cmps

    @parallel_generator
    def select(self, selection=None, filter=None, explain=False,
               verbosity=1, limit=None, offset=0, sort=None, **kwargs):
        """Select rows.

        Return AtomsRow iterator with results.  Selection is done
        using key-value pairs and the special keys:

            formula, age, user, calculator, natoms, energy, magmom
            and/or charge.

        selection: int, str or list
            Can be:

            * an integer id
            * a string like 'key=value', where '=' can also be one of
              '<=', '<', '>', '>=' or '!='.
            * a string like 'key'
            * comma separated strings like 'key1<value1,key2=value2,key'
            * list of strings or tuples: [('charge', '=', 1)].
        filter: function
            A function that takes as input a row and returns True or False.
        explain: bool
            Explain query plan.
        verbosity: int
            Possible values: 0, 1 or 2.
        limit: int or None
            Limit selection.
        """

        if sort:
            if sort == 'age':
                sort = '-ctime'
            elif sort == '-age':
                sort = 'ctime'
            elif sort.lstrip('-') == 'user':
                sort += 'name'

        keys, cmps = self.parse_selection(selection, **kwargs)
        for row in self._select(keys, cmps, explain=explain,
                                verbosity=verbosity,
                                limit=limit, offset=offset, sort=sort):
            if filter is None or filter(row):
                yield row

    def count(self, selection=None, **kwargs):
        """Count rows.

        See the select() method for the selection syntax.  Use db.count() or
        len(db) to count all rows.
        """
        n = 0
        for row in self.select(selection, **kwargs):
            n += 1
        return n

    def __len__(self):
        return self.count()

    @parallel_function
    @lock
    def update(self, ids, delete_keys=[], block_size=1000,
               **add_key_value_pairs):
        """Update and/or delete key-value pairs of row(s).

        ids: int or list of int
            ID's of rows to update.
        delete_keys: list of str
            Keys to remove.
        block_size: int
            Block-size for each transaction.

        Use keyword arguments to add new key-value pairs.

        Returns number of key-value pairs added and removed.
        """
        check(add_key_value_pairs)

        if isinstance(ids, int):
            ids = [ids]

        B = block_size
        nblocks = (len(ids) - 1) // B + 1
        M = 0
        N = 0
        for b in range(nblocks):
            m, n = self._update(ids[b * B:(b + 1) * B], delete_keys,
                                add_key_value_pairs)
            M += m
            N += n
        return M, N

    def delete(self, ids):
        """Delete rows."""
        raise NotImplementedError


def time_string_to_float(s):
    if isinstance(s, (float, int)):
        return s
    s = s.replace(' ', '')
    if '+' in s:
        return sum(time_string_to_float(x) for x in s.split('+'))
    if s[-2].isalpha() and s[-1] == 's':
        s = s[:-1]
    i = 1
    while s[i].isdigit():
        i += 1
    return seconds[s[i:]] * int(s[:i]) / YEAR


def float_to_time_string(t, long=False):
    t *= YEAR
    for s in 'yMwdhms':
        x = t / seconds[s]
        if x > 5:
            break
    if long:
        return '{0:.3f} {1}s'.format(x, longwords[s])
    else:
        return '{0:.0f}{1}'.format(round(x), s)
