package tim.prune.function.estimate;

import java.awt.BorderLayout;
import java.awt.Component;
import java.awt.FlowLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.AdjustmentEvent;
import java.awt.event.AdjustmentListener;
import java.awt.event.KeyAdapter;
import java.awt.event.KeyEvent;
import java.util.ArrayList;

import javax.swing.BorderFactory;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JDialog;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollBar;

import tim.prune.App;
import tim.prune.GenericFunction;
import tim.prune.I18nManager;
import tim.prune.config.Config;
import tim.prune.data.DataPoint;
import tim.prune.data.Distance;
import tim.prune.data.RangeStats;
import tim.prune.data.Track;
import tim.prune.data.Unit;
import tim.prune.data.UnitSetLibrary;
import tim.prune.function.estimate.jama.Matrix;
import tim.prune.gui.ProgressDialog;

/**
 * Function to learn the estimation parameters from the current track
 */
public class LearnParameters extends GenericFunction implements Runnable
{
	/** Progress dialog */
	ProgressDialog _progress = null;
	/** Results dialog */
	JDialog _dialog = null;
	/** Calculated parameters */
	private ParametersPanel _calculatedParamPanel = null;
	private EstimationParameters _calculatedParams = null;
	/** Slider for weighted average */
	private JScrollBar _weightSlider = null;
	/** Label to describe position of slider */
	private JLabel _sliderDescLabel = null;
	/** Combined parameters */
	private ParametersPanel _combinedParamPanel = null;
	/** Combine button */
	private JButton _combineButton = null;


	/**
	 * Inner class used to hold the results of the matrix solving
	 */
	static class MatrixResults
	{
		public EstimationParameters _parameters = null;
		public double _averageErrorPc = 0.0; // percentage
	}


	/**
	 * Constructor
	 * @param inApp App object
	 */
	public LearnParameters(App inApp)
	{
		super(inApp);
	}

	/** @return key for function name */
	public String getNameKey() {
		return "function.learnestimationparams";
	}

	/**
	 * Begin the function
	 */
	public void begin()
	{
		// Show progress bar
		if (_progress == null) {
			_progress = new ProgressDialog(_parentFrame, getNameKey());
		}
		_progress.show();
		// Start new thread for the calculations
		new Thread(this).start();
	}

	/**
	 * Run method in separate thread
	 */
	public void run()
	{
		_progress.setMaximum(100);
		// Go through the track and collect the range stats for each sample
		ArrayList<RangeStats> statsList = new ArrayList<RangeStats>(20);
		Track track = _app.getTrackInfo().getTrack();
		final int numPoints = track.getNumPoints();
		final int sampleSize = numPoints / 30;
		int prevStartIndex = -1;
		for (int i=0; i<30; i++)
		{
			int startIndex = i * sampleSize;
			RangeStats stats = getRangeStats(track, startIndex, startIndex + sampleSize, prevStartIndex);
			if (stats != null && stats.getMovingDistanceKilometres() > 1.0
				&& !stats.getTimestampsIncomplete() && !stats.getTimestampsOutOfSequence()
				&& stats.getTotalDurationInSeconds() > 100
				&& stats.getStartIndex() > prevStartIndex)
			{
				// System.out.println("Got stats for " + stats.getStartIndex() + " to " + stats.getEndIndex());
				statsList.add(stats);
				prevStartIndex = stats.getStartIndex();
			}
			_progress.setValue(i);
		}

		// Check if we've got enough samples
		// System.out.println("Got a total of " + statsList.size() + " samples");
		if (statsList.size() < 10)
		{
			_progress.dispose();
			// Show error message, not enough samples
			_app.showErrorMessage(getNameKey(), "error.learnestimationparams.failed");
			return;
		}
		// Loop around, solving the matrices and removing the highest-error sample
		MatrixResults results = reduceSamples(statsList);
		if (results == null)
		{
			_progress.dispose();
			_app.showErrorMessage(getNameKey(), "error.learnestimationparams.failed");
			return;
		}

		_progress.dispose();

		// Create the dialog if necessary
		if (_dialog == null)
		{
			_dialog = new JDialog(_parentFrame, I18nManager.getText(getNameKey()), true);
			_dialog.setLocationRelativeTo(_parentFrame);
			// Create Gui and show it
			_dialog.getContentPane().add(makeDialogComponents());
			_dialog.pack();
		}

		// Populate the values in the dialog
		populateCalculatedValues(results);
		updateCombinedLabels(calculateCombinedParameters());
		_dialog.setVisible(true);
	}


