from __future__ import with_statement
import sys

try:
    import numpy
    import re
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    import matplotlib.gridspec as gridspec
    import_success = True
except ImportError:
    import_success = False
    print "Error importing modules pyfits, wx, matplotlib, numpy"

def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)


class DataPlotterManager(object):
    # static members
    recipe_name = "kmos_std_star"
    star_spec_cat = "STAR_SPEC"
    telluric_cat = "TELLURIC"
    std_image_cat = "STD_IMAGE"
    y_scalefactor = 1000

    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        return [
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="imethod",
                    group="Recons.", description="Interpolation Method (NN, lwNN, swNN, MS, CS)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="xcal_interpolation",
                    group="Recons.", description="Interpolate xcal between rotator angles"),

            reflex.RecipeParameter(recipe=self.recipe_name, displayName="mask_method",
                    group="Extr.", description="Extraction Method (optimal/integrated)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="centre",
                    group="Extr.", description="Centre (integrated only)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="radius",
                    group="Extr.", description="Radius (integrated only)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="fmethod",
                    group="Extr.", description="Fitting Method (gauss, moffat)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="neighborhoodRange",
                    group="Extr.", description="Range for Neighbors"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="flux",
                    group="Extr.", description="Apply flux conservation"),

            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmethod",
                    group="Comb.", description="Combination Method (average, median, sum, min_max, ksigma)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cpos_rej",
                    group="Comb.", description="The positive rejection threshold"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cneg_rej",
                    group="Comb.", description="The negative rejection threshold"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="citer",
                    group="Comb.", description="The number of iterations"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmax",
                    group="Comb.", description="The number of maximum pixel values"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmin",
                    group="Comb.", description="The number of minimum pixel values"),
            
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="startype",
                    group="Star", description="The star spectral type (O, B, A, F, G)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="magnitude",
                    group="Star", description="Star Magnitude"),
        ]

    def readFitsData(self, fitsFiles):
        self.frames = dict()
        for f in fitsFiles:
            self.frames[f.category] = PipelineProduct(f)

        # Two cases: the file category is found or not found.
        # Define the plotting functions in both cases
        if self.std_image_cat in self.frames and self.star_spec_cat in self.frames and self.telluric_cat in self.frames:
            # Get the wished files
            std_image = self.frames[self.std_image_cat]
            star_spec = self.frames[self.star_spec_cat]
            telluric = self.frames[self.telluric_cat]
           
            # Initialise
            self.star_data_extnames = []
            self.qc_std_trace = dict()
            self.qc_spat_res = dict()
            self.image_std = dict()
            self.image_std_avg = dict()
            self.image_std_stdev = dict()
            self.qc_thruput = dict()
            self.qc_zpoint = dict()
            self.crpix1 = dict()
            self.crval1 = dict()
            self.cdelt1 = dict()
            self.spec_data = dict()
            self.spec_noise = dict()
            
            # READ data
            self.qc_nr_std_stars = star_spec.all_hdu[0].header['ESO QC NR STD STARS']
            self.qc_thruput_mean = star_spec.all_hdu[0].header['ESO QC THRUPUT MEAN']
            self.qc_thruput_sdv  = star_spec.all_hdu[0].header['ESO QC THRUPUT SDV']
            
            # Loop on all extensions of std_image to find the std stars extensions
            for std_image_ext in std_image.all_hdu:
                naxis = std_image_ext.header['NAXIS']
                # NAXIS is 2 if there is an image, 0 otherwise
                if (naxis == 2):
                    # extname is like IFU.3.DATA 
                    extname = std_image_ext.header['EXTNAME']
                    self.star_data_extnames.append(extname)
                    self.qc_std_trace[extname] = std_image_ext.header['ESO QC STD TRACE']
                    self.qc_spat_res[extname] = std_image_ext.header['ESO QC SPAT RES']
                    self.image_std[extname] = std_image_ext.data
                    self.image_std_avg[extname] = numpy.average(std_image_ext.data)
                    self.image_std_stdev[extname] = numpy.std(std_image_ext.data)

                    # Get infos from star_spec using the extname
                    self.qc_thruput[extname] = star_spec.all_hdu[extname].header['ESO QC THRUPUT']
                    self.qc_zpoint[extname] = star_spec.all_hdu[extname].header['ESO QC ZPOINT']
                    self.crpix1[extname] = star_spec.all_hdu[extname].header['CRPIX1']
                    self.crval1[extname] = star_spec.all_hdu[extname].header['CRVAL1']
                    self.cdelt1[extname] = star_spec.all_hdu[extname].header['CDELT1']
                    self.spec_data[extname] = star_spec.all_hdu[extname].data

                    # noise_extname is like IFU.3.NOISE
                    noise_extname = re.sub("DATA", "NOISE", extname)    
                    self.spec_noise[extname] = star_spec.all_hdu[noise_extname].data


            # Set the plotting functions
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot
        else:
            # Set the plotting functions to NODATA ones
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self) :
        widgets = list()
        # Radio button
        self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.axradiobutton, self.setRadioCallback, 
                self.star_data_extnames, 0, title='Standard star selection')
        widgets.append(self.radiobutton)
        return widgets

    def setRadioCallback(self, label) :

        # Setup the image display
        imgdisp = pipeline_display.ImageDisplay()
        imgdisp.setAspect('equal')
        imgdisp.z_lim = (self.image_std_avg[label] - self.image_std_stdev[label], self.image_std_avg[label] + 2 * self.image_std_stdev[label])
        imgdisp.display(self.img_plot, "Median collapsed cube (STD_IMAGE)\nFWHM "+ r"$\simeq$ " +str( numpy.round(self.qc_spat_res[label], 2) )+ "  [arcsec]",                     self._data_plot_get_tooltip(label), self.image_std[label])
        self.img_plot.set_xlabel("pixels")
        self.img_plot.set_ylabel("pixels")

        # Define wave
        pix = numpy.arange(len(self.spec_data[label]))
        wave = self.crval1[label] + pix * self.cdelt1[label]  
          
        # Plot Spectrum
        specdisp = pipeline_display.SpectrumDisplay()
        self.spec_plot.clear()
        specdisp.setLabels(r"$\lambda$["+"$\mu$m]", "Flux (ADU)   [x"+str(self.y_scalefactor)+"]" )
        specdisp.display(self.spec_plot, "Extracted Standard Star Spectrum (STAR_SPEC)", self._data_plot_get_tooltip(label), 
                wave, self.spec_data[label]/self.y_scalefactor) # TO DO: Unify plotting in one function.

        if (self.spec_noise[label] is not None):
            # Overplot the Noise spectrum
            specdisp.overplot(self.spec_plot, wave, self.spec_noise[label]/self.y_scalefactor, 'red')   # TO DO: Unify plotting in one function.
            self.spec_plot.legend(('Flux', 'Noise'))

    def _add_subplots(self, figure):
        gs = gridspec.GridSpec(2, 2)
        self.axradiobutton =    figure.add_subplot(gs[0,0])
        self.img_plot =         figure.add_subplot(gs[0,1])
        self.spec_plot =        figure.add_subplot(gs[1,:])

    def _data_plot_get_tooltip(self, extname):
        # Create the tooltip
        tooltip = " \
            ESO QC NR STD STARS : %f \n \
            ESO QC THRUPUT MEAN : %f \n \
            ESO QC THRUPUT SDV  : %f \n \
            ESO QC STD TRACE    : %f \n \
            ESO QC SPAT RES     : %f \n \
            ESO QC THRUPUT      : %f \n \
            ESO QC ZPOINT       : %f \
            " % (self.qc_nr_std_stars, self.qc_thruput_mean, self.qc_thruput_sdv, self.qc_std_trace[extname], self.qc_spat_res[extname], self.qc_thruput[extname], self.qc_zpoint[extname])
        return tooltip

    def _data_plot(self):
        extname = self.star_data_extnames[0]
        imgdisp = pipeline_display.ImageDisplay()
        imgdisp.setAspect('equal')
        imgdisp.z_lim = (self.image_std_avg[extname] - self.image_std_stdev[extname], self.image_std_avg[extname] + 2 * self.image_std_stdev[extname])
        imgdisp.display(self.img_plot,  "Median collapsed cube (STD_IMAGE)\nFWHM "+ r"$\simeq$ " +str( numpy.round(self.qc_spat_res[extname], 2) )+ "  [arcsec]",                    self._data_plot_get_tooltip(extname), self.image_std[extname])  
        #Please alert if it shuld be latex "\approx" instead of "\simeq".
        self.img_plot.set_xlabel("pixels")
        self.img_plot.set_ylabel("pixels")

        # Define wave
        pix = numpy.arange(len(self.spec_data[extname]))
        wave = self.crval1[extname] + pix * self.cdelt1[extname]  
        
        # Plot Spectrum
        specdisp = pipeline_display.SpectrumDisplay()
        specdisp.setLabels(r"$\lambda$["+"$\mu$m]", "Flux (ADU)   [x"+str(self.y_scalefactor)+"]" )
        specdisp.display(self.spec_plot, "Extracted Standard Star Spectrum (STAR_SPEC)", self._data_plot_get_tooltip(extname), wave, self.spec_data[extname]/self.y_scalefactor) # TO DO: Unify plotting in one function.

        if (self.spec_noise[extname] is not None):
            # Overplot the Noise spectrum
            specdisp.overplot(self.spec_plot, wave, self.spec_noise[extname]/self.y_scalefactor, 'red') # TO DO: Unify plotting in one function.
            self.spec_plot.legend(('Flux', 'Noise'))

    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1,1,1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found"
        self.img_plot.text(0.1, 0.6, text_nodata, color='#11557c', fontsize=18, ha='left', va='center', alpha=1.0)
        self.img_plot.tooltip = text_nodata

#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    #Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        #Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
