# Copyright (C) 2018 The Regents of the University of California, Michael Ludkovski and Aditya Maheshwari
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)

from __future__ import division
import numpy as np
import math
import matplotlib.pyplot as plt
import microgridDEA.parameters as bv
import time

def penalty(param,inventory):
    return ((inventory - param.I0)**2)*0

def finalValue(param,regParams,demand,inventory):

    if regParams.rmctype == 'regress now 2D':
        nsim = len(demand)
        value = np.zeros((nsim,param.H+1))
        for q in range(param.H+1):
            value[:,q] = penalty(param,inventory)
        return value
    
    elif regParams.rmctype == 'gd':
        nsim = len(demand)
        gridI = len(inventory)

        value = np.zeros((nsim,gridI*(param.H+1)))
        for i in range(gridI):
            for q in range(param.H+1):
                value[:,q*gridI + i] = penalty(param,inventory[i])
        return value

def continuationVal(contValObject,x0,i0,q,nextInventory,nextq, type):

    if type=='regress now 2D':
        return contValObject.getValue([nextq],[x0,nextInventory])

    if type=='gd':
        return contValObject.getValue([nextInventory, nextq],[x0])


def currentCost(param,currentSt, control, switchCost):

    # assuming the cost function is of the type 
    # c1*d^a + c2*1{st>0} + c3*st*1{st<0}    
    cost = param.c1*(control**param.a)*param.dt + (param.c2*(currentSt if currentSt>0 else 0) + param.c3*(-currentSt if currentSt<0 else 0))*param.dt    
    if switchCost == "No":
        return cost
    else:
        return cost + param.K



def calculateCost(param,currentSt, contValObject, x0,i0,q,control,i1,q1,ohc,switchCost="No"):
    return currentCost(param,currentSt, control, switchCost) + max(continuationVal(contValObject,x0,i0,q,i1,q1,'gd'),0)




# one step optimization
def findOptimalControl(x0, i0, q, contValObject, param,regParams):

    B_max = param.B_minMax[1]
    B_min = param.B_minMax[0]

    I_max = param.I_minMax[1]
    I_min = param.I_minMax[0]

    maxOutputBattery = ((i0 - I_min)/param.dt)
    maxInputBattery = - (I_max - i0)/param.dt
    maxOutputBattery =  (B_max if maxOutputBattery>B_max else maxOutputBattery)
    maxInputBattery = (B_min if maxInputBattery<B_min else maxInputBattery)


    possibleControl = np.zeros((1,2))
    possibleControl[:,0] = 0

    possibleControl[:,1] = x0*(x0>0) + np.abs(maxInputBattery)


    demandExContol = x0 - possibleControl; # nparray of shape (1 rows and 2 columns)
    

    St = np.where((demandExContol<=maxOutputBattery) & (demandExContol>=maxInputBattery),0,demandExContol)
    St = np.where((demandExContol>maxOutputBattery),demandExContol - maxOutputBattery,St)
    St = np.where((demandExContol<maxInputBattery),demandExContol - maxInputBattery,St)

    Bt = demandExContol - St

    nextInventory = i0 - Bt*param.dt

    nextq = np.ones_like(demandExContol)
    nextq[:,0] = 0

    contVal = np.zeros_like(demandExContol)

    contVal[:,0] = continuationVal(contValObject,x0,i0,q,nextInventory[0,0],nextq[0,0], regParams.rmctype) #contValObject[0].getYhatOutSample(x0,nextInventory[:,0])
    contVal[:,1] = continuationVal(contValObject,x0,i0,q,nextInventory[0,1],nextq[0,1], regParams.rmctype) #contValObject[1].getYhatOutSample(x0,nextInventory[:,1])

    cost = param.c1*(possibleControl**param.a)*param.dt + param.c2*np.where(St>0, St, 0)*param.dt  + param.c3*np.where(St<0, -St, 0)*param.dt + contVal

    if q==0:
        switchCost = np.zeros_like(demandExContol) + param.K
        switchCost[:,0] = 0
        cost=cost+switchCost


    indx = possibleControl[:,1]<0.000001
    cost[indx,1]=10**11

    indx = np.argmin(cost,1)
    
    return cost[0, indx], possibleControl[0,indx], St[0,indx], nextInventory[0,indx], Bt[0,indx]


def optimization(contValObject, demand, inventory, param, regParamsTminus1, regParamsT):

    regime = np.arange(0,param.H+1,1)
    lc = len(regime)
    nsim = len(demand)

    if regParamsTminus1.rmctype == 'gd':
        gridI = len(inventory)
    elif regParamsTminus1.rmctype == 'regress now 2D':
        gridI = 1

    value = np.zeros((nsim,lc*gridI))
    policy_d = np.zeros((nsim,lc*gridI))
    
    StOuput = np.zeros((nsim,lc*gridI))
    StOuput[:,:]=np.nan

    policy_m = np.zeros((nsim,lc*gridI))
    policy_m[:,:]=None

    # iteration over each sample of residual demand
    for i in range(nsim):

        # iteration over each regime
        for q in regime:
            
            if regParamsTminus1.rmctype == 'gd':
                # if rmctype is grid-discretization (gd) then iterate over grid levels.
                for j in range(gridI):

                    value[i,q*gridI + j], policy_d[i,q*gridI + j], StOuput[i,q*gridI + j], _, _  = findOptimalControl(demand[i], inventory[j], q, contValObject, param, regParamsT)

            elif regParamsTminus1.rmctype == 'regress now 2D':

                value[i,q], policy_d[i,q], StOuput[i,q], _, _  = findOptimalControl(demand[i], inventory[i], q, contValObject, param, regParamsT)


    return value,policy_d, StOuput             