	/**
	 * Make the dialog components
	 * @return the GUI components for the dialog
	 */
	private Component makeDialogComponents()
	{
		JPanel dialogPanel = new JPanel();
		dialogPanel.setLayout(new BorderLayout());

		// main panel with a box layout
		JPanel mainPanel = new JPanel();
		mainPanel.setLayout(new BoxLayout(mainPanel, BoxLayout.Y_AXIS));
		// Label at top
		JLabel introLabel = new JLabel(I18nManager.getText("dialog.learnestimationparams.intro") + ":");
		introLabel.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
		introLabel.setAlignmentX(Component.LEFT_ALIGNMENT);
		mainPanel.add(introLabel);

		// Panel for the calculated results
		_calculatedParamPanel = new ParametersPanel("dialog.estimatetime.results", true);
		_calculatedParamPanel.setAlignmentX(Component.LEFT_ALIGNMENT);
		mainPanel.add(_calculatedParamPanel);
		mainPanel.add(Box.createVerticalStrut(14));

		mainPanel.add(new JLabel(I18nManager.getText("dialog.learnestimationparams.combine") + ":"));
		mainPanel.add(Box.createVerticalStrut(4));
		_weightSlider = new JScrollBar(JScrollBar.HORIZONTAL, 5, 1, 0, 11);
		_weightSlider.addAdjustmentListener(new AdjustmentListener() {
			public void adjustmentValueChanged(AdjustmentEvent inEvent)
			{
				if (!inEvent.getValueIsAdjusting()) {
					updateCombinedLabels(calculateCombinedParameters());
				}
			}
		});
		mainPanel.add(_weightSlider);
		_sliderDescLabel = new JLabel(" ");
		_sliderDescLabel.setAlignmentX(Component.LEFT_ALIGNMENT);
		mainPanel.add(_sliderDescLabel);
		mainPanel.add(Box.createVerticalStrut(12));

		// Results panel
		_combinedParamPanel = new ParametersPanel("dialog.learnestimationparams.combinedresults");
		_combinedParamPanel.setAlignmentX(Component.LEFT_ALIGNMENT);
		mainPanel.add(_combinedParamPanel);

		dialogPanel.add(mainPanel, BorderLayout.NORTH);

		// button panel at bottom
		JPanel buttonPanel = new JPanel();
		buttonPanel.setLayout(new FlowLayout(FlowLayout.RIGHT));

		// Combine
		_combineButton = new JButton(I18nManager.getText("button.combine"));
		_combineButton.addActionListener(new ActionListener() {
			public void actionPerformed(ActionEvent arg0) {
				combineAndFinish();
			}
		});
		buttonPanel.add(_combineButton);

		// Cancel
		JButton cancelButton = new JButton(I18nManager.getText("button.cancel"));
		cancelButton.addActionListener(new ActionListener() {
			public void actionPerformed(ActionEvent e) {
				_dialog.dispose();
			}
		});
		KeyAdapter escapeListener = new KeyAdapter() {
			public void keyPressed(KeyEvent inE) {
				if (inE.getKeyCode() == KeyEvent.VK_ESCAPE) {_dialog.dispose();}
			}
		};
		_combineButton.addKeyListener(escapeListener);
		cancelButton.addKeyListener(escapeListener);
		buttonPanel.add(cancelButton);
		dialogPanel.add(buttonPanel, BorderLayout.SOUTH);
		return dialogPanel;
	}

	/**
	 * Construct a rangestats object for the selected range
	 * @param inTrack track object
	 * @param inStartIndex start index
	 * @param inEndIndex end index
	 * @param inPreviousStartIndex the previously used start index, or -1
	 * @return range stats object or null if required information missing from this bit of the track
	 */
	private RangeStats getRangeStats(Track inTrack, int inStartIndex, int inEndIndex, int inPreviousStartIndex)
	{
		// Check parameters
		if (inTrack == null || inStartIndex < 0 || inEndIndex <= inStartIndex || inStartIndex > inTrack.getNumPoints()) {
			return null;
		}
		final int numPoints = inTrack.getNumPoints();
		int start = inStartIndex;

		// Search forward until a decent track point found for the start
		DataPoint p = inTrack.getPoint(start);
		while (start < numPoints && (p == null || p.isWaypoint() || !p.hasTimestamp() || !p.hasAltitude()))
		{
			start++;
			p = inTrack.getPoint(start);
		}
		if (inPreviousStartIndex >= 0 && start <= (inPreviousStartIndex + 10) // overlapping too much with previous range
			|| (start >= (numPoints - 10))) // starting too late in the track
		{
			return null;
		}

		// Search forward (counting the radians) until a decent end point found
		double movingRads = 0.0;
		final double minimumRads = Distance.convertDistanceToRadians(1.0, UnitSetLibrary.UNITS_KILOMETRES);
		DataPoint prevPoint = inTrack.getPoint(start);
		int endIndex = start;
		boolean shouldStop = false;
		do
		{
			endIndex++;
			p = inTrack.getPoint(endIndex);
			if (p != null && !p.isWaypoint())
			{
				if (!p.hasAltitude() || !p.hasTimestamp()) {return null;} // abort if no time/altitude
				if (prevPoint != null && !p.getSegmentStart()) {
					movingRads += DataPoint.calculateRadiansBetween(prevPoint, p);
				}
			}
			prevPoint = p;
			if (endIndex >= numPoints) {
				shouldStop = true; // reached the end of the track
			}
			else if (movingRads >= minimumRads && endIndex >= inEndIndex) {
				shouldStop = true; // got at least a kilometre
			}
		}
		while (!shouldStop);

		// Check moving distance
		if (movingRads >= minimumRads) {
			return new RangeStats(inTrack, start, endIndex);
		}
		return null;
	}

