from __future__ import with_statement
from __future__ import absolute_import
from __future__ import print_function
import sys


try:
    import numpy
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    from matplotlib import gridspec, pylab, pyplot
    import pdb  # for debugging
    import_success = True

except ImportError:
    import_success = False
    print("Error importing modules pyfits, wx, matplotlib, numpy")

# Median absolute deviation function; used to scale the images
def MAD(x):
    x=numpy.array(x)
    return numpy.median(numpy.abs(x-numpy.median(x)))


def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)


class DataPlotterManager(object):
    """
    This class must be added to the PipelineInteractiveApp with setPlotManager
    It must have following member functions which will be called by the app:
     - setInteractiveParameters(self)
     - readFitsData(self, fitsFiles):
     - addSubplots(self, figure):
     - plotProductsGraphics(self, figure, canvas)
    Following members are optional:
     - setWindowHelp(self)
     - setWindowTitle(self)
    """

    # static members
    recipe_name = "vimos_ima_bias"
    img_cat = "MASTER_BIAS"
    diff_img_cat = "DIFFIMG_BIAS"
    diff_stats_cat = "DIFFIMG_STATS_BIAS"

    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        """
        This function specifies which are the parameters that should be presented
        in the window to be edited.  Note that the parameter has to also be in the
        in_sop port (otherwise it won't appear in the window). The descriptions are
        used to show a tooltip. They should match one to one with the parameter
        list.
        """
        return [
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="combtype",
                                   group="vimos_ima_bias", description="Combination algorithm. <median | mean> [median]"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="xrej",
                                   group="vimos_ima_bias", description="True if using extra rejection cycle. [TRUE]"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="thresh",
                                   group="vimos_ima_bias", description="Rejection threshold in sigma above background. [5.0]"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="ncells",
                                   group="vimos_ima_bias", description="Number of cells per data channel to evaluate stats. < 1 | 2 | 4 | 8 | 16 | 32 | 64 > [64]")
        ]

    def readFitsData(self, fitsFiles):
        """
        This function should be used to read and organize the raw fits files
        produced by the recipes.
        It receives as input a list of reflex.FitsFiles
        """
        # organize the files into a dictionary, here we assume we only have 
        # one file per category if there are more, one must use a
        # dictionary of lists
        self.frames = dict()
        for f in fitsFiles:
            self.frames[f.category] = PipelineProduct(f)

        # we only have two states, we have data or we don't
        # define the plotting functions we want to use for each

        if self.img_cat in self.frames:
            bias_img = self.frames[self.img_cat]
            self.n_extn = len(bias_img.hdulist())-1  # number of extensions
            self.bias_img_name = bias_img.fits_file.name
            self.bias_img_hdu = bias_img.all_hdu
            self.img_found = True

            # Read the bias images
            self.bias_images = []
            for i in range(self.n_extn):
                bias_img.readImage(i+1)
                self.bias_images.append(bias_img.image) 

            # Read the difference images

            if self.diff_img_cat in self.frames:
                diff_img = self.frames[self.diff_img_cat]
                self.diff_img_name = diff_img.fits_file.name
                self.diff_img_hdu = diff_img.all_hdu[0]
                self.diff_img_found = True
                self.diff_images = []
                for i in range(self.n_extn):
                    diff_img.readImage(i+1)
                    self.diff_images.append(diff_img.image) 
            else:
                self.diff_img_found = False

            # Read in statistics FITS table
            if self.diff_stats_cat in self.frames:
                self.stats_found = True
                self.stats = self.frames[self.diff_stats_cat]

                # stats_table is a list of FITS record arrays, one for each extension
                # access data by field name: stats_table[i_ext]['COLNAME']
                # see help at https://pythonhosted.org/pyfits/users_guide/users_table.html
                self.stats_table = [] 
                
                for i in range(self.n_extn):
                    self.stats_table.append(self.stats.all_hdu[i+1].data)
            else:
                self.stats_found = False

            # re-define plotting functions to enable callbacks
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot

            # Define radio button options as a dict; this is handy for assessing which option the user selected
            if ((self.diff_img_found is True) and (self.stats_found is True)):
                self.radio_button_opts = {'Master BIAS Image':0,'Histogram of BIAS Image':1,'Diff Image (BIAS-REF)':2,'Stats on Diff Image':3}
            else:
                self.radio_button_opts = {'Master BIAS Image':0,'Histogram of BIAS Image':1}

            # Set the initial radio button selections (value 0)
            self.radio_button_label = [key for key, value in iter(self.radio_button_opts.items()) if value == 0][0]

        else:
            # Set the plotting functions to NODATA ones
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self) :
        widgets = list()

        # Radio button
        self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.axradiobutton, 
                                                                       self.setRadioCallback, 
                                                                       [key for key,value in sorted(self.radio_button_opts.items(),key= lambda k: k[1])],
                                                                       self.radio_button_opts.get(self.radio_button_label),
                                                                       title='Select item to display')
        widgets.append(self.radiobutton)

        return widgets

    def setRadioCallback(self, label) :

        # Only do something if user changes the button
        if (label != self.radio_button_label):
            self.radio_button_label = label
            self._plot()

    def _add_subplots(self, figure):
      
        self.img_plot = []
        gs = gridspec.GridSpec(7, 2)
        self.img_plot.append(figure.add_subplot(gs[1:4,0]))
        self.img_plot.append(figure.add_subplot(gs[1:4,1]))
        self.img_plot.append(figure.add_subplot(gs[4:8,0]))
        self.img_plot.append(figure.add_subplot(gs[4:8,1]))
        self.axradiobutton = figure.add_subplot(gs[0,0])
            
    def _data_plot(self):

        #clock_pattern = self.bias_img_hdu[0].header['HIERARCH ESO DET READ CLOCK']
        #binx = self.bias_img_hdu[0].header['HIERARCH ESO DET WIN1 BINX']
        #biny = self.bias_img_hdu[0].header['HIERARCH ESO DET WIN1 BINY']

        for i in range(self.n_extn):
            chip_name = self.bias_img_hdu[i+1].header['EXTNAME']

            imgdisp = pipeline_display.ImageDisplay()
            imgdisp.setAspect('equal')
            if ( (self.radio_button_opts[self.radio_button_label] == 0) or
                 (self.radio_button_opts[self.radio_button_label] == 2)):  # show master bias image or diff image
                
                self.img_plot[i].cla()

                if ( (i==0) or (i==1) ): 
                    pylab.setp(self.img_plot[i].get_xticklabels(), visible = False)
                    imgdisp.setLabels('','Y')
                else:
                    pylab.setp(self.img_plot[i].get_xticklabels(), visible = True)
                    imgdisp.setLabels('X','Y')

                if (self.radio_button_opts[self.radio_button_label] == 0):
                    title = "Master BIAS Chip:{}".format(chip_name)
                    temp_image = self.bias_images[i]
                    tool_tip = "Bias"
                    # Set z-limit using 1 iteration of sigma clipping using MED and MAD
                    temp_image = temp_image[numpy.isfinite(temp_image)]
                    try:
                        med = self.bias_img_hdu[i+1].header['HIERARCH ESO QC BIASMED']
                    except:
                        med = numpy.median(temp_image)
                    try:
                        sig = self.bias_img_hdu[i+1].header['HIERARCH ESO QC BIASRMS']
                    except:
                        sig = 1.48 * MAD(temp_image)

                    temp_image = temp_image[numpy.abs(temp_image-med) < 3*sig]
                    new_med = numpy.median(temp_image)
                    new_sig = 1.48 * MAD(temp_image)
                    
                    # Set limits to median-1sigma, median+3sigma for biases
                    imgdisp.z_lim = new_med-new_sig, new_med+3*new_sig
                    imgdisp.display(self.img_plot[i], title, tool_tip, self.bias_images[i])

                elif (self.radio_button_opts[self.radio_button_label] == 2):
                    title = "Diff Image Chip: {}".format(chip_name)
                    temp_image = self.diff_images[i]
                    tool_tip = "BIAS - REF"
                    temp_image = temp_image[numpy.isfinite(temp_image)]
                    try:
                        med = self.diff_img_hdu[i+1].header['HIERARCH ESO QC BIAS_DIFFMED']
                    except:
                        med = numpy.median(temp_image)
                    try:
                        sig = self.diff_img_hdu[i+1].header['HIERARCH ESO QC BIAS_DIFFRMS']
                    except:
                        sig = 1.48 * MAD(temp_image)

                    temp_image = temp_image[numpy.abs(temp_image-med) < 3*sig]
                    new_med = numpy.median(temp_image)
                    new_sig = 1.48 * MAD(temp_image)
                    
                    # Set limits to median-3sigma, median+3sigma for diff image
                    imgdisp.z_lim = new_med-3*new_sig, new_med+3*new_sig
                    imgdisp.display(self.img_plot[i], title, tool_tip, self.diff_images[i])

            elif (self.radio_button_opts[self.radio_button_label] == 1):  # show histogram
                self.img_plot[i].cla()
                self.img_plot[i].set_title("Histogram  Chip: {}".format(chip_name), fontsize=12, fontweight='semibold')
                pylab.setp(self.img_plot[i].get_xticklabels(), visible = True)

                temp_image = self.bias_images[i]
                x = temp_image[numpy.isfinite(temp_image)]

                try:
                    med = self.bias_img_hdu[i+1].header['HIERARCH ESO QC BIASMED']
                except:
                    med = numpy.median(x)
                try:
                    sig = self.bias_img_hdu[i+1].header['HIERARCH ESO QC BIASRMS']
                except:
                    sig = 1.48 * MAD(x)

                n, bins, patches = self.img_plot[i].hist(x,normed=True,range=(med-5.0*sig,med+5.0*sig))
                self.img_plot[i].axis('tight')  # change aspect ratio to show all data, has to be placed after .hist()
                if ( (i==0) or (i==1)):
                    self.img_plot[i].set_xlabel('')
                if ( (i==2) or (i==3)):
                    self.img_plot[i].set_xlabel('Pixel Value [ADU]')
                self.img_plot[i].set_ylabel('Normalised PDF')

                self.img_plot[i].tooltip = '10 bins over a range Median'+u"\u00B1"+'5*sigma'
                self.img_plot[i].text(0.05,0.9,'Med:  {:8.2f}'.format(med), transform=self.img_plot[i].transAxes)
                self.img_plot[i].text(0.05,0.8,'Mean: {:8.2f}'.format(numpy.mean(x)), transform=self.img_plot[i].transAxes)
                self.img_plot[i].text(0.05,0.7,'MAD:  {:8.2f}'.format(sig/1.48), transform=self.img_plot[i].transAxes)
                self.img_plot[i].text(0.05,0.6,'RMS:  {:8.2f}'.format(numpy.std(x)), transform=self.img_plot[i].transAxes)

            elif (self.radio_button_opts[self.radio_button_label] == 3):  # show statistics
                self.img_plot[i].cla()
                x = numpy.linspace(1,len(self.stats_table[0]['xmin']), num = len(self.stats_table[0]['xmin']))
                y = self.stats_table[i]['median']
                err = 1.483*(self.stats_table[i]['mad'])

                title = "Stats {} ".format(chip_name)

                tool_tip = ("X axis: Index of small cell/box on chip \nY axis: Median value of (BIAS-REF) pixels in box \n\t (Err bars = 1.48*Median Abs Deviation)")

                scatter_display = pipeline_display.ScatterDisplay()
                self.img_plot[i].axis('tight')  # change aspect ratio to show all data
                if (min(y) != max(y)):
                    self.img_plot[i].set_ylim(min(y),max(y))
                else:
                    self.img_plot[i].set_ylim(min(y)-2*max(err),min(y)+2*max(err))
                if (max(x) != max(x)):
                    self.img_plot[i].set_xlim(min(x),max(x))
                else:
                     self.img_plot[i].set_xlim(min(x)-1,max(x)+1)

                if ( (i==0) or (i==1) ): 
                    pylab.setp(self.img_plot[i].get_xticklabels(), visible = False)
                    xtitle = ""
                else:
                    pylab.setp(self.img_plot[i].get_xticklabels(), visible = True)
                    xtitle = "Index of Cell on Chip"
                    
                scatter_display.setLabels(xtitle,"BIAS - REF")
                scatter_display.display(self.img_plot[i],
                                        title, tool_tip, 
                                        x, y, err)
                #pdb.set_trace()
                
    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1,1,1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found. Input files should contain this" \
                       " type:\n%s" % self.img_cat
        self.img_plot.text(0.1, 0.6, text_nodata, color='#11557c',
                      fontsize=18, ha='left', va='center', alpha=1.0)
        self.img_plot.tooltip = 'No data found'


    def setWindowHelp(self):
      help_text = """
This is an interactive window which help asses the quality of the execution of a recipe.
"""
      return help_text


#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp()

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    #Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        #Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
