from __future__ import with_statement
import sys

try:
    import numpy
    import os
    import re
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    import matplotlib.gridspec as gridspec
    from matplotlib.text import Text
    from pylab import *

    import_success = True
except ImportError:
    import_success = False
    print "Error importing modules pyfits, wx, matplotlib, numpy"

def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)

class DataPlotterManager(object):
    # static members
    recipe_name = "kmos_sci_red"
    reconstructed_cat = "SCI_RECONSTRUCTED"
    oh_spec_cat = "OH_SPEC"
    
    IFU_stat_color = {'Active':'summer', 'Locked':'spring', 'NotInPAF':'winter', 'NotInPAF & Locked':'autumn', 'Coll.':'cool', 'Empty':'copper'}

    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        return [
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="imethod",
                    group="Recons.", description="Interpolation Method (NN, lwNN, swNN, MS, CS)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="oscan",
                    group="Recons.", description="Apply Overscan Correction"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="xcal_interpolation",
                    group="Recons.", description="Interpolate xcal between rotator angles"),

            reflex.RecipeParameter(recipe=self.recipe_name, displayName="neighborhoodRange",
                    group="Extr.", description="Range for Neighbors"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="flux",
                    group="Extr.", description="Apply flux conservation"),

            reflex.RecipeParameter(recipe=self.recipe_name, displayName="obj_sky_table",
                    group="Sky Sub.", description="The path to the file with the modified obj/sky associations."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="background",
                    group="Sky Sub.", description="Apply background removal"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="no_subtract",
                    group="Sky Sub.", description="Don't sky subtract object and references"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="sky_tweak",
                    group="Sky Sub.", description="Use modified sky cube for sky subtraction (TRUE or FALSE)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="skip_sky_oh_align",
                    group="Sky Sub.", description="Do not align the sky on OH lines - Only is stretch is on (0-no, 1-yes)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="stretch",
                    group="Sky Sub.", description="Stretch the sky for sky subtraction (0-no, 1-yes)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="stretch_degree",
                    group="Sky Sub.", description="Stretching polynomial degree."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="stretch_resampling",
                    group="Sky Sub.", description="Stretching interpolation method (linear / spline)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="velocity_offset",
                    group="Sky Sub.", description="Specify velocity offset correction in km/s for lambda scale."),
        ]

    def readFitsData(self, fitsFiles):
        # Initialise
        self.files = dict()
        self.oh_spec = dict()
        self.oh_spec["Found"] = False

        # Loop on all FITS files 
        for f in fitsFiles:
            # Use OH_SPEC if found
            if f.category == self.oh_spec_cat :
                oh_spec_file = PipelineProduct(f)
                self.oh_spec["Found"] = True
                self.oh_spec["CRVAL1"] = oh_spec_file.all_hdu[1].header["CRVAL1"]
                self.oh_spec["CDELT1"] = oh_spec_file.all_hdu[1].header["CDELT1"]
                self.oh_spec["Spectrum"] = oh_spec_file.all_hdu[1].data

            # For each reconstructed image
            if f.category == self.reconstructed_cat :
                recons_file = PipelineProduct(f)
                filename = os.path.basename(f.name)
                # Create a Dictionary per file
                self.files[filename] = dict()
                # Loop on extensions
                for recons_ext in recons_file.all_hdu:
                    # EXTNAME is missing in the primary header - Skip it anyway
                    try:
                        extname = recons_ext.header['EXTNAME']
                    except KeyError:
                        continue
                   
                    # Create Entry for the extension
                    self.files[filename][extname]=dict()
                        
                    # Get the IFU number from extname to get the IFU status
                    m = re.search(r"\d+", extname)
                    ifu_number = m.group()
                    self.files[filename][extname]["IFU_NUMBER"] = int(ifu_number)

                    naxis = recons_ext.header['NAXIS']
                    # Set the IFU STATUS
                    arm_key = 'ESO OCS ARM'+ifu_number+' NOTUSED'
                    if (arm_key in recons_file.all_hdu[0].header)  :
                        self.files[filename][extname]["IFU_STATUS"] = recons_file.all_hdu[0].header[arm_key]
                    else :
                        if (naxis == 3):
                            self.files[filename][extname]["IFU_STATUS"] = 'Active'
                        else :
                            self.files[filename][extname]["IFU_STATUS"] = 'Empty'

                    if (naxis == 3):
                        # Get Keyword infos  
                        self.files[filename][extname]["CRPIX3"] = recons_ext.header['CRPIX3']
                        self.files[filename][extname]["CRVAL3"] = recons_ext.header['CRVAL3']
                        self.files[filename][extname]["CDELT3"] = recons_ext.header['CDELT3']
                        self.files[filename][extname]["UNIT"] = recons_ext.header['ESO QC CUBE_UNIT']
                        self.files[filename][extname]["NAME"] = recons_ext.header['ESO OCS ARM' + ifu_number +' NAME']
                        
                        # Fill Spectrum
                        self.files[filename][extname]["Spectrum"] = []
                        for cube_plane in recons_ext.data:
                            cube_plane_nan_free = cube_plane[~numpy.isnan(cube_plane)]
                            if (len(cube_plane_nan_free) > 0):
                                mean = cube_plane_nan_free.mean()
                            else:
                                mean = numpy.nan
                            self.files[filename][extname]["Spectrum"].append(mean)

                        # Fill Collapsed Image
                        collapsed_frame_ext = self._get_collapsed_ext(fitsFiles, f, extname)
                        if collapsed_frame_ext:
                            self.files[filename][extname]["Collapsed"] = collapsed_frame_ext.data

        # If proper files are there...
        if (len(self.files.keys()) > 0):
            # Set the plotting functions
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot
        else:
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    # Inputs : all files, reconstructed file, extension
    # Returns the corresponding extension in the collapsed corresponding file 
    # The corresponding file must have the same file name as ref_file, prefixed with make_image_
    # The corresponding extention must have the same EXTNAME keyword
    def _get_collapsed_ext(self, fitsFiles, ref_file, extname):
        ref_file_name = os.path.basename(ref_file.name)
        # Loop on all FITS files 
        for f in fitsFiles:
            filename = os.path.basename(f.name)
            if (filename == "make_image_"+ref_file_name):
                collapsed_frame = PipelineProduct(f)
                for collapsed_frame_ext in collapsed_frame.all_hdu:
                    # EXTNAME is missing in the primary header - Skip it anyway
                    try:
                        coll_extname = collapsed_frame_ext.header['EXTNAME']
                    except KeyError:
                        continue
                    if coll_extname == extname:
                        return collapsed_frame_ext

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self) :
        widgets = list()

        # Files Selector radiobutton
        self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.files_selector, self.setFSCallback, sorted(self.files.keys()), 0, 
                title='Files Selection (Left Mouse Button)')
        widgets.append(self.radiobutton)
        
        self.clickable_ifus =reflex_plot_widgets.InteractiveClickableSubplot(self.ifus_selector, self.setIFUSCallback)
        widgets.append(self.clickable_ifus)

        return widgets

    def extension_has_spectrum(self, filename, extname):
        if ("Spectrum" in self.files[self.selected_file][extname].keys()):
            return True
        else:
            return False

    def setIFUSCallback(self, point) :
        if (1 < point.ydata < 3) :
            extname = "IFU."+str(int((point.xdata/2)+0.5))+".DATA"
            if (self.extension_has_spectrum(self.selected_file, extname)):
                # Update selected extension
                self.selected_extension = extname
                # Redraw IFUs selection
                self._plot_ifus_selector(self.selected_file)
                # Redraw spectrum
                self._plot_spectrum()
                # Redisplay image
                self._disp_image()

    def setFSCallback(self, filename) :
        # Keep track of the selected file
        self.selected_file = filename

        # Check that the new file currently selected extension is valid
        if (not self.extension_has_spectrum(self.selected_file, self.selected_extension)):
            self.selected_extension = self._get_first_valid_extname(self.selected_file)
            self._plot_ifus_selector(self.selected_file)
        # Redraw spectrum
        self._plot_spectrum()
        # Redisplay image
        self._disp_image()
        
    def _add_subplots(self, figure):
        gs = gridspec.GridSpec(3, 2)
        self.files_selector = figure.add_subplot(gs[0,:])
        self.ifus_selector = figure.add_subplot(gs[1,:])
        self.spec_plot = figure.add_subplot(gs[2,0])
        self.collapsed_img = figure.add_subplot(gs[2,1])

    def _data_plot_get_tooltip(self):
        return self.selected_file+" ["+self.selected_extension+"];" +" Object NAME = "+str(  self.files[self.selected_file][self.selected_extension]["NAME"]  )

    def _disp_image(self):
        extension_dict = self.files[self.selected_file][self.selected_extension] 
        self.collapsed_img.clear()
        imgdisp = pipeline_display.ImageDisplay()
        imgdisp.setAspect('equal')
        
        # Only if data
        if 'Collapsed' in extension_dict.keys():
            imgdisp.display(self.collapsed_img, "Collapsed image", self._data_plot_get_tooltip(), extension_dict["Collapsed"])
            self.collapsed_img.set_xlabel("pixels")   # item 6f
            self.collapsed_img.set_ylabel("pixels")   # item 6f
            

    def _plot_spectrum(self):
        extension_dict = self.files[self.selected_file][self.selected_extension] 
        
        # Plot Spectrum (to be contd.)
        self.spec_plot.clear()
        specdisp = pipeline_display.SpectrumDisplay()
        specdisp.setLabels(r"$\lambda$["+"$\mu$m]" + " (blue: Observed; red: OH[arb. units])", 
                            self._process_label(extension_dict["UNIT"]) )
        
        if (self.oh_spec["Found"]):  # Placed before spectrum plotting, so it stays in the background
			# Overplot the OH spectrum
            pix = numpy.arange(len(self.oh_spec["Spectrum"]))
            wave = self.oh_spec["CRVAL1"] + pix * self.oh_spec["CDELT1"]
            oh_flux = self.oh_spec["Spectrum"] / numpy.nanmax(self.oh_spec["Spectrum"]) * 200*numpy.nanmax(extension_dict["Spectrum"]) 
            specdisp.overplot(self.spec_plot, wave, oh_flux, '#ED5D5D')
            #self.spec_plot.legend(('Observed', 'OH'))
            self.spec_plot.fill_between(wave, 0, oh_flux, color='#ED5D5D')
        
        # Define wave
        pix = numpy.arange(len(extension_dict["Spectrum"]))
        wave = extension_dict["CRVAL3"] + pix * extension_dict["CDELT3"]

		# Plot Spectrum (contd.)
        specdisp.display(self.spec_plot, "Spectrum", self._data_plot_get_tooltip(), wave, extension_dict["Spectrum"])



    def _process_label(self, in_label):
        # If known, 'pretty print' the label
        if (in_label == "erg.s**(-1).cm**(-2).angstrom**(-1)"):
            return "erg sec" + r"$^{-1}$"+"cm" + r"$^{-2}$" + r"$\AA^{-1}$"
        else:
            return in_label

    def _get_ifu_status_from_file(self, filename, extname):
        cur_status = self.files[filename][extname]["IFU_STATUS"]
        if cur_status == "NotInPAF/Locked" :
            return 'NotInPAF & Locked'
        return cur_status

    def _plot_ifus_selector(self, filename):
        self.ifus_selector.clear()

        # Loop on the different kind of Status to Print the Legend
        self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(1, 9, 8, 10), cmap='summer')
        self.ifus_selector.text(2, 8.5, 'Active', fontsize=11,color='white')
        self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(11, 19, 8, 10), cmap='spring')
        self.ifus_selector.text(12, 8.5, 'Locked', fontsize=11,color='white')
        self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(21, 29, 8, 10), cmap='winter')
        self.ifus_selector.text(22, 8.5, 'NotInPAF', fontsize=11,color='white')
        self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(31, 39, 8, 10), cmap='autumn')
        self.ifus_selector.text(32, 8.5, 'NotInPAF & Locked', fontsize=11,color='white')
        self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(41, 49, 8, 10), cmap='cool')
        self.ifus_selector.text(42, 8.5, 'Collision', fontsize=11,color='white')
        self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(1, 9, 4, 6), cmap='copper')
        self.ifus_selector.text(2, 4.5, 'Empty', fontsize=11,color='white')

        # Display the IFUs selection squares
        box_y_start = 1
        box_y_stop  = 3
        box_xwidth  = 1.5
        for extname in self.files[filename].keys():
            # Compute the IFU number
            ifu_number  = self.files[filename][extname]['IFU_NUMBER']
            # Draw the little IFU image
            box_xstart  = 2 * ifu_number - 1
            box_xstop = box_xstart + box_xwidth

            ifu_status = self._get_ifu_status_from_file(filename, extname)
            self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(box_xstart,box_xstop,box_y_start,box_y_stop), 
                    cmap=self.IFU_stat_color[ifu_status])
            # Write the IFU number in the image
            self.ifus_selector.text(2 * (ifu_number-1) + 1.2, 1.5, str(ifu_number), fontsize=13,color='white')
            # Mark the selected IFU
            if (extname == self.selected_extension):
                self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(box_xstart,box_xstop,0.5,0.7), 
                        cmap=self.IFU_stat_color[ifu_status])

        self.ifus_selector.axis([0,50,0,11])
        self.ifus_selector.set_title("IFU Selection (Mouse middle button)")
        self.ifus_selector.get_xaxis().set_ticks([])
        self.ifus_selector.get_yaxis().set_ticks([])

    # Get the first valid extension name (ie containing a spectrum) in a file - "" if None 
    def _get_first_valid_extname(self, filename):
        for extname in sorted(self.files[filename].keys()):
            if (self.extension_has_spectrum(filename, extname)) :
                return extname
        return ""

    def _data_plot(self):
        # Initial file is the first one
        self.selected_file = self.files.keys()[0]
        self.selected_extension = self._get_first_valid_extname(self.selected_file)
        
        # Plot the IFUS selection
        self._plot_ifus_selector(self.selected_file)

        # Draw Spectrum
        self._plot_spectrum()

        # Display image
        self._disp_image()

    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1,1,1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found. Input files should contain this" \
                       " type:\n%s" % self.reconstructed_cat
        self.img_plot.text(0.1, 0.6, text_nodata, color='#11557c',
                      fontsize=18, ha='left', va='center', alpha=1.0)
        self.img_plot.tooltip = 'No data found'

#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    #Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        #Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
