# Licensed under a 3-clause BSD style license - see LICENSE.rst

# "core.py" is auto-generated by erfa_generator.py from the template
# "core.py.templ". Do *not* edit "core.py" directly, instead edit
# "core.py.templ" and run erfa_generator.py from the source directory to
# update it.

"""
This module uses the Python/C API to wrap the ERFA library in
numpy-vectorized equivalents.

..warning::
    This is currently *not* part of the public Astropy API, and may change in
    the future.


The key idea is that any function can be called with inputs that are arrays,
and the wrappers will automatically vectorize and call the ERFA functions for
each item using broadcasting rules for numpy.  So the return values are always
numpy arrays of some sort.

For ERFA functions that take/return vectors or matrices, the vector/matrix
dimension(s) are always the *last* dimension(s).  For example, if you
want to give ten matrices (i.e., the ERFA input type is double[3][3]),
you would pass in a (10, 3, 3) numpy array.  If the output of the ERFA
function is scalar, you'll get back a length-10 1D array.

Note that the C part of these functions are implemented in a separate
module (compiled as ``_core``), derived from the ``core.c`` file.
Splitting the wrappers into separate pure-python and C portions
dramatically reduces compilation time without notably impacting
performance. (See issue [#3063] on the github repository for more
about this.)
"""
from __future__ import absolute_import, division, print_function

import warnings

from ..utils.compat import NUMPY_LT_1_8
from ..utils.exceptions import AstropyUserWarning

import numpy
from . import _core

# TODO: remove the above variable and the code using it and make_outputs_scalar
# when numpy < 1.8 is no longer supported

__all__ = ['ErfaError', 'ErfaWarning',
           {{ funcs|map(attribute='pyname')|surround("'","'")|join(", ") }},
           {{ constants|map(attribute='name')|surround("'","'")|join(", ") }},
           'dt_eraASTROM', 'dt_eraLDBODY']


#<---------------------------------Error-handling----------------------------->

class ErfaError(ValueError):
    """
    A class for errors triggered by ERFA functions (status codes < 0)
    """


class ErfaWarning(AstropyUserWarning):
    """
    A class for warnings triggered by ERFA functions (status codes > 0)
    """


STATUS_CODES = {}  # populated below before each function that returns an int

# This is a hard-coded list of status codes that need to be remapped,
# such as to turn errors into warnings.
STATUS_CODES_REMAP = {
    'cal2jd': {-3: 3}
}


def check_errwarn(statcodes, func_name):
    # Remap any errors into warnings in the STATUS_CODES_REMAP dict.
    if func_name in STATUS_CODES_REMAP:
        for before, after in STATUS_CODES_REMAP[func_name].items():
            statcodes[statcodes == before] = after
            STATUS_CODES[func_name][after] = STATUS_CODES[func_name][before]

    if numpy.any(statcodes<0):
        # errors present - only report the errors.
        if statcodes.shape:
            statcodes = statcodes[statcodes<0]

        errcodes = numpy.unique(statcodes)

        errcounts = dict([(e, numpy.sum(statcodes==e)) for e in errcodes])

        elsemsg = STATUS_CODES[func_name].get('else', None)
        if elsemsg is None:
            errmsgs = dict([(e, STATUS_CODES[func_name].get(e, 'Return code ' + str(e))) for e in errcodes])
        else:
            errmsgs = dict([(e, STATUS_CODES[func_name].get(e, elsemsg)) for e in errcodes])

        emsg = ', '.join(['{0} of "{1}"'.format(errcounts[e], errmsgs[e]) for e in errcodes])
        raise ErfaError('ERFA function "' + func_name + '" yielded ' + emsg)

    elif numpy.any(statcodes>0):
        #only warnings present
        if statcodes.shape:
            statcodes = statcodes[statcodes>0]

        warncodes = numpy.unique(statcodes)

        warncounts = dict([(w, numpy.sum(statcodes==w)) for w in warncodes])

        elsemsg = STATUS_CODES[func_name].get('else', None)
        if elsemsg is None:
            warnmsgs = dict([(w, STATUS_CODES[func_name].get(w, 'Return code ' + str(w))) for w in warncodes])
        else:
            warnmsgs = dict([(w, STATUS_CODES[func_name].get(w, elsemsg)) for w in warncodes])

        wmsg = ', '.join(['{0} of "{1}"'.format(warncounts[w], warnmsgs[w]) for w in warncodes])
        warnings.warn('ERFA function "' + func_name + '" yielded ' + wmsg, ErfaWarning)


#<-------------------------trailing shape verification------------------------>

def check_trailing_shape(arr, shape, name):
    try:
        if arr.shape[-len(shape):] != shape:
            raise Exception()
    except:
        raise ValueError("{0} must be of trailing dimensions {1}".format(name, shape))

#<--------------------------Actual ERFA-wrapping code------------------------->