	/**
	 * Build an A matrix for the given list of RangeStats objects
	 * @param inStatsList list of (non-null) RangeStats objects
	 * @return A matrix with n rows and 5 columns
	 */
	private static Matrix buildAMatrix(ArrayList<RangeStats> inStatsList)
	{
		final Unit METRES = UnitSetLibrary.UNITS_METRES;
		Matrix result = new Matrix(inStatsList.size(), 5);
		int row = 0;
		for (RangeStats stats : inStatsList)
		{
			result.setValue(row, 0, stats.getMovingDistanceKilometres());
			result.setValue(row, 1, stats.getGentleAltitudeRange().getClimb(METRES));
			result.setValue(row, 2, stats.getSteepAltitudeRange().getClimb(METRES));
			result.setValue(row, 3, stats.getGentleAltitudeRange().getDescent(METRES));
			result.setValue(row, 4, stats.getSteepAltitudeRange().getDescent(METRES));
			row++;
		}
		return result;
	}

	/**
	 * Build a B matrix containing the observations (moving times)
	 * @param inStatsList list of (non-null) RangeStats objects
	 * @return B matrix with single column of n rows
	 */
	private static Matrix buildBMatrix(ArrayList<RangeStats> inStatsList)
	{
		Matrix result = new Matrix(inStatsList.size(), 1);
		int row = 0;
		for (RangeStats stats : inStatsList)
		{
			result.setValue(row, 0, stats.getMovingDurationInSeconds() / 60.0); // convert seconds to minutes
			row++;
		}
		return result;
	}

	/**
	 * Look for the maximum absolute value in the given column matrix
	 * @param inMatrix matrix with only one column
	 * @return row index of cell with greatest absolute value, or -1 if not valid
	 */
	private static int getIndexOfMaxValue(Matrix inMatrix)
	{
		if (inMatrix == null || inMatrix.getNumColumns() > 1) {
			return -1;
		}
		int index = 0;
		double currValue = 0.0, maxValue = 0.0;
		// Loop over the first column looking for the maximum absolute value
		for (int i=0; i<inMatrix.getNumRows(); i++)
		{
			currValue = Math.abs(inMatrix.get(i, 0));
			if (currValue > maxValue)
			{
				maxValue = currValue;
				index = i;
			}
		}
		return index;
	}

	/**
	 * See if the given set of samples is sufficient for getting a descent solution (at least 3 nonzero values)
	 * @param inRangeSet list of RangeStats objects
	 * @param inRowToIgnore row index to ignore, or -1 to use them all
	 * @return true if the samples look ok
	 */
	private static boolean isRangeSetSufficient(ArrayList<RangeStats> inRangeSet, int inRowToIgnore)
	{
		int numGC = 0, numSC = 0, numGD = 0, numSD = 0; // number of samples with gentle/steep climb/descent values > 0
		final Unit METRES = UnitSetLibrary.UNITS_METRES;
		int i = 0;
		for (RangeStats stats : inRangeSet)
		{
			if (i != inRowToIgnore)
			{
				if (stats.getGentleAltitudeRange().getClimb(METRES) > 0) {numGC++;}
				if (stats.getSteepAltitudeRange().getClimb(METRES) > 0)  {numSC++;}
				if (stats.getGentleAltitudeRange().getDescent(METRES) > 0) {numGD++;}
				if (stats.getSteepAltitudeRange().getDescent(METRES) > 0)  {numSD++;}
			}
			i++;
		}
		return numGC > 3 && numSC > 3 && numGD > 3 && numSD > 3;
	}

