from __future__ import with_statement
import sys

try:
    import inspect
    import numpy
    import os
    import re
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    import matplotlib.gridspec as gridspec
    from matplotlib.text import Text
    from matplotlib.pyplot import cm
    from pylab import *

    import_success = True
except ImportError:
    import_success = False
    print "Error importing modules pyfits, wx, matplotlib, numpy"

def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)

class DataPlotterManager(object):
    # static members
    recipe_name  = "kmos_combine"
    recons_cat   = "SCI_RECONSTRUCTED"
    recons_cat2  = "SINGLE_CUBES"
    combined_cat = "COMBINED_CUBE"
    
    oh_spec_cat = "OH_SPEC"
    cores = ['r', 'g', '0.50', 'c', 'm', 'y', 'b']
    labels_view=['Display input and output spectra', 'Display output spectra', 'Display input spectra']
    vs_state = 0
    selected_file_idx = 0

    def setCurrentParameterHelper(self, helper):
        self.param_helper = helper

    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        return [
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="name",
                    group="Comb.", description="Name of the object to combine. Empty[default] means all."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="fmethod",
                    group="Comb.", description="Fitting Method (gauss, moffat)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="flux",
                    group="Comb.", description="Apply flux conservation"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="edge_nan",
                    group="Comb.", description="Set borders of cubes to NaN before combining them"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmethod",
                    group="Comb.", description="Combination Method (average, median, sum, min_max, ksigma)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cpos_rej",
                    group="Comb.", description="The positive rejection threshold"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cneg_rej",
                    group="Comb.", description="The negative rejection threshold"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="citer",
                    group="Comb.", description="The number of iterations"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmax",
                    group="Comb.", description="The number of maximum pixel values"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmin",
                    group="Comb.", description="The number of minimum pixel values"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="skipped_frames",
                    group="Comb.", description="Comma-separated list of R:I duplets/'labels' indexing the frames to skip. R:I means the Rth RAW frame from the Ith IFU will be ignored. Empty[default] skips none."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="ifus",
                    group="Other", description="The indices of the IFUs to combine"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="method",
                    group="Other", description="The shifting method"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="filename",
                    group="Other", description="The path to the file with the shift vectors"),
        ]

    def readFitsData(self, fitsFiles):
        # Initialise
        self.comb_files = []    # List of Dict (1 per comb. file) - Keys : FILENAME, USEDIFUS, SKIPPEDIFUS, CRPIX3, CRVAL3, CDELT3, UNIT, Spectrum, Collapsed, in_ifus
                                # in_fus: List of Dict (1 per used IFUS) - Keys : REC_FILENAME, rec_im_nums, ifu_nb, CRPIX3, CRVAL3, CDELT3, UNIT, Spectrum
        self.oh_spec = dict()   # Dict - Keys : Found, CRVAL1, CDELT1, Spectrum 

        # Loop on all FITS files 
        self.oh_spec["Found"] = False
        for f in fitsFiles:
            # Use OH_SPEC if found
            if f.category == self.oh_spec_cat :
                oh_spec_file = PipelineProduct(f)
                self.oh_spec["Found"] = True
                self.oh_spec["CRVAL1"] = oh_spec_file.all_hdu[1].header["CRVAL1"]
                self.oh_spec["CDELT1"] = oh_spec_file.all_hdu[1].header["CDELT1"]
                self.oh_spec["Spectrum"] = oh_spec_file.all_hdu[1].data

            # For each COMBINED image
            if f.category == self.combined_cat :
                combined_file = PipelineProduct(f)
                filename = os.path.basename(f.name)

                # Create a Dictionary per file
                my_file_dict = dict()
                
                # STORE the combined file name
                my_file_dict["FILENAME"] = filename
                
                # STORE the Used reconstructed files IFUS : "1:3,2:5,..."
                combined_primary = combined_file.all_hdu[0]
                
                # Support case where keyword is missing
                if 'ESO PRO USEDIFUS' in combined_primary.header :
                    my_file_dict["USEDIFUS"] = combined_primary.header['ESO PRO USEDIFUS']
                else :
                    my_file_dict["USEDIFUS"] = ""
                if 'ESO PRO SKIPPEDIFUS' in combined_primary.header :
                    my_file_dict["SKIPPEDIFUS"] = combined_primary.header['ESO PRO SKIPPEDIFUS']
                else :
                    my_file_dict["SKIPPEDIFUS"] = ""
               
                # Store the COMBINE Spectrum and Image
                combined_ext = combined_file.all_hdu[1]            
                naxis = combined_ext.header['NAXIS']
                if (naxis == 3):
                    # Get Keyword infos  
                    my_file_dict["CRPIX3"] = combined_ext.header['CRPIX3']
                    my_file_dict["CRVAL3"] = combined_ext.header['CRVAL3']
                    my_file_dict["CDELT3"] = combined_ext.header['CDELT3']
                    my_file_dict["UNIT"]   = combined_ext.header['ESO QC CUBE_UNIT']
                    
                    # Fill Spectrum
                    my_file_dict["Spectrum"] = []
                    for cube_plane in combined_ext.data:
                        cube_plane_nan_free = cube_plane[~numpy.isnan(cube_plane)]
                        if (len(cube_plane_nan_free) > 0):
                            mean = cube_plane_nan_free.mean()
                        else:
                            mean = numpy.nan
                        my_file_dict["Spectrum"].append(mean)

                    # Fill Collapsed Image
                    collapsed_frame_ext = self._get_collapsed_ext(fitsFiles, f)
                    if collapsed_frame_ext:
                        my_file_dict["Collapsed"] = collapsed_frame_ext.data
 
                ### Parse R:I from CSV string ##################
                rec_im_nums = dict()
                ifu_nums = dict()
                skipped = dict()
                # Store used IFUs from the FITS keyword
                duplet_seq = my_file_dict["USEDIFUS"]
                duplet_seq.strip()
                
                if len(duplet_seq) > 1 :
                    duplet_seq = duplet_seq.split(",")
                    for kindx in range(0, len(duplet_seq)):
                        xpto = duplet_seq[kindx]
                        aux = xpto.split(":")
                        rec_im_nums[kindx] = int( aux[0] )
                        ifu_nums[kindx] = int( aux[1] )
                        skipped[kindx] = 0
                    
                # Store Skipped IFUs
                next_idx = len(duplet_seq)
                duplet_seq = my_file_dict["SKIPPEDIFUS"]
                duplet_seq.strip()
                if len(duplet_seq) > 1 :
                    duplet_seq = duplet_seq.split(",")
                    for skipped_idx in range(0, len(duplet_seq)):
                        xpto = duplet_seq[skipped_idx]
                        aux = xpto.split(":")
                        rec_im_nums[next_idx+skipped_idx] = int( aux[0] )
                        ifu_nums[next_idx+skipped_idx] = int( aux[1] )
                        skipped[next_idx+skipped_idx] = 1
                ### end ########################################
                
                # Get the Input IFUs for this combined image
                my_file_dict["in_ifus"] = []

                # Loop on all input fitsfiles
                countr = 0  
                for frec in fitsFiles:
                    if frec.category == self.recons_cat or frec.category == self.recons_cat2 :
                        reconstructed_file = PipelineProduct(frec)
                        reconstructed_filename = os.path.basename(frec.name)
                        
                        # countr is the id of the current REC file amongst the REC files
                        countr += 1
                        
                        # Loop on the used IFUS to get the ones from the current REC file
                        for duplet_num in range(0, len(rec_im_nums) ) :   
                            r = rec_im_nums[duplet_num]

                            # Check if the current REC file is specified here
                            if r == countr :
                                # Create a dictionary per used IFU
                                my_used_ifus_dict = dict()

                                # STORE RECONSTRUCTED filename
                                my_used_ifus_dict["REC_FILENAME"] = reconstructed_filename
                                
                                i = ifu_nums[duplet_num]

                                # STORE the ifu nb / rec file idx / skipped info
                                my_used_ifus_dict["ifu_nb"] = i
                                my_used_ifus_dict["rec_im_nums"] = r
                                my_used_ifus_dict["skipped"] = skipped[duplet_num]
                                
                                # Store the spectrum in a dictionary