dt_eraASTROM = numpy.dtype([('pmt','d'),
                         ('eb','d',(3,)),
                         ('eh','d',(3,)),
                         ('em','d'),
                         ('v','d',(3,)),
                         ('bm1','d'),
                         ('bpn','d',(3,3)),
                         ('along','d'),
                         ('phi','d'),
                         ('xpl','d'),
                         ('ypl','d'),
                         ('sphi','d'),
                         ('cphi','d'),
                         ('diurab','d'),
                         ('eral','d'),
                         ('refa','d'),
                         ('refb','d')], align=True)

dt_eraLDBODY = numpy.dtype([('bm','d'),
                         ('dl','d'),
                         ('pv','d',(2,3))], align=True)


{% for constant in constants %}
{{ constant.name }} = {{ constant.value }}
"""{{ constant.doc|join(' ') }}"""
{%- endfor %}

{% for func in funcs %}
def {{ func.pyname }}({{ func.args_by_inout('in|inout')|map(attribute='name')|join(', ') }}):
    """
    Wrapper for ERFA function ``{{ func.name }}``.

    Parameters
    ----------
    {%- for arg in func.args_by_inout('in|inout') %}
    {{ arg.name }} : {{ arg.ctype }} array
    {%- endfor %}

    Returns
    -------
    {%- for arg in func.args_by_inout('inout|out|ret') %}
    {{ arg.name }} : {{ arg.ctype }} array
    {%- endfor %}

    Notes
    -----
    The ERFA documentation is below.

{{ func.doc }}
    """

    #Turn all inputs into arrays
    {%- for arg in func.args_by_inout('in|inout') %}
    {{ arg.name }}_in = numpy.array({{ arg.name }}, dtype={{ arg.dtype }}, order="C", copy=False, subok=True)
    {%- endfor %}
    {%- for arg in func.args_by_inout('in|inout') %}
    {%- if arg.ndim > 0 %}
    check_trailing_shape({{ arg.name }}_in, {{ arg.shape }}, "{{arg.name}}")
    {%- endif %}
    {%- endfor %}

    {%- if func.args_by_inout('in|inout') %}
    make_outputs_scalar = False
    if NUMPY_LT_1_8:
        # in numpy < 1.8, the iterator used below doesn't work with 0d/scalar arrays
        # so we replace all scalars with 1d arrays
        make_outputs_scalar = True
        {%- for arg in func.args_by_inout('in|inout') %}
        if {{ arg.name_in_broadcast }}.shape == tuple():
            {{ arg.name }}_in = {{ arg.name }}_in.reshape((1,) + {{ arg.name }}_in.shape)
        else:
            make_outputs_scalar = False
        {%- endfor %}
    {%- endif %}

    #Create the output array, based on the broadcasted shape, adding the generated dimensions if needed
    broadcast = numpy.broadcast(numpy.int32(0.0), numpy.int32(0.0), {{ func.args_by_inout('in|inout')|map(attribute='name_in_broadcast')|join(', ') }})
    {%- for arg in func.args_by_inout('inout|out|ret|stat') %}
    {{ arg.name }}_out = numpy.empty(broadcast.shape + {{ arg.shape }}, dtype={{ arg.dtype }})
    {%- endfor %}
    {%- for arg in func.args_by_inout('inout') %}
    numpy.copyto({{ arg.name }}_out, {{ arg.name }}_in)
    {%- endfor %}

    #Create the iterator, broadcasting on all but the consumed dimensions
    arrs = [{{ (func.args_by_inout('in')|map(attribute='name_in_broadcast')|list + func.args_by_inout('inout|out|ret|stat')|map(attribute='name_out_broadcast')|list)|join(', ') }}]
    op_axes = [[-1]*(broadcast.nd-arr.ndim) + list(range(arr.ndim)) for arr in arrs]
    op_flags = [['readonly']]*{{ func.args_by_inout('in')|count }} + [['readwrite']]*{{ func.args_by_inout('inout|out|ret|stat')|count }}
    it = numpy.nditer(arrs, op_axes=op_axes, op_flags=op_flags)

    #Iterate
    stat_ok = _core._{{ func.pyname }}(it)

    {%- for arg in func.args_by_inout('stat') %}

    if not stat_ok:
        check_errwarn({{ arg.name }}_out, '{{ func.pyname }}')
    {%- endfor %}

    {%- if func.args_by_inout('in|inout') %}
    #need to convert the outputs back to scalars if all the inputs were scalars but we made them 1d
    if make_outputs_scalar:
        {%- for arg in func.args_by_inout('inout|out|ret') %}
        assert len({{ arg.name }}_out.shape) > 0 and {{ arg.name }}_out.shape[0] == 1
        {{ arg.name }}_out = {{ arg.name }}_out.reshape({{ arg.name }}_out.shape[1:])
        {%- else %}
        pass
        {%- endfor %}
    {%- endif %}

    return {{ func.args_by_inout('inout|out|ret')|map(attribute='name')|postfix('_out')|join(', ') }}

{%- for stat in func.args_by_inout('stat') %}
{%- if stat.doc_info.statuscodes %}
STATUS_CODES['{{ func.pyname }}'] = {{ stat.doc_info.statuscodes|string }}
{% endif %}
{%- endfor %}

{% endfor %}