	/**
	 * Reduce the number of samples in the given list by eliminating the ones with highest errors
	 * @param inStatsList list of stats
	 * @return results in an object
	 */
	private MatrixResults reduceSamples(ArrayList<RangeStats> inStatsList)
	{
		int statsIndexToRemove = -1;
		Matrix answer = null;
		boolean finished = false;
		double averageErrorPc = 0.0;
		while (!finished)
		{
			// Remove the marked stats object, if any
			if (statsIndexToRemove >= 0) {
				inStatsList.remove(statsIndexToRemove);
			}

			// Build up the matrices
			Matrix A = buildAMatrix(inStatsList);
			Matrix B = buildBMatrix(inStatsList);
			// System.out.println("Times in minutes are:\n" + B.toString());

			// Solve (if possible)
			try
			{
				answer = A.solve(B);
				// System.out.println("Solved matrix with " + A.getNumRows() + " rows:\n" + answer.toString());
				// Work out the percentage error for each estimate
				Matrix estimates = A.times(answer);
				Matrix errors = estimates.minus(B).divideEach(B);
				// System.out.println("Errors: " + errors.toString());
				averageErrorPc = errors.getAverageAbsValue();
				// find biggest percentage error, remove it from list
				statsIndexToRemove = getIndexOfMaxValue(errors);
				if (statsIndexToRemove < 0)
				{
					System.err.println("Something wrong - index is " + statsIndexToRemove);
					throw new Exception();
				}
				// Check whether removing this element would make the range set insufficient
				finished = inStatsList.size() <= 25 || !isRangeSetSufficient(inStatsList, statsIndexToRemove);
			}
			catch (Exception e)
			{
				// Couldn't solve at all
				System.out.println("Failed to reduce: " + e.getClass().getName() + " - " + e.getMessage());
				return null;
			}
			_progress.setValue(20 + 80 * (30 - inStatsList.size())/5); // Counting from 30 to 25
		}
		// Copy results to an EstimationParameters object
		MatrixResults result = new MatrixResults();
		result._parameters = new EstimationParameters();
		result._parameters.populateWithMetrics(answer.get(0, 0) * 5, // convert from 1km to 5km
			answer.get(1, 0) * 100.0, answer.get(2, 0) * 100.0,      // convert from m to 100m
			answer.get(3, 0) * 100.0, answer.get(4, 0) * 100.0);
		result._averageErrorPc = averageErrorPc;
		return result;
	}


	/**
	 * Populate the dialog's labels with the calculated values
	 * @param inResults results of the calculations
	 */
	private void populateCalculatedValues(MatrixResults inResults)
	{
		if (inResults == null || inResults._parameters == null)
		{
			_calculatedParams = null;
			_calculatedParamPanel.updateParameters(null, 0.0);
		}
		else
		{
			_calculatedParams = inResults._parameters;
			_calculatedParamPanel.updateParameters(_calculatedParams, inResults._averageErrorPc);
		}
	}

	/**
	 * Combine the calculated parameters with the existing ones
	 * according to the value of the slider
	 * @return combined parameters
	 */
	private EstimationParameters calculateCombinedParameters()
	{
		final double fraction1 = 1 - 0.1 * _weightSlider.getValue(); // slider left = value 0 = fraction 1 = keep current
		EstimationParameters oldParams = new EstimationParameters(Config.getConfigString(Config.KEY_ESTIMATION_PARAMS));
		return oldParams.combine(_calculatedParams, fraction1);
	}

	/**
	 * Update the labels to show the combined parameters
	 * @param inCombinedParams combined estimation parameters
	 */
	private void updateCombinedLabels(EstimationParameters inCombinedParams)
	{
		// Update the slider description label
		String sliderDesc = null;
		final int sliderVal = _weightSlider.getValue();
		switch (sliderVal)
		{
			case 0:  sliderDesc = I18nManager.getText("dialog.learnestimationparams.weight.100pccurrent"); break;
			case 5:  sliderDesc = I18nManager.getText("dialog.learnestimationparams.weight.50pc"); break;
			case 10: sliderDesc = I18nManager.getText("dialog.learnestimationparams.weight.100pccalculated"); break;
			default:
				final int currTenths = 10 - sliderVal, calcTenths = sliderVal;
				sliderDesc = "" + currTenths + "0% " + I18nManager.getText("dialog.learnestimationparams.weight.current")
					+ " + " + calcTenths + "0% " + I18nManager.getText("dialog.learnestimationparams.weight.calculated");
		}
		_sliderDescLabel.setText(sliderDesc);
		// And update all the combined params labels
		_combinedParamPanel.updateParameters(inCombinedParams);
		_combineButton.setEnabled(sliderVal > 0);
	}

	/**
	 * React to the combine button, by saving the combined parameters in the config
	 */
	private void combineAndFinish()
	{
		EstimationParameters params = calculateCombinedParameters();
		Config.setConfigString(Config.KEY_ESTIMATION_PARAMS, params.toConfigString());
		_dialog.dispose();
	}
}