# TODO : ith IFU in ith extension ??
                                rec_ext = reconstructed_file.all_hdu[i]  
                                rec_naxis = rec_ext.header['NAXIS']
                                if (rec_naxis == 3):
                                    # Get Keyword infos  
                                    my_used_ifus_dict["CRPIX3"] = rec_ext.header['CRPIX3']
                                    my_used_ifus_dict["CRVAL3"] = rec_ext.header['CRVAL3']
                                    my_used_ifus_dict["CDELT3"] = rec_ext.header['CDELT3']
                                    my_used_ifus_dict["UNIT"]   = rec_ext.header['ESO QC CUBE_UNIT']
                                    
                                    # Fill Spectrum
                                    my_used_ifus_dict["Spectrum"] = []
                                    for cube_plane in rec_ext.data:
                                        cube_plane_nan_free = cube_plane[~numpy.isnan(cube_plane)]
                                        if (len(cube_plane_nan_free) > 0):
                                            mean = cube_plane_nan_free.mean()
                                        else:
                                            mean = numpy.nan
                                        my_used_ifus_dict["Spectrum"].append(mean)
                                        
                                # add this dictionary my_file_dict["in_ifus"]
                                my_file_dict["in_ifus"].append(my_used_ifus_dict)       

                # Append the dictionary to the files list
                self.comb_files.append(my_file_dict)
               
        # If proper files are there...
        if (len(self.comb_files) > 0):
            # Set the plotting functions
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot
        else:
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    # Returns the corresponding extension in the collapsed corresponding file 
    # The corresponding file must have the same file name as ref_file, prefixed with make_image_
    def _get_collapsed_ext(self, fitsFiles, comb_file):
        comb_file_name = os.path.basename(comb_file.name)
        # Loop on all FITS files 
        for f in fitsFiles:
            make_image_filename = os.path.basename(f.name)
            if (make_image_filename == "make_image_"+comb_file_name):
                collapsed_frame = PipelineProduct(f)
                return collapsed_frame.all_hdu[1]

    def addSubplots(self, figure):
        self._add_subplots(figure)

    #def addSubplots_SS(self, figure):
        #self._add_subplots_SS(figure)

    def plotProductsGraphics(self):
