#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2022 Pytroll developers
#
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Read the Terra/Aqua MODIS relative spectral response functions."""

import logging
import os

import numpy as np

from pyspectral.config import get_config
from pyspectral.utils import get_central_wave, sort_data

LOG = logging.getLogger(__name__)

MODIS_BAND_NAMES = [str(i) for i in range(1, 37)]
SHORTWAVE_BANDS = [str(i) for i in list(range(1, 20)) + [26]]


class ModisRSR(object):
    """Container for the Terra/Aqua RSR data."""

    def __init__(self, bandname, platform_name, sort=True):
        """Initialize the Modis RSR class."""
        self.platform_name = platform_name
        self.bandname = bandname
        self.filenames = {}
        self.requested_band_filename = None
        self.is_sw = False
        if bandname in SHORTWAVE_BANDS:
            self.is_sw = True
        self.scales = {}
        for bname in MODIS_BAND_NAMES:
            self.filenames[bname] = None

        self.rsr = None
        self._sort = sort

        options = get_config()
        self.path = options[platform_name + '-modis'].get('path')
        self.output_dir = options.get('rsr_dir', './')

        self._get_bandfilenames()
        LOG.debug("Filenames: %s", str(self.filenames))
        if os.path.exists(self.filenames[bandname]):
            self.requested_band_filename = self.filenames[bandname]
            self._load()
        else:
            raise IOError("Couldn't find an existing file for this band: " +
                          str(self.bandname))

    def _get_bandfilenames(self):
        """Get the MODIS rsr filenames."""
        path = self.path

        for band in MODIS_BAND_NAMES:
            bnum = int(band)
            LOG.debug("Band = %s", str(band))
            if self.platform_name == 'EOS-Terra':
                filename = os.path.join(path,
                                        "rsr.{0:d}.inb.final".format(bnum))
            else:
                if bnum in [5, 6, 7] + list(range(20, 37)):
                    filename = os.path.join(
                        path, "{0:>02d}.tv.1pct.det".format(bnum))
                else:
                    filename = os.path.join(
                        path, "{0:>02d}.amb.1pct.det".format(bnum))

            self.filenames[band] = filename

    def _load(self):
        """Load the MODIS RSR data for the band requested."""
        if self.is_sw or self.platform_name == 'EOS-Aqua':
            scale = 0.001
        else:
            scale = 1.0
        detector = read_modis_response(self.requested_band_filename, scale)
        self.rsr = detector
        if self._sort:
            self.sort()

    def sort(self):
        """Sort the data so that x is monotonically increasing and contains no duplicates."""
        if 'wavelength' in self.rsr:
            # Only one detector apparently:
            self.rsr['wavelength'], self.rsr['response'] = \
                sort_data(self.rsr['wavelength'], self.rsr['response'])
        else:
            for detector_name in self.rsr:
                (self.rsr[detector_name]['wavelength'],
                 self.rsr[detector_name]['response']) = \
                    sort_data(self.rsr[detector_name]['wavelength'],
                              self.rsr[detector_name]['response'])


def read_modis_response(filename, scale=1.0):
    """Read the Terra/Aqua MODIS relative spectral responses.

    Be aware that MODIS has several detectors (more than one) compared to
    e.g. AVHRR which has always only one.

    """
    with open(filename, "r") as fid:
        lines = fid.readlines()

    nodata = -99.0
    # The IR channels seem to be in microns, whereas the short wave channels are
    # in nanometers! For VIS/NIR scale should be 0.001
    detectors = {}
    for line in lines:
        if line.find("#") == 0:
            continue
        dummy, det_num, s_1, s_2 = line.split()
        detector_name = 'det-{0:d}'.format(int(det_num))
        if detector_name not in detectors:
            detectors[detector_name] = {'wavelength': [], 'response': []}

        detectors[detector_name]['wavelength'].append(float(s_1) * scale)
        detectors[detector_name]['response'].append(float(s_2))

    for key in detectors:
        mask = np.array(detectors[key]['response']) == nodata
        detectors[key]['response'] = np.ma.masked_array(
            detectors[key]['response'], mask=mask).compressed()
        detectors[key]['wavelength'] = np.ma.masked_array(
            detectors[key]['wavelength'], mask=mask).compressed()

    return detectors


def convert2hdf5(platform_name):
    """Retrieve original RSR data and convert to internal hdf5 format."""
    import h5py

    modis = ModisRSR('20', platform_name)
    mfile = os.path.join(modis.output_dir,
                         "rsr_modis_{platform}.h5".format(platform=platform_name))

    with h5py.File(mfile, "w") as h5f:
        h5f.attrs['description'] = 'Relative Spectral Responses for MODIS'
        h5f.attrs['platform_name'] = platform_name
        h5f.attrs['band_names'] = MODIS_BAND_NAMES

        for chname in MODIS_BAND_NAMES:
            modis = ModisRSR(chname, platform_name)
            grp = h5f.create_group(chname)
            grp.attrs['number_of_detectors'] = len(modis.rsr.keys())
            # Loop over each detector to check if the sampling wavelengths are
            # identical:
            det_names = [detector_name for detector_name in modis.rsr.keys()]
            wvl = modis.rsr[det_names[0]]['wavelength']
            wvl_is_constant = True
            for det in det_names[1:]:
                if wvl.shape != modis.rsr[det]['wavelength'].shape:
                    wvl_is_constant = False
                    break
                elif not np.allclose(wvl, modis.rsr[det]['wavelength']):
                    wvl_is_constant = False
                    break

            if wvl_is_constant:
                arr = modis.rsr[det_names[0]]['wavelength']
                dset = grp.create_dataset('wavelength', arr.shape, dtype='f')
                dset.attrs['unit'] = 'm'
                dset.attrs['scale'] = 1e-06
                dset[...] = arr

            # Loop over each detector:
            for det in modis.rsr:
                det_grp = grp.create_group(det)
                wvl = modis.rsr[det]['wavelength'][
                    ~np.isnan(modis.rsr[det]['wavelength'])]
                rsp = modis.rsr[det]['response'][
                    ~np.isnan(modis.rsr[det]['wavelength'])]
                det_grp.attrs[
                    'central_wavelength'] = get_central_wave(wvl, rsp)
                if not wvl_is_constant:
                    arr = modis.rsr[det]['wavelength']
                    dset = det_grp.create_dataset(
                        'wavelength', arr.shape, dtype='f')
                    dset.attrs['unit'] = 'm'
                    dset.attrs['scale'] = 1e-06
                    dset[...] = arr

                arr = modis.rsr[det]['response']
                dset = det_grp.create_dataset('response', arr.shape, dtype='f')
                dset[...] = arr


if __name__ == "__main__":
    for sat in ['EOS-Terra', 'EOS-Aqua']:
        convert2hdf5(sat)
