# Copyright (C) 2018 The Regents of the University of California, Michael Ludkovski and Aditya Maheshwari
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)

from __future__ import division
import numpy as np
import math
import matplotlib.pyplot as plt
import time

import StOptGrids
import StOptReg
import StOptGeners

import microGrid.simulateState as simX
import microGrid.valueNext as fv


# returns the value function at the final step. Depending upon the type of RMC method, the size of the matrix may be different.
def finalStepValue(param,demand,inventory):

	return fv.finalValue(param,demand,inventory)


def createXMatrix(param,asset, inventory=[]):

	if param.rmctype == 'regress now 2D':
		xMatrix = np.zeros((2,len(asset)))
		xMatrix[0,:] = asset
		xMatrix[1,:] = inventory
		return xMatrix

	if param.rmctype == 'gd':
		xMatrix = np.zeros((1,len(asset)))
		xMatrix[0,:] = asset
		return xMatrix



def storageCalculation(param):

	# for regress now 2D, a residual demand process and fixed inventory vector at every time step
	# for gd,  a price process and fixed inventory grid at every time step 
	X0 = np.linspace(-param.ampl,param.ampl,param.nbsimulOpt)
	residualDemand,inventory = simX.residualDemand(param, param.nstep, param.nbsimulOpt, param.maturity, X0, param.rmctype)
	inventory[0] = 0

	if param.rmctype == 'gd':

		stepI = param.I_minMax[1]/param.meshI										
		lowValues =np.array([param.I_minMax[0] , 0.],dtype=np.float64)				
		step = np.array([stepI , 1],dtype=np.float64)  								
		nbStep = np.array([param.meshI , param.H], dtype=np.int32)					
		grid = StOptGrids.RegularSpaceGrid(lowValues,step,nbStep)		

	if param.rmctype == 'regress now 2D':
		
		lowValues =np.array([0.],dtype=np.float64)				            # low value for the mesh
		step = np.array([1],dtype=np.float64)  								# size of the mesh
		nbStep = np.array([param.H], dtype=np.int32)						# number of step
		grid = StOptGrids.RegularSpaceGrid(lowValues,step,nbStep)	

	# Value at the final step. 
	valueNext = finalStepValue(param,residualDemand[:,-1],inventory)

	archiveToWrite = StOptGeners.BinaryFileArchive(param.filetoDump,"w")


	# iterate on time steps
	for iSteps in range(param.nstep-1,-1,-1):

		print("iSteps", iSteps )
		demand = residualDemand[:,iSteps]

		# rmc type and regression method type. The choice is made in the parameter class. 
		if param.rmctype == 'gd':
			nbMesh = np.array([param.meshX],dtype=np.int32)
			regressor = StOptReg.LocalLinearRegression(False,createXMatrix(param,demand),nbMesh)

		elif param.rmctype == 'regress now 2D':
			if param.regType == 'piecewiseLinear':
				nbMesh = np.array([param.meshX, param.meshI],dtype=np.int32)
				regressor = StOptReg.LocalLinearRegression(False,createXMatrix(param,demand, inventory),nbMesh)

			if param.regType == 'globalPolynomial':
				regressor = StOptReg.GlobalCanonicalRegression(False,createXMatrix(param,demand, inventory),param.degree)

			if param.regType == 'kernel':
				regressor = StOptReg.LocalGridKernelRegression(False,createXMatrix(param,demand, inventory),param.bandwidth,param.factPoint, True)

		# Dump the continuation values in the archive:
		archiveToWrite.dumpGridAndRegressedValue("toStore", param.nstep - iSteps, [valueNext], regressor,grid)

		# Read the regressed values
		archiveToRead =  StOptGeners.BinaryFileArchive(param.filetoDump,"r")
		contValues = archiveToRead.readGridAndRegressedValue(param.nstep - iSteps,"toStore")


		valueNext, control_d, StOutput = fv.optimization(contValues[0], demand, inventory, param)


		#######################################################################################################
		# uncomment below to observe following graphs:
		# 1) Cost as a function of residual demand at differnet levels of the inventory.
		# 2) Cost as a function of inventory at differnet levels of the residual demand
		# 3) Control strategy when the generator is ON and when it is OFF.
		# if (iSteps == 0):

		# 	demandCheck = demand #np.tile(demand,param.meshI)
		# 	inventoryCheck = inventory #np.repeat(inventory,param.meshX)

		#  	for q in range(2):
		#  		# plt.figure(q+1)
		#  		plt.scatter(demandCheck,valueNext[:,q],c=inventoryCheck, lw=0, s=10)
		#  		plt.xlabel("demand")
		#  		plt.ylabel("cost")
		#  		name = "q="+str(q)
		#  		plt.title(name)
		#  		plt.colorbar()
		#  		filename = 'q'+str(q)+'_iSteps'+str(iSteps)+'_costDemand.png'
		#  		plt.savefig(filename)
		#  		plt.close()


		#  	for q in range(2):
		#  		# plt.figure(q+1)
		#  		plt.scatter(inventoryCheck,valueNext[:,q],c=demandCheck, lw=0, s=10)
		#  		plt.xlabel("inventory")
		#  		plt.ylabel("cost")
		#  		#plt.ylim(0,1000)
		#  		name = "q="+str(q)
		#  		plt.title(name)
		#  		plt.colorbar()
		#  		filename = 'q'+str(q)+'_iSteps'+str(iSteps)+'_costinventory.png'
		#  		plt.savefig(filename)
		#  		plt.close()

		#  	for q in range(2):
		#  		plt.scatter(demandCheck,inventoryCheck,c=control_d[:,q],lw=0, s=10)
		#  		plt.xlabel('demand')
		#  		plt.ylabel('inventory')
		#  		name = "q="+str(q)
		#  		plt.title(name)
		#  		plt.colorbar()
		#  		filename = 'q'+str(q)+'_iSteps'+str(iSteps)+'_controlMap.png'
		#  		plt.savefig(filename)
		#  		plt.close()

