# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np


# Defines a simple gas storage for optimization and simulation
# No constraint on the storage at the end of optimization period (so the storage will be empty)
class OptimizeGasStorageSwitchingCost :
    
    # Constructor
    # p_maxLevel          size of the storage
    # p_injectionRate     injection rate per time step
    # p_withdrawalRate    withdrawal rate between two time steps
    # p_injectionCost     injection cost
    # p_withdrawalCost    withdrawal cost
    def __init__(self, p_injectionRate, p_withdrawalRate, p_injectionCost, p_withdrawalCost, p_switchCost,  p_regime) :
        
        self.m_injectionRate = p_injectionRate
        self.m_withdrawalRate = p_withdrawalRate
        self.m_injectionCost = p_injectionCost
        self.m_withdrawalCost = p_withdrawalCost
        self.m_switchCost = p_switchCost
        self.m_regime = p_regime
        
    # define the diffusion cone for parallelism
    # p_regionByProcessor         region (min max) treated by the processor for the different regimes treated
    # returns in each dimension the min max values in the stock that can be reached from the grid p_gridByProcessor for each regime
    def getCone(self, p_regionByProcessor) :
        
        extrGrid = np.zeros(1)
        extrGrid[0][0] = p_regionByProcessor[0][0] - self.m_withdrawalRate
        extrGrid[0][1] = p_regionByProcessor[0][1] + self.m_injectionRate
        
        return extrGrid
    
    # get number of regimes
    def getNbRegime(self) :
  
        if isinstance(self.m_regime, int):
            return self.m_regime
        
        else:
            return self.m_regime.get(self.m_simulator.getCurrentStep())
    
    def getNbRegimeReached(self):
        
        if isinstance(self.m_regime, int):
            return self.m_regime
        
        else:
            return self.m_regime.get(self.m_simulator.getCurrentStep() + self.m_simulator.getStep())
        
    def getNbControl(self):
        
        return self.getNbRegime()
    
    # defines a step in optimization
    # p_grid      grid at arrival step after command
    # p_stock     coordinate of the stock point to treat
    # p_condEsp   conditional expectation operator
    # p_phiIn     for each regime  gives the solution calculated at the previous step ( next time step by Dynamic Programming resolution)
    # for each regimes (column) gives the solution for each particle (row)
    def stepOptimize(self, p_grid, p_stock, p_condEsp, p_phiIn) :
        
        # number of regime  allowed at the beginning of the time step
        nbReg = self.getNbRegime()
        actuStep= self.m_simulator.getActuStep()
        # number of regimes reached to test
        nbRegReached = self.getNbRegimeReached()
        nbSimul = self.m_simulator.getNbSimul()
        solutionAndControl = list(range(2))
        solutionAndControl[0] = np.zeros((nbSimul, nbReg))
        solutionAndControl[1] = np.zeros((nbSimul, nbReg))
        # Spot price
        spotPrice = self.m_simulator.fromParticlesToSpot(self.m_simulator.getParticles())
        
        # size of the stock
        maxStorage = p_grid.getExtremeValues()[0][1]
        # injection
        injectionMax = min(maxStorage - p_stock[0], self.m_injectionRate)
        # level min of the stock
        minStorage = p_grid.getExtremeValues()[0][0]
        withdrawalMax = min(p_stock[0] - minStorage, self.m_withdrawalRate)
        
        if (injectionMax < (0. - 1e3 * 2.220446049250313e-16)) & (withdrawalMax < (0. - 1e3 * 2.220446049250313e-16)):
            # not an admissible point
            solutionAndControl[0].setConstant(-1.e30)
            solutionAndControl[1].setConstant(0.)
            
            return solutionAndControl
        
        # Suppose that non injection and no withdrawal
        if p_grid.isInside(p_stock):
            # create interpolator at current stock point
            interpolatorCurrentStock = p_grid.createInterpolator(p_stock)
            # cash flow at current stock and previous step
            cashSameStock = interpolatorCurrentStock.applyVec(p_phiIn[0])
            # conditional expectation at current stock point
            condExpSameStock = actuStep * p_condEsp[0].getAllSimulations(interpolatorCurrentStock).squeeze()
            
        cashInjectionStock = []
        
        if (0. < (injectionMax - 1e3 * 2.220446049250313e-16)) & (nbRegReached >= 2):
            injectionStock = p_stock + injectionMax
            # interpolator for stock level if injection
            interpolatorInjectionStock = p_grid.createInterpolator(injectionStock)
            # cash flow  at previous step at injection level
            cashInjectionStock = interpolatorInjectionStock.applyVec(p_phiIn[1]).squeeze()
            # conditional expectation at injection stock level
            condExpInjectionStock = actuStep * p_condEsp[0].getAllSimulations(interpolatorInjectionStock).squeeze()
            
        cashWithdrawalStock = []
        
        # withdrawal
        if (0. < (withdrawalMax - 1e3 * 2.220446049250313e-16)) & (nbRegReached == 3): 
            withdrawalStock = p_stock - withdrawalMax
            # interpolator for stock level if withdrawal
            interpolatorWithdrawalStock = p_grid.createInterpolator(withdrawalStock)
            # cash flow  at previous step at injection level
            cashWithdrawalStock = interpolatorWithdrawalStock.applyVec(p_phiIn[2]).squeeze()
            # conditional expectation at withdrawal stock level
            condExpWithdrawalStock = actuStep * p_condEsp[0].getAllSimulations(interpolatorWithdrawalStock).squeeze()
            
        gainInjection = np.zeros((len(spotPrice), 3))
        gainWithdrawal = np.zeros((len(spotPrice), 3))
        gainSameStock = np.zeros((len(spotPrice), 3))
        # suppose that current regime is 0 (Do Nothing)
        gainInjection[:,0] =  - injectionMax * (spotPrice + self.m_injectionCost) - self.m_switchCost
        gainWithdrawal[:,0] =  withdrawalMax * (spotPrice - self.m_withdrawalCost) - self.m_switchCost
        # Regime 1) : injection
        gainInjection[:,1] =  - injectionMax * (spotPrice + self.m_injectionCost)
        gainWithdrawal[:,1] =  withdrawalMax * (spotPrice - self.m_withdrawalCost) - self.m_switchCost
        gainSameStock[:,1] = np.zeros(len(gainSameStock[:,1])) - self.m_switchCost
        # Regime 2) : withdrawal
        gainInjection[:,2] =  - injectionMax * (spotPrice + self.m_injectionCost) - self.m_switchCost
        gainWithdrawal[:,2] =  withdrawalMax * (spotPrice - self.m_withdrawalCost)
        gainSameStock[:,2] = np.zeros(len(gainSameStock[:,2])) - self.m_switchCost
        
        # now arbitrage in each regime
        for iReg in range(nbReg):
            
            if ((len(cashInjectionStock) > 0) & (len(cashWithdrawalStock) > 0)):
                
                if nbRegReached == 3:
                    # do the arbitrage
                    solutionAndControl[0][:,iReg] = gainSameStock[:,iReg] + np.multiply(actuStep, cashSameStock).transpose()
                    solutionAndControl[1][:,iReg] = 0.
                    espCondMax = condExpSameStock
                    espCondInjection = gainInjection[:,iReg] + condExpInjectionStock
                    
                    solutionAndControl[0][:,iReg] = np.where(espCondInjection > espCondMax, gainInjection[:,iReg] + actuStep * cashInjectionStock, solutionAndControl[0][:,iReg])
                    solutionAndControl[1][:,iReg] = np.where(espCondInjection > espCondMax, injectionMax, solutionAndControl[1][:,iReg])
                    espCondMax = np.where(espCondInjection > espCondMax, espCondInjection, espCondMax)
                    
                    espCondWithdrawal = gainWithdrawal[:,iReg] + condExpWithdrawalStock
                    
                    solutionAndControl[0][:,iReg] = np.where(espCondWithdrawal > espCondMax, gainWithdrawal[:,iReg] + actuStep * cashWithdrawalStock, solutionAndControl[0][:,iReg])
                    solutionAndControl[1][:,iReg] = np.where(espCondWithdrawal > espCondMax, - withdrawalMax, solutionAndControl[1][:,iReg])
                                    
                elif nbRegReached == 2:
                    # do the arbitrage
                    solutionAndControl[0][:,iReg] = gainSameStock[:,iReg] + np.multiply(actuStep, cashSameStock).transpose()
                    solutionAndControl[1][:,iReg] = 0.
                    espCondMax = condExpSameStock
                    espCondInjection = gainInjection[:,iReg] + condExpInjectionStock
                    
                    solutionAndControl[0][:,iReg] = np.where(espCondInjection > espCondMax, gainInjection[:,iReg] + actuStep * cashInjectionStock, solutionAndControl[0][:,iReg])
                    solutionAndControl[1][:,iReg] = np.where(espCondInjection > espCondMax, injectionMax, solutionAndControl[1][:,iReg])
                        
                else:
                    solutionAndControl[0][:,iReg] = gainSameStock[:,iReg] + np.multiply(actuStep, cashSameStock).transpose()
                    solutionAndControl[1][:,iReg] = 0.
                    
            elif len(cashWithdrawalStock) > 0:
                
                if len(cashSameStock) > 0:
                    
                    if nbRegReached == 3:
                        # do the arbitrage
                        solutionAndControl[0][:,iReg] = np.multiply(actuStep, cashSameStock).transpose()
                        solutionAndControl[1][:,iReg] = 0.
                        espCondMax = condExpSameStock - self.m_switchCost
                        espCondWithdrawal = gainWithdrawal[:,iReg] + condExpWithdrawalStock
                        
                        solutionAndControl[0][:,iReg] = np.where(espCondWithdrawal > espCondMax, gainWithdrawal[:,iReg] + actuStep * cashWithdrawalStock, solutionAndControl[0][:,iReg])
                        solutionAndControl[1][:,iReg] = np.where(espCondWithdrawal > espCondMax, - withdrawalMax, solutionAndControl[1][:,iReg])
                        
                    else:
                        solutionAndControl[0][:,iReg] = np.multiply(actuStep, cashSameStock).transpose()
                        solutionAndControl[1][:,iReg] = 0.
                        
                else:
                    
                    if nbRegReached == 3:
                        solutionAndControl[0][:,iReg] = gainWithdrawal[:,iReg] + np.multiply(actuStep, cashWithdrawalStock).transpose()
                        solutionAndControl[1][:,iReg] = - withdrawalMax
                    
                    else:
                        solutionAndControl[0][:,iReg].setConstant(-1.e30)
                        solutionAndControl[1][:,iReg] = 0.
                        
            elif len(cashInjectionStock) > 0:
                
                if len(cashSameStock) > 0:
                    
                    if nbRegReached >= 2:
                        # do the arbitrage
                        solutionAndControl[0][:,iReg] = np.multiply(actuStep, cashSameStock).transpose()
                        solutionAndControl[1][:,iReg] = 0.
                        espCondMax = condExpSameStock
                        espCondInjection = gainInjection[:,iReg] + condExpInjectionStock
                        
                        solutionAndControl[0][:,iReg] = np.where(espCondInjection > espCondMax, gainInjection[:,iReg] + actuStep * cashInjectionStock, solutionAndControl[0][:,iReg])
                        solutionAndControl[1][:,iReg] = np.where(espCondInjection > espCondMax, injectionMax, solutionAndControl[1][:,iReg])
                        
                    else:
                        solutionAndControl[0][:,iReg] = np.multiply(actuStep, cashSameStock).transpose()
                        solutionAndControl[1][:,iReg] = 0.
                        
                else:
                    
                    if nbRegReached >= 2:
                        solutionAndControl[0][:,iReg] = gainInjection[:,iReg] + actuStep * cashInjectionStock
                        solutionAndControl[1][:,iReg] = injectionMax
                        
                    else:
                        solutionAndControl[0][:,iReg].setConstant(-1.e30)
                        solutionAndControl[1][:,iReg] = 0.
                        
            else:
                
                # only same level
                if len(cashSameStock) > 0:
                    solutionAndControl[0][:,iReg] = np.multiply(actuStep, cashSameStock).transpose()
                    solutionAndControl[1][:,iReg] = 0.
                
                else:
                    solutionAndControl[0][:,iReg].setConstant(-1.e30)
                    solutionAndControl[1][:,iReg] = 0.
                    
        return solutionAndControl
    
    # defines a step in simulation
    # Notice that this implementation is not optimal. In fact no interpolation is necessary for this asset.
    # This implementation is for test and example purpose
    # p_grid          grid at arrival step after command
    # p_continuation  defines the continuation operator for each regime
    # p_state         defines the state value (modified)
    # p_phiInOut      defines the value function (modified)
    def stepSimulate(self, p_grid, p_continuation, p_state, p_phiInOut) :

        # actualization
        actuStep = self.m_simulator.getActuStep()
        actu = self.m_simulator.getActu()
        # optimal stock attained
        ptStockMax = np.zeros(p_state.getPtStock())
        # spot price
        spotPrice = self.m_simulator.fromOneParticleToSpot(p_state.getStochasticRealization())
        # if do nothing
        continuationDoNothing = actuStep * p_continuation[0].getValue(p_state.getPtStock(), p_state.getStochasticRealization())
        espCondMax = continuationDoNothing
        # gain to add at current point
        phiAdd = 0
        # size of the stock
        maxStorage = p_grid.getExtremeValues()[0][1]
        # if injection
        injectionMax = min(maxStorage - p_state.getPtStock()[0], self.m_injectionRate)
        gainInjection = -injectionMax * (spotPrice + self.m_injectionCost)
        continuationInjection = actuStep * p_continuation[1].getValue(p_state.getPtStock() + injectionMax, p_state.getStochasticRealization())
        espCondInjection = gainInjection + continuationInjection
        minStorage = p_grid.getExtremeValues()[0][0]
        withdrawalMax = min(p_state.getPtStock()[0] - minStorage, self.m_withdrawalRate)
        gainWithdrawal =  withdrawalMax * (spotPrice - self.m_withdrawalCost)
        continuationWithdrawal = self. m_actu * p_continuation[2].getValue(p_state.getPtStock() - withdrawalMax, p_state.getStochasticRealization())
        espCondWithdrawal = gainWithdrawal + continuationWithdrawal
        
        newRegime = 0
        
        if (p_state.getRegime() == 0):
            espCondInjection -= self.m_switchCost
            espCondWithdrawal -= self.m_switchCost
            
            if (espCondInjection > espCondMax):
                espCondMax = espCondInjection
                phiAdd = gainInjection - self.m_switchCost
                ptStockMax[0] += injectionMax
                newRegime = 1
            
            if (espCondWithdrawal > espCondMax):
                phiAdd = gainWithdrawal - self.m_switchCost
                ptStockMax[0] -= withdrawalMax
                newRegime = 2
            
        elif (p_state.getRegime() == 1):
            espCondMax -= self.m_switchCost
            espCondWithdrawal -= self.m_switchCost
            
            if (espCondInjection > espCondMax):
                espCondMax = espCondInjection
                phiAdd = gainInjection
                ptStockMax[0] += injectionMax
                newRegime = 1
            
            else:
                phiAdd = -self.m_switchCost
                newRegime = 0
            
            if (espCondWithdrawal > espCondMax):
                phiAdd = gainWithdrawal - self.m_switchCost
                ptStockMax[0] -= withdrawalMax
                newRegime = 2
            
        else: # regime withdrawal
            espCondMax -= self.m_switchCost
            espCondInjection -= self.m_switchCost
            
            if (espCondInjection > espCondMax):
                espCondMax = espCondInjection
                phiAdd = gainInjection - self.m_switchCost
                ptStockMax[0] += injectionMax
                newRegime = 1
            
            else:
                phiAdd = -self.m_switchCost
                newRegime = 0
            
            if (espCondWithdrawal > espCondMax):
                phiAdd = gainWithdrawal
                ptStockMax[0] -= withdrawalMax
                newRegime = 2
            
        # for return
        p_state.setPtStock(ptStockMax)
        p_state.setRegime(newRegime)
        p_phiInOut += phiAdd * actu
        
    # Defines a step in simulation using interpolation in controls
    # p_grid          grid at arrival step after command
    # p_control       defines the controls
    # p_state         defines the state value (modified)
    # p_phiInOut      defines the value function (modified): size number of functions to follow
    def stepSimulateControl(self, p_grid, p_control, p_state, p_phiInOut):
        
        # actualization
        actu = self.m_simulator.getActu()
        ptStock = p_state.getPtStock()
        iReg = p_state.getRegime()
        # spot price
        spotPrice = self.m_simulator.fromOneParticleToSpot(p_state.getStochasticRealization())
        # optimal control
        control = p_control[iReg].getValue(p_state.getPtStock(), p_state.getStochasticRealization())
        maxStorage = p_grid.getExtremeValues()[0][1]
        minStorage = p_grid.getExtremeValues()[0][0]
        control = max(min(maxStorage - ptStock[0], control), minStorage - ptStock[0])
        
        if control > 0:
            # already injection
            p_phiInOut[0] -= control * (spotPrice + self.m_injectionCost) * actu
            p_state.setRegime(1)
            
            if iReg != 1:
                p_phiInOut[0] -= self.m_switchCost
                p_state.setRegime(1)
                
        elif control < 0:
            p_phiInOut[0] -= control * (spotPrice - self.m_withdrawalCost) * actu
            
            if iReg != 2:
                p_phiInOut[0] -= self.m_switchCost
                p_state.setRegime(2)
                
        else:
            
            # do nothing
            if iReg != 0:
                p_phiInOut[0] -= self.m_switchCost
                p_state.setRegime(0)
                
        ptStock[0] += control
        p_state.setPtStock(ptStock)
        

    # store the simulator
    def setSimulator(self, p_simulator):
        
        self.m_simulator = p_simulator
        
    # get the simulator back
    def getSimulator(self):
        
        return self.m_simulator
    
    # get size of the  function to follow in simulation
    def getSimuFuncSize(self):
        
        return 1
