# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import StOptGrids 
import StOptReg
import numpy as np



# Defines a simple gas storage for optimization and simulation
# No constraint on the storage at the end of optimization period (so the storage will be empty)
class OptimizeGasStorage :
    
    # Constructor
    # p_maxLevel          size of the storage
    # p_injectionRate     injection rate per time step
    # p_withdrawalRate    withdrawal rate between two time steps
    # p_injectionCost     injection cost
    # p_withdrawalCost    withdrawal cost
    def __init__(self, p_injectionRate, p_withdrawalRate, p_injectionCost, p_withdrawalCost) :
        
        self.m_injectionRate = p_injectionRate
        self.m_withdrawalRate = p_withdrawalRate
        self.m_injectionCost = p_injectionCost
        self.m_withdrawalCost = p_withdrawalCost
        
    # define the diffusion cone for parallelism
    # p_regionByProcessor         region (min max) treated by the processor for the different regimes treated
    # returns in each dimension the min max values in the stock that can be reached from the grid p_gridByProcessor for each regime
    def getCone(self, p_regionByProcessor) :
        
        extrGrid = np.zeros(1)
        extrGrid[0] = p_regionByProcessor[0] - self.m_withdrawalRate
        extrGrid[1] = p_regionByProcessor[1] + self.m_injectionRate
        
        return extrGrid
    
        
    # defines a step in optimization
    # p_grid      grid at arrival step after command
    # p_stock     coordinate of the stock point to treat
    # p_condEsp   conditional expectation operator
    # p_phiIn     for each regime  gives the solution calculated at the previous step ( next time step by Dynamic Programming resolution)
    # for each regimes (column) gives the solution for each particle (row)
    def stepOptimize(self, p_grid, p_stock, p_condEsp, p_phiIn) :
        
        nbSimul = self.m_simulator.getNbSimul()
        actuStep = self.m_simulator.getActuStep()
        solutionAndControl = list(range(2))
        solutionAndControl[0] = np.zeros((nbSimul, 1))
        solutionAndControl[1] = np.zeros((nbSimul, 1))
        # Spot price
        spotPrice = self.m_simulator.fromParticlesToSpot(self.m_simulator.getParticles())
        # level if injection
        # size of the stock
        maxStorage = p_grid.getExtremeValues()[0][1]
        injectionMax = min(maxStorage - p_stock[0], self.m_injectionRate)
        injectionStock = p_stock + injectionMax
        # level if withdrawal
        # level min of the stock
        minStorage = p_grid.getExtremeValues()[0][0]
        withdrawalMax = min(p_stock[0] - minStorage, self.m_withdrawalRate)
        withdrawalStock = p_stock - withdrawalMax
        
        if (injectionMax < -1e-10):
            solutionAndControl[0].fill(-1.e30)
            solutionAndControl[1].fill(0.)
            
            # not an admissible point
            return solutionAndControl
        
        condExpSameStock = []
        
        # Suppose that non injection and no withdrawal
        #############################################
        # create interpolator at current stock point
        if p_grid.isInside(p_stock):
            interpolatorCurrentStock = p_grid.createInterpolator(p_stock)
            # cash flow at current stock and previous step
            cashSameStock = interpolatorCurrentStock.applyVec(p_phiIn[0])
            # conditional expectation at current stock point
            condExpSameStock = actuStep * p_condEsp[0].getAllSimulations(interpolatorCurrentStock).squeeze()
            
        # injection
        ###########
        gainInjection = []
        
        if (injectionMax > 0):
            # interpolator for stock level if injection
            interpolatorInjectionStock = p_grid.createInterpolator(injectionStock)
            # cash flow  at previous step at injection level
            cashInjectionStock = interpolatorInjectionStock.applyVec(p_phiIn[0]).squeeze()
            # conditional expectation at injection stock level
            condExpInjectionStock = actuStep * p_condEsp[0].getAllSimulations(interpolatorInjectionStock).squeeze()
            # instantaneous gain if injection
            gainInjection = -injectionMax * (spotPrice + self.m_injectionCost)
        
        # withdrawal
        ############
        gainWithdrawal = []
        
        if (withdrawalMax >0):
            # interpolator for stock level if withdrawal
            interpolatorWithdrawalStock = p_grid.createInterpolator(withdrawalStock)
            # cash flow  at previous step at injection level
            cashWithdrawalStock = interpolatorWithdrawalStock.applyVec(p_phiIn[0]).squeeze()
            # conditional expectation at withdrawal stock level
            condExpWithdrawalStock = actuStep * p_condEsp[0].getAllSimulations(interpolatorWithdrawalStock).squeeze()
            # instantaneous gain if withdrawal
            gainWithdrawal = withdrawalMax * (spotPrice - self.m_withdrawalCost)
        
        # do the arbitrage
        ##################
        if (len(gainWithdrawal) > 0) & (len(gainInjection) > 0):
            # all point admissible
            solutionAndControl[0][:,0] = np.multiply(actuStep, cashSameStock).transpose()
            solutionAndControl[1][:,0] = 0.
            espCondInjection = gainInjection + condExpInjectionStock
            espCondInjectionSuperiorToCondExpSameStock = espCondInjection > condExpSameStock
            
            solutionAndControl[0][:,0] = np.where(espCondInjectionSuperiorToCondExpSameStock, gainInjection + actuStep * cashInjectionStock, solutionAndControl[0][:,0])
            solutionAndControl[1][:,0] = np.where(espCondInjectionSuperiorToCondExpSameStock, injectionMax,  solutionAndControl[1][:,0])
            condExpSameStock = np.where(espCondInjectionSuperiorToCondExpSameStock, espCondInjection, condExpSameStock)
            
            espCondWithdrawal = gainWithdrawal + condExpWithdrawalStock
            
            solutionAndControl[0][:,0] = np.where(espCondWithdrawal > condExpSameStock, gainWithdrawal + actuStep * cashWithdrawalStock, solutionAndControl[0][:,0])
            solutionAndControl[1][:,0] = np.where(espCondWithdrawal > condExpSameStock, - withdrawalMax,  solutionAndControl[1][:,0])
            
        elif len(gainWithdrawal) > 0:
            
            if len(condExpSameStock) > 0:
                solutionAndControl[0][:,0] = np.multiply(actuStep, cashSameStock).transpose()
                solutionAndControl[1][:,0] = 0.
                espCondWithdrawal = gainWithdrawal + condExpWithdrawalStock
                
                solutionAndControl[0][:,0] = np.where(espCondWithdrawal > condExpSameStock, gainWithdrawal + actuStep * cashWithdrawalStock, solutionAndControl[0][:,0])
                solutionAndControl[1][:,0] = np.where(espCondWithdrawal > condExpSameStock, - withdrawalMax, solutionAndControl[1][:,0])
                    
            else:
                solutionAndControl[0][:,0] = gainWithdrawal + actuStep * cashWithdrawalStock
                solutionAndControl[1][:,0] = - withdrawalMax
                
        elif len(gainInjection) > 0:
            
            if len(condExpSameStock) > 0:
                solutionAndControl[0][:,0] = np.multiply(actuStep, cashSameStock).transpose()
                solutionAndControl[1][:,0] = 0.
                espCondInjection = gainInjection + condExpInjectionStock
                espCondInjectionSuperiorToCondExpSameStock = espCondInjection > condExpSameStock
                
                solutionAndControl[0][:,0] = np.where(espCondInjectionSuperiorToCondExpSameStock, gainInjection + actuStep * cashInjectionStock, solutionAndControl[0][:,0])
                solutionAndControl[1][:,0] = np.where(espCondInjectionSuperiorToCondExpSameStock, injectionMax, solutionAndControl[1][:,0])
                
            else:
                solutionAndControl[0][:,0] = gainInjection + actuStep * cashInjectionStock
                solutionAndControl[1][:,0] = injectionMax
                
        return solutionAndControl
        
    # get number of regimes
    def getNbRegime(self) :
  
        return 1
    
    # number of controls
    def getNbControl(self):
        
        return 1
    
    # defines a step in simulation
    # Notice that this implementation is not optimal. In fact no interpolation is necessary for this asset.
    # This implementation is for test and example purpose
    # p_grid          grid at arrival step after command
    # p_continuation  defines the continuation operator for each regime
    # p_state         defines the state value (modified)
    # p_phiInOut      defines the value function (modified)
    def stepSimulate(self, p_grid, p_continuation, p_state, p_phiInOut) :
        actuStep = self.m_simulator.getActuStep()
        actu = self.m_simulator.getActu()
        # optimal stock attained
        ptStockCur = p_state.getPtStock()
        control = np.zeros_like(ptStockCur)
        # spot price
        spotPrice = self.m_simulator.fromOneParticleToSpot(p_state.getStochasticRealization())
        espCondMax = -1.e30
        
        if p_grid.isInside(ptStockCur):
            continuationDoNothing = actuStep * p_continuation[0].getValue(p_state.getPtStock(), p_state.getStochasticRealization())
            espCondMax = continuationDoNothing
            
        # gain to add at current point
        phiAdd = 0
        # size of the stock
        maxStorage = p_grid.getExtremeValues()[0][1]
        # if injection
        injectionMax = min(maxStorage - p_state.getPtStock()[0], self.m_injectionRate)
        
        if injectionMax>0:
            continuationInjection = actuStep * p_continuation[0].getValue(p_state.getPtStock() + injectionMax, p_state.getStochasticRealization())
            gainInjection = - injectionMax * (spotPrice + self.m_injectionCost)
            espCondInjection = gainInjection + continuationInjection
            
            if espCondInjection > espCondMax :
                espCondMax = espCondInjection
                phiAdd = gainInjection
                control[0] = injectionMax
                
        # if withdrawal
        # level min of the stock
        minStorage = p_grid.getExtremeValues()[0][0]
        withdrawalMax = min(p_state.getPtStock()[0] - minStorage, self.m_withdrawalRate)
        
        if withdrawalMax >0:
            gainWithdrawal = withdrawalMax * (spotPrice - self.m_withdrawalCost)
            continuationWithdrawal = actuStep * p_continuation[0].getValue(p_state.getPtStock() - withdrawalMax, p_state.getStochasticRealization())
            espCondWithdrawal = gainWithdrawal + continuationWithdrawal
            if espCondWithdrawal > espCondMax :
                espCondMax = espCondWithdrawal
                phiAdd = gainWithdrawal
                control[0] = -withdrawalMax
        # for return
        p_state.setPtStock(ptStockCur+control)
        p_phiInOut += phiAdd * actu
         
    # Defines a step in simulation using interpolation in controls
    # p_grid          grid at arrival step after command
    # p_control       defines the controls
    # p_state         defines the state value (modified)
    # p_phiInOut      defines the value function (modified): size number of functions to follow
    def stepSimulateControl(self, p_grid, p_control, p_state, p_phiInOut):
        actu = self.m_simulator.getActu()
        ptStock = p_state.getPtStock()
        # spot price
        spotPrice = self.m_simulator.fromOneParticleToSpot(p_state.getStochasticRealization())
        # optimal control
        control = p_control[0].getValue(p_state.getPtStock(), p_state.getStochasticRealization())
        maxStorage = p_grid.getExtremeValues()[0][1]
        minStorage = p_grid.getExtremeValues()[0][0]
        control = max(min(maxStorage - ptStock[0], control), minStorage - ptStock[0])
        
        if control > 0:
            p_phiInOut[0] -= control * (spotPrice + self.m_injectionCost) * actu
        
        elif control < 0:
            p_phiInOut[0] -= control * (spotPrice - self.m_withdrawalCost) * actu
            
        ptStock[0] += control
        p_state.setPtStock(ptStock)
        
    # get size of the  function to follow in simulation
    def getSimuFuncSize(self):
        
        return 1
        
    # store the simulator
    def setSimulator(self, p_simulator):
        
        self.m_simulator = p_simulator
        
    # get the simulator back
    def getSimulator(self):
        
        return self.m_simulator
