#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2009, 2010, 2011, 2012, 2013, 2014.

# SMHI,
# Folkborgsvägen 1,
# Norrköping,
# Sweden

# Author(s):

#   Martin Raspaud <martin.raspaud@smhi.se>
#   Adam Dybbroe <adam.dybbroe@smhi.se>

# This file is part of mpop.

# mpop is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# mpop is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with mpop.  If not, see <http://www.gnu.org/licenses/>.

"""This module handles coverage objects. Such objects are used to
transform area projected data by changing either the area or the
projection or both. A typical usage is to transform one large area in
satellite projection to an area of interest in polar projection for
example.
"""
import os
import ConfigParser
import logging

import numpy as np
from pyresample import image, utils, geometry, kd_tree
from pyresample.bilinear import get_sample_from_bil_info, get_bil_info

from mpop import CONFIG_PATH

logger = logging.getLogger(__name__)

area_file = None


def get_area_file():
    global area_file
    if area_file:
        return area_file

    conf = ConfigParser.ConfigParser()
    conf.read(os.path.join(CONFIG_PATH, "mpop.cfg"))

    try:
        area_file = os.path.join(conf.get("projector",
                                          "area_directory") or
                                 CONFIG_PATH,
                                 conf.get("projector", "area_file"))
    except ConfigParser.NoSectionError:
        area_file = ""
        logger.warning("Couldn't find the mpop.cfg file. "
                       "Do you have one ? is it in $PPP_CONFIG_DIR ?")
    return area_file


def get_area_def(area_name):
    """Get the definition of *area_name* from file. The file is defined to use
    is to be placed in the $PPP_CONFIG_DIR directory, and its name is defined
    in mpop's configuration file.
    """
    return utils.parse_area_file(get_area_file(), area_name)[0]


def _get_area_hash(area):
    """Calculate a (close to) unique hash value for a given area.
    """
    try:
        return hash(area.lons.tostring() + area.lats.tostring())
    except AttributeError:
        try:
            return hash(area.tostring())
        except AttributeError:
            return hash(str(area))


