from __future__ import with_statement
import sys

try:
    import numpy
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    import matplotlib.gridspec as gridspec
    import_success = True
except ImportError:
    import_success = False
    print "Error importing modules pyfits, wx, matplotlib, numpy"

def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)


class DataPlotterManager(object):
    # static members
    recipe_name = "kmos_wave_cal"
    det_img_wave_cat = "DET_IMG_WAVE"
    lcal_cat = "LCAL"

    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        return [
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="order",
                    group="Wavelength Calibration", description="The wavelength polynomial order [0,7]"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="dev_flip",
                    group="Wavelength Calibration", description="TRUE if the wavelengths are ascending on the detector from top to bottom"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="dev_disp",
                    group="Wavelength Calibration", description="The expected dispersion of the wavelength in microns/pixel"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="suppress_extension",
                    group="Wavelength Calibration", description="Suppress arbitrary filename extension"),
        ]

    def readFitsData(self, fitsFiles):
        self.frames = dict()
        for f in fitsFiles:
            self.frames[f.category] = PipelineProduct(f)

        # Two cases: the file category is found or not found.
        # Define the plotting functions in both cases
        if self.det_img_wave_cat in self.frames:
            # Get the wished files
            det_img_wave = self.frames[self.det_img_wave_cat]
            self.det_img_wave_name = det_img_wave.fits_file.name

            # Get the angle values
            self.angles_list = dict()
            for i in range(18):
                if (i+1 < len(det_img_wave.all_hdu)):
                    angle = det_img_wave.readKeyword('ESO PRO ROT NAANGLE', i+1)
                    if not (angle in self.angles_list) :
                        self.angles_list[angle] = []
                    self.angles_list[angle].append(i+1)

            # Sorted angles list
            self.sorted_angles = self.angles_list.keys()
            self.sorted_angles.sort() 

            # Read the Plotting Data
            self.argon_pos_data = []
            self.argon_fwhm_data = []
            self.neon_pos_data = []
            self.neon_fwhm_data = []
            for i in range(18):
                if (i+1 < len(det_img_wave.all_hdu)):
                    key1 = det_img_wave.readKeyword('ESO QC ARC AR POS MEAN', i+1)
                    key2 = det_img_wave.readKeyword('ESO QC ARC AR FWHM MEAN', i+1)
                    ar_vscale = det_img_wave.readKeyword('ESO QC ARC AR VSCALE', i+1)

                    key3 = det_img_wave.readKeyword('ESO QC ARC NE POS MEAN', i+1)
                    key4 = det_img_wave.readKeyword('ESO QC ARC NE FWHM MEAN', i+1)
                    ne_vscale = det_img_wave.readKeyword('ESO QC ARC NE VSCALE', i+1)
               
                    # Correct Key2 and Key4 units
                    key2 /= ar_vscale
                    key4 /= ne_vscale

                    # Fill the Argon Data for the scatter plot
                    self.argon_pos_data.append(key1)
                    self.argon_fwhm_data.append(key2)
                    
                    # Fill the Neon Data for the scatter plot
                    self.neon_pos_data.append(key3)
                    self.neon_fwhm_data.append(key4)

            # Read the images
            self.images = []
            for i in range(18):
                if (i+1 < len(det_img_wave.all_hdu)):
                    det_img_wave.readImage(i+1)
                    self.images.append(det_img_wave.image) 

            # Set the plotting functions
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot
        else:
            # Set the plotting functions to NODATA ones
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self) :
        widgets = list()

        # Radio button
        self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.axradiobutton, self.setRadioCallback, self.sorted_angles, 0, title='Angle selection')
        
        widgets.append(self.radiobutton)
        return widgets

    def setRadioCallback(self, label) :
        # Get Extensions as list
        extensions = self.angles_list[float(label)] 
        if (len(extensions) == 3):
            # Setup the image display
            for i in range(3):
                imgdisp = pipeline_display.ImageDisplay()
                imgdisp.setAspect('equal')
                imgdisp.display(self.img_plot[i], "Extension {0}".format(extensions[i]), self._data_plot_get_tooltip(extensions[i]), self.images[extensions[i]-1])
        else:
            print "Wrong number of extensions for this angle"

    def _add_subplots(self, figure):
      
        gs = gridspec.GridSpec(5, 3)
        self.img_plot = []
        for i in range(3):
            self.img_plot.append(figure.add_subplot(gs[0:2,i]))
        self.axradiobutton = figure.add_subplot(gs[3:,0])
        self.argon_plot = figure.add_subplot(gs[3,1:])
        self.neon_plot = figure.add_subplot(gs[4,1:])

    def _data_plot_get_tooltip(self, extension):
        # Create the tooltip
        tooltip = " \
                ESO QC ARC AR POS MEAN : %f \n \
                ESO QC ARC AR FWHM MEAN : %f \n \
                ESO QC ARC NE POS MEAN : %f \n \
                ESO QC ARC NE FWHM MEAN : %f \
        " % (self.argon_pos_data[extension-1],self.argon_fwhm_data[extension-1],self.neon_pos_data[extension-1],self.neon_fwhm_data[extension-1])
        return tooltip

    def _data_plot(self):
        # Get Extensions as list
        extensions = self.angles_list[self.sorted_angles[0]] 
        if (len(extensions) == 3):
            # Setup the image display
            for i in range(3):
                imgdisp = pipeline_display.ImageDisplay()
                imgdisp.setAspect('equal')
                imgdisp.display(self.img_plot[i], "Extension {0}".format(extensions[i]), self._data_plot_get_tooltip(extensions[i]), self.images[extensions[i]-1])
        else:
            print "Wrong number of extensions for this angle"

        # Define x
        size = len(self.argon_pos_data)
        x = numpy.linspace(1, size, num=size)

        # Plot Argon and Neon plots
        scadsp = pipeline_display.ScatterDisplay()
        scadsp.display(self.argon_plot, "Argon Lines", "Mean Argon line positional offset, with error bars equal to the FWH", x, self.argon_pos_data,
                self.argon_fwhm_data)

        scadsp = pipeline_display.ScatterDisplay()
        scadsp.setLabels('File Extensions 1 -> 18','Line Pos. (pix)')
        scadsp.display(self.neon_plot, "Neon Lines", "Mean Neon line positional offset, with error bars equal to the FWH", x, self.neon_pos_data,
                self.neon_fwhm_data)

    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1,1,1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found. Input files should contain this" \
                       " type:\n%s" % self.det_img_wave_cat
        self.img_plot.text(0.1, 0.6, text_nodata, color='#11557c',
                      fontsize=18, ha='left', va='center', alpha=1.0)
        self.img_plot.tooltip = 'No data found'

#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    #Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        #Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