#        print self.param_helper("skipped_frames")
        self._plot()

    def plotWidgets(self) :
        #self.radiobutton.clear()
        
        widgets = list()

        labels=[]
        for idx in range(0, len(self.comb_files)):
            labels.append(self.comb_files[idx]["FILENAME"])
        #labels_view=['Combined', 'Classic', 'Spectra Stack']
        
        # Files Selector radiobutton
        self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.comb_files_selector, self.setFSCallback, labels, self.selected_file_idx, title='Files Selection (double-click Left Mouse Button)')
        widgets.append(self.radiobutton)
        
        self.plotVS_Widget()
        
        return widgets
        
    def plotVS_Widget(self) :
        self.view_select.clear()
        widgets = list()
        # View Selector radiobutton
        self.radiobutton_VS = reflex_plot_widgets.InteractiveRadioButtons(self.view_select, self.setVSCallback, self.labels_view, self.vs_state)
        self.view_select.text(0.5, 0.0, 'View Selection (Left Mouse Button)', verticalalignment='bottom', horizontalalignment='center', 
                              fontsize=12, fontweight='semibold')
        widgets.append(self.radiobutton_VS)
        return widgets

#def setCurrentParameterHelper(self, helper) :
#print helper('skipped_frames')

    def setFSCallback(self, filename) :
        for idx in range(0, len(self.comb_files)):
            if (self.comb_files[idx]["FILENAME"] == filename):
                self.selected_file_idx = idx        
        # Redraw spectrum
        self._plot_spectrum()
        # Redisplay image
        self._disp_image()
        # Redraw spectra stack
        self._disp_spectra_stack()
        
    def setVSCallback(self, chosen_view) :
        #print(chosen_view)

        if(chosen_view == self.labels_view[0]):    #if(chosen_view == 'Combined'):
            self.vs_state = 0
            self.addSubplots(self.my_figure)
            
            # Redraw spectrum
            self._plot_spectrum()
            # Redisplay image
            self._disp_image()
            # Redraw spectra stack
            self._disp_spectra_stack()
            
            # Redraw View Selection radiobutton panel
            self.plotWidgets()
            
        if(chosen_view == self.labels_view[1]):    #if(chosen_view == 'Classic'):
            self.vs_state = 1
            self._add_subplots_Clssc(self.my_figure)
            
            # Redraw spectrum
            self._plot_spectrum()
            # Redisplay image
            self._disp_image()
            
            # Redraw View Selection radiobutton panel
            self.plotWidgets()              
                    
        if(chosen_view == self.labels_view[2]):    #if(chosen_view == 'Spectra Stack'):
            self.vs_state = 2
            self._add_subplots_SS(self.my_figure)

            # Redraw spectra stack
            self._disp_spectra_stack()
            
            # Redraw View Selection radiobutton panel
            self.plotVS_Widget()  
    
        self.my_figure.canvas.draw()
            
    def _add_subplots_Clssc(self, figure):
        self.my_figure = figure
        self.my_figure.clear()
        figure.clear()
        
        gs = gridspec.GridSpec(4, 3)
        self.comb_files_selector = figure.add_subplot(gs[0:2, :])
        self.spec_plot = figure.add_subplot(     gs[2,   0:2])
        self.collapsed_img = figure.add_subplot( gs[2,   2])
        self.view_select = figure.add_subplot(   gs[3,   :])

    def _add_subplots_SS(self, figure):
        self.my_figure = figure
        self.my_figure.clear()
        figure.clear()
        
        gs = gridspec.GridSpec(4, 3)
        self.spec_stack = figure.add_subplot(    gs[0:3, :])
        self.view_select = figure.add_subplot(   gs[3,   :])
        
    def _add_subplots(self, figure):
        self.my_figure = figure
        self.my_figure.clear()
        figure.clear()
        #stop_here()
        gs = gridspec.GridSpec(4, 3)
        self.comb_files_selector = figure.add_subplot(gs[0:2, 0:2])
        self.spec_plot = figure.add_subplot(     gs[2,   0])
        self.collapsed_img = figure.add_subplot( gs[2,   1])
        self.spec_stack = figure.add_subplot(    gs[0:3, 2])
        self.view_select = figure.add_subplot(   gs[3,   :])

    def _data_plot_get_tooltip(self):
        return self.comb_files[self.selected_file_idx]["FILENAME"]

    def _disp_image(self):
        extension_dict = self.comb_files[self.selected_file_idx]
        self.collapsed_img.clear()
        imgdisp = pipeline_display.ImageDisplay()
        imgdisp.setAspect('equal')

        # Only if data
        if 'Collapsed' in extension_dict.keys():
            imgdisp.display(self.collapsed_img, "Collapsed image", self._data_plot_get_tooltip(), extension_dict["Collapsed"])
            self.collapsed_img.set_xlabel("pixels")   # item 6f
            self.collapsed_img.set_ylabel("pixels")   # item 6f 

    def _plot_spectrum(self):
        extension_dict = self.comb_files[self.selected_file_idx] 
        
        # Plot Spectrum (to be contd.)
        self.spec_plot.clear()
        specdisp = pipeline_display.SpectrumDisplay()
        specdisp.setLabels(r"$\lambda$["+"$\mu$m]" + " (blue: Observed; red: OH[arb. units])", 
                            self._process_label(extension_dict["UNIT"]) )
        
        if (self.oh_spec["Found"]):  # Placed before spectrum plotting, so it stays in the background
            # Overplot the OH spectrum
            pix = numpy.arange(len(self.oh_spec["Spectrum"]))
            wave = self.oh_spec["CRVAL1"] + pix * self.oh_spec["CDELT1"]
            oh_flux = self.oh_spec["Spectrum"] / numpy.nanmax(self.oh_spec["Spectrum"]) * 200*numpy.nanmax(extension_dict["Spectrum"]) 
            specdisp.overplot(self.spec_plot, wave, oh_flux, '#ED5D5D')
            #self.spec_plot.legend(('Observed', 'OH'))
            self.spec_plot.fill_between(wave, 0, oh_flux, color='#ED5D5D')
        
        # Define wave
        pix = numpy.arange(len(extension_dict["Spectrum"]))
        wave = extension_dict["CRVAL3"] + pix * extension_dict["CDELT3"]

        # Plot Spectrum (contd.)
        specdisp.display(self.spec_plot, "Spectrum", self._data_plot_get_tooltip(), wave, extension_dict["Spectrum"])

    def _disp_spectra_stack(self):        
        comb_dict = self.comb_files[self.selected_file_idx]
       
        spec_list = []
        ifus_list = []
        rec_idx_list = []
        skipped = []
        for ind in range( 0, len(comb_dict["in_ifus"]) ) :
            spec_list.append( comb_dict["in_ifus"][ind]["Spectrum"] )  
            ifus_list.append( comb_dict["in_ifus"][ind]["ifu_nb"] )
            rec_idx_list.append( comb_dict["in_ifus"][ind]["rec_im_nums"] )
            skipped.append( comb_dict["in_ifus"][ind]["skipped"] )

        # Define wave
        pix = numpy.arange(len(comb_dict["Spectrum"]))
        wave = comb_dict["CRVAL3"] + pix * comb_dict["CDELT3"]
        lower_xbound = np.nanmin(wave)
        upper_xbound = np.nanmax(wave)
        #debug
        #print("\n" + "np.nanmin(wave) = " + str(np.nanmin(wave)) + "\n")
        #print("\n" + "np.nanmax(wave) = " + str(np.nanmax(wave)) + "\n")
        
        # Plot Spectrum
        self.spec_stack.clear()
        sstackdisp = pipeline_display.SpectrumDisplay()
        sstackdisp.setLabels(r"$\lambda$ ["+"$\mu$m] ",  self._process_label(comb_dict["UNIT"]) )
        
        v_offset = 0.
        lower_ybound = 0.
        upper_ybound = 0.
        for indx in range(0, len(spec_list) ):
            c = self.cores[mod(indx, len(self.cores))]
            spectrum = spec_list[indx]
            spec_ = [x for x in spectrum if (np.isnan(x) == False)]
            #os_factor = 1.0e-16; v_offset = os_factor * (indx+0)   #<-fixed step
            med = np.median(spec_) if not np.isnan(np.median(spec_)) else 0
            mad = np.median(np.abs(spec_ - med)) if not np.isnan(np.median(spec_)) else 0.1  #Once NaN always NaN assumption forced because of isnan()
            os_factor = med
            v_offset += mad*10  # cumulative version
            if lower_ybound == 0 :
                lower_ybound = v_offset
            
            # Skipped Spectra in Black
            if (skipped[indx]) :
                c = 'black' ;

            spectrum_str = str(rec_idx_list[indx])+':'+str(ifus_list[indx])
            sstackdisp.overplot(self.spec_stack, wave, np.add(spectrum, v_offset), c )
            self.spec_stack.text(upper_xbound*1.01211366512 if self.vs_state==2 else upper_xbound*1.04411725927, 
                                 v_offset +os_factor, spectrum_str, verticalalignment='top', horizontalalignment='right', color=c, fontsize=10)
            upper_ybound = v_offset + mad*10
            
        if (lower_ybound != 0 and upper_ybound != 0) :
            self.spec_stack.set_ybound(lower=lower_ybound*1.15, upper=upper_ybound*1.02)
            self.spec_stack.set_xbound(lower=lower_xbound*1.00259742743, upper=upper_xbound)
        if self.vs_state == 2 :
            self.spec_stack.set_title("IFUs for '"+comb_dict["FILENAME"]+"' (skipped in black)", fontsize=12, fontweight='semibold') 
            self.spec_stack.set_xlabel(r"$\lambda$ ["+"$\mu$m]" + "   (duplet label syntax, <Recons. Image #> : <IFU #>)")
        else :
            self.spec_stack.set_title("IFUs (skipped in black)\n", fontsize=12, fontweight='semibold') 
            self.spec_stack.set_xlabel(r"$\lambda$ ["+"$\mu$m]")
        self.spec_stack.set_ylabel( self._process_label(comb_dict["UNIT"]) ) 

        if len(spec_list) == 0 :
            self.spec_stack.set_title("Missing USEDIFUS", fontsize=12, fontweight='semibold')

    def _process_label(self, in_label):
        # If known, 'pretty print' the label
        if (in_label == "erg.s**(-1).cm**(-2).angstrom**(-1)"):
            return "erg sec" + r"$^{-1}$"+"cm" + r"$^{-2}$" + r"$\AA^{-1}$"
        else:
            return in_label
    
    def _data_plot(self):
        #~ # Initial file is the first one
        #~ self.selected_file_idx = 0
        
        # Draw Spectrum
        self._plot_spectrum()

        # Display image
        self._disp_image()
        
        # Display spectra stack
        self._disp_spectra_stack()

    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1,1,1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found. Input files should contain this" \
                       " type:\n%s" % self.combined_cat
        self.img_plot.text(0.1, 0.6, text_nodata, color='#11557c',
                      fontsize=18, ha='left', va='center', alpha=1.0)
        self.img_plot.tooltip = 'No data found'

#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    # Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        # Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()