class Projector(object):

    """This class define projector objects. They contain the mapping
    information necessary for projection purposes. For efficiency reasons,
    generated projectors can be saved to disk for later reuse. Use the
    :meth:`save` method for this.

    To define a projector object, on has to specify *in_area* and
    *out_area*, and can also input the *in_lonlats* or the *mode*.
    Available modes area:
    - 'quick' (works only if both in- and out-areas are AreaDefinitions)
    - 'bilinear' (out-area needs to be AreaDefinition with proj4_string)
    - 'ewa'
    - 'nearest'.
    *radius* defines the radius of influence for nearest neighbour
    search in 'nearest' and 'bilinear' modes.
    """

    def __init__(self, in_area, out_area,
                 in_latlons=None, mode=None,
                 radius=10000, nprocs=1):

        if (mode is not None and
                mode not in ["quick", "nearest", "ewa", "bilinear"]):
            raise ValueError("Projector mode must be one of 'nearest', "
                             "'quick', 'ewa', 'bilinear'")

        self.area_file = get_area_file()

        self.in_area = None
        self.out_area = None
        self._cache = None
        self._filename = None
        self.mode = "quick"
        self.radius = radius
        self.conf = ConfigParser.ConfigParser()
        self.conf.read(os.path.join(CONFIG_PATH, "mpop.cfg"))

        # TODO:
        # - Rework so that in_area and out_area can be lonlats.
        # - Add a recompute flag ?

        # Setting up the input area
        try:
            self.in_area = get_area_def(in_area)
            in_id = in_area
        except (utils.AreaNotFound, AttributeError):
            try:
                in_id = in_area.area_id
                self.in_area = in_area
            except AttributeError:
                try:
                    # TODO: Note that latlons are in order (lons, lats)
                    self.in_area = geometry.SwathDefinition(lons=in_latlons[0],
                                                            lats=in_latlons[1])
                    in_id = in_area
                except TypeError:
                    raise utils.AreaNotFound("Input area " +
                                             str(in_area) +
                                             " must be defined in " +
                                             self.area_file +
                                             ", be an area object"
                                             " or longitudes/latitudes must be "
                                             "provided.")

        # Setting up the output area
        try:
            self.out_area = get_area_def(out_area)
            out_id = out_area
        except (utils.AreaNotFound, AttributeError):
            try:
                out_id = out_area.area_id
                self.out_area = out_area
            except AttributeError:
                raise utils.AreaNotFound("Output area " +
                                         str(out_area) +
                                         " must be defined in " +
                                         self.area_file + " or "
                                         "be an area object.")

        # if self.in_area == self.out_area:
        #    return

        # choosing the right mode if necessary
        if mode is None:
            try:
                dicts = in_area.proj_dict, out_area.proj_dict
                del dicts
                self.mode = "quick"
            except AttributeError:
                self.mode = "nearest"
        else:
            self.mode = mode

        filename = (in_id + "2" + out_id + "_" +
                    str(_get_area_hash(self.in_area)) + "to" +
                    str(_get_area_hash(self.out_area)) + "_" +
                    self.mode + ".npz")

        projections_directory = "/var/tmp"
        try:
            projections_directory = self.conf.get("projector",
                                                  "projections_directory")
        except ConfigParser.NoSectionError:
            pass

        self._filename = os.path.join(projections_directory, filename)

        try:
            self._cache = {}
            self._file_cache = np.load(self._filename)
        except:
            logger.info("Computing projection from %s to %s...",
                        in_id, out_id)

            if self.mode == "nearest":
                valid_index, valid_output_index, index_array, distance_array = \
                    kd_tree.get_neighbour_info(self.in_area,
                                               self.out_area,
                                               self.radius,
                                               neighbours=1,
                                               nprocs=nprocs)
                del distance_array
                self._cache = {}
                self._cache['valid_index'] = valid_index
                self._cache['valid_output_index'] = valid_output_index
                self._cache['index_array'] = index_array

            elif self.mode == "quick":
                ridx, cidx = \
                    utils.generate_quick_linesample_arrays(self.in_area,
                                                           self.out_area)
                self._cache = {}
                self._cache['row_idx'] = ridx
                self._cache['col_idx'] = cidx

            elif self.mode == "ewa":
                from pyresample.ewa import ll2cr
                swath_points_in_grid, cols, rows = ll2cr(self.in_area,
                                                         self.out_area)
                self._cache = {}
                # self._cache['ewa_swath_points_in_grid'] = \
                #     swath_points_in_grid
                self._cache['ewa_cols'] = cols
                self._cache['ewa_rows'] = rows

            elif self.mode == "bilinear":

                bilinear_t, bilinear_s, input_idxs, idx_arr = \
                    get_bil_info(self.in_area, self.out_area,
                                 self.radius, neighbours=32,
                                 nprocs=nprocs, masked=False)

                self._cache = {}
                self._cache['bilinear_s'] = bilinear_s
                self._cache['bilinear_t'] = bilinear_t
                self._cache['input_idxs'] = input_idxs
                self._cache['idx_arr'] = idx_arr

    def save(self, resave=False):
        """Save the precomputation to disk, and overwrite existing file in case
        *resave* is true.
        """
        if (not os.path.exists(self._filename)) or resave:
            logger.info("Saving projection to " +
                        self._filename)
            np.savez(self._filename, **self._cache)

    def project_array(self, data):
        """Project an array *data* along the given Projector object.
        """

        if self.mode == "nearest":
            if not 'valid_index' in self._cache:
                self._cache['valid_index'] = self._file_cache['valid_index']
                self._cache['valid_output_index'] = \
                    self._file_cache['valid_output_index']
                self._cache['index_array'] = self._file_cache['index_array']

            valid_index, valid_output_index, index_array = \
                (self._cache['valid_index'],
                 self._cache['valid_output_index'],
                 self._cache['index_array'])

            res = kd_tree.get_sample_from_neighbour_info('nn',
                                                         self.out_area.shape,
                                                         data,
                                                         valid_index,
                                                         valid_output_index,
                                                         index_array,
                                                         fill_value=None)

        elif self.mode == "quick":
            if not 'row_idx' in self._cache:
                self._cache['row_idx'] = self._file_cache['row_idx']
                self._cache['col_idx'] = self._file_cache['col_idx']
            row_idx, col_idx = self._cache['row_idx'], self._cache['col_idx']
            img = image.ImageContainer(data, self.in_area, fill_value=None)
            res = np.ma.array(img.get_array_from_linesample(row_idx, col_idx),
                              dtype=data.dtype)

        elif self.mode == "ewa":
            from pyresample.ewa import fornav
            # TODO: should be user configurable?
            rows_per_scan = None

            if 'ewa_cols' not in self._cache:
                self._cache['ewa_cols'] = self._file_cache['ewa_cols']
                self._cache['ewa_rows'] = self._file_cache['ewa_rows']
            num_valid_points, res = fornav(self._cache['ewa_cols'],
                                           self._cache['ewa_rows'],
                                           self.out_area, data,
                                           rows_per_scan=rows_per_scan)

        elif self.mode == "bilinear":

            if 'bilinear_t' not in self._cache:
                self._cache['bilinear_t'] = self._file_cache['bilinear_t']
                self._cache['bilinear_s'] = self._file_cache['bilinear_s']
                self._cache['input_idxs'] = self._file_cache['input_idxs']
                self._cache['idx_arr'] = self._file_cache['idx_arr']

            res = get_sample_from_bil_info(data.ravel(),
                                           self._cache['bilinear_t'],
                                           self._cache['bilinear_s'],
                                           self._cache['input_idxs'],
                                           self._cache['idx_arr'],
                                           output_shape=self.out_area.shape)
            res = np.ma.masked_invalid(res)

        return res
