# Copyright (C) 2016 EDF, 2017, 2018  EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import os.path as osp
import sys
sys.path.append(osp.abspath(osp.dirname(osp.dirname(__file__))))
import numpy as np
import math
import random 
import unittest


# Ornstein Uhlenbeck simulator


# Ornstein Uhlenbeck simulator
class MeanRevertingSimulator :
    
    # Actualize trend
    def actualizeTrend(self) :
        
        self.m_trend = 0
        
        for i in list(range(len(self.m_sigma))) :
            self.m_trend += pow(self.m_sigma[i], 2.) / (2 * self.m_mr[i]) * (1 - math.exp(-2 * self.m_mr[i] * self.m_currentStep))
            
        self.m_trend *= 0.5
        
    # Constructor
    # p_curve  Initial forward curve
    # p_sigma  Volatility of each factor
    # p_mr     Mean reverting per factor
    # p_r      Interest rate
    # p_T      Maturity
    # p_nbStep Number of time step for simulation
    # p_nbSimul Number of simulations for the Monte Carlo
    # p_bForward true if the simulator is forward, false if the simulation is backward
    def __init__(self, p_curve, p_sigma, p_mr, p_r,p_T, p_nbStep, p_nbSimul, p_bForward) :
        
        self.m_curve = p_curve
        self.m_sigma = p_sigma
        self.m_mr = p_mr
        self.m_r = p_r
        self.m_T = p_T
        self.m_step = p_T / p_nbStep
        self.m_nbStep = p_nbStep
        self.m_nbSimul = p_nbSimul
        self.m_bForward = p_bForward
        self.m_currentStep = 0. if p_bForward else p_T
        self.m_OUProcess = np.zeros((len(p_sigma), p_nbSimul))

        np.random.seed(0)

        if self.m_bForward :
            self.m_OUProcess = np.zeros((len(p_sigma), p_nbSimul))
        else :
            #for i in list(range(self.m_OUProcess.shape[0])) :
            stDev = self.m_sigma * math.sqrt((1 - math.exp(-2 * self.m_mr * self.m_T)) / (2 * self.m_mr))
            self.m_OUProcess = stDev * np.random.randn(len(p_sigma), p_nbSimul)
            
        self.actualizeTrend()

    # a step forward for OU process
    def forwardStepForOU(self) :
        
        racine = math.sqrt((1 - np.exp(-2 * self.m_mr * self.m_step)) / (2 * self.m_mr))
        stDev = self.m_sigma * racine
        expActu = np.exp(-self.m_mr * self.m_step)
        normalSample = np.random.randn(len(self.m_sigma), self.m_nbSimul)
        increment = np.multiply(stDev, normalSample)
        # update OU process
        self.m_OUProcess = np.multiply(self.m_OUProcess, expActu) + increment
        
    # a step backward for OU process
    def backwardStepForOU(self) :
        
        if self.m_currentStep <= 0. :
            self.m_OUProcess = np.zeros((len(self.m_sigma), self.m_nbSimul))
        else :
            # use brownian bridge
            util = np.sinh(self.m_mr * self.m_currentStep) / np.sinh(self.m_mr * (self.m_currentStep + self.m_step))
            variance = pow(self.m_sigma, 2.) / (2* self.m_mr)*  ((1 - np.exp(-2 * self.m_mr * self.m_currentStep)) * pow(1 - np.exp(-self.m_mr * self.m_step)*util, 2.) + (1 - np.exp(-2 * self.m_mr * self.m_step))* pow(util, 2.))
            stdDev = np.sqrt(variance)
            self.m_OUProcess = self.m_OUProcess*util + np.einsum( "i,ij->ij",stdDev,np.random.randn(len(self.m_sigma), self.m_nbSimul))

        self.actualizeTrend()

    # get  current markov state
    def getParticles(self) :        
        return self.m_OUProcess

    # get one simulation
    # p_isim  simulation number
    # return the particle associated to p_isim
    # get  current markov state
    def getOneParticle(self, p_isim) :
        
        return self.m_OUProcess[:,p_isim]

    # a step forward for simulations
    def stepForward(self) :
        
        if self.m_bForward == False :
            pass
        else :
            self.m_currentStep += self.m_step
            self.actualizeTrend()
            self.forwardStepForOU()

    # return  the asset values (asset,simulations)
    def stepBackward(self) :
        
        if self.m_bForward == True :
            pass
        else :
            self.m_currentStep -= self.m_step
            self.actualizeTrend()
            self.backwardStepForOU()

    # a step forward for simulations
    # return  the asset values (asset,simulations)
    def stepForwardAndGetParticles(self) :
        
        if self.m_bForward == False :
            pass
        else :
            self.m_currentStep += self.m_step
            self.actualizeTrend()
            self.forwardStepForOU()
            
            return self.m_OUProcess

    # a step backward for simulations
    # return  the asset values (asset,simulations)
    def stepBackwardAndGetParticles(self) :
        
        if self.m_bForward == True :
            pass
        else :
            self.m_currentStep -= self.m_step
            self.actualizeTrend()
            self.backwardStepForOU()
            
            return self.m_OUProcess

    # From particles simulation for an  OU process, get spot price
    # p_particles  (dimension of the problem by number of simulations)
    # return spot price for all simulations
    def fromParticlesToSpot(self, p_particles) :
        
        values = np.zeros(p_particles.shape[1])
        curveCurrent = self.m_curve.get(self.m_currentStep)

        #for i in range(self.m_nbSimul) :
        values = np.multiply(curveCurrent, np.exp(-self.m_trend + np.sum(p_particles,axis=0)))

        return values

    # From one particle simulation for an  OU process, get spot price
    # p_oneParticle  One particle
    # return spot value
    def fromOneParticleToSpot(self, p_oneParticle) :
        
        curveCurrent = self.m_curve.get(self.m_currentStep)
        return curveCurrent * math.exp(-self.m_trend + np.sum(p_oneParticle))

    # get back asset spot value
    def getAssetValues(self) :
        
        return self.fromParticlesToSpot(self.m_OUProcess)

    # get back asset spot value
    # p_isim  simulation particle number
    # return spot value for this particle
    def getAssetValues2(self, p_isim) :
         
        return self.m_curve.get(self.m_currentStep) * math.exp(-self.m_trend + sum(self.m_OUProcess[:,p_isim]))

    # Get back attribute
    
    def getCurrentStep(self) :
        
        return self.m_currentStep

    def getT(self) :
        
        return self.m_T

    def getStep(self) :
        
        return self.m_step

    def getSigma(self) :
        
        return self.m_sigma

    def getMr(self) :
        
        return self.m_mr

    def getNbSimul(self) :
        
        return self.m_nbSimul

    def getNbSample(self) :
        
        return 1

    def getNbStep(self) :
        
        return self.m_nbStep

    def getDimension(self) :
        
        return len(self.m_sigma)

    # actualize at date t=0
    def getActu(self):
        return math.exp(- self.m_r* self.m_currentStep )

    # actualize on one step
    def getActuStep(self):
        return math.exp(- self.m_r* self.m_step )

    
    # forward or backward update
    # p_date  current date in simulator
    def updateDates(self, p_date) :
        
        if self.m_bForward :
            if p_date > 0. :
                self.stepForward()
        else :
            self.stepBackward()

    # forward or backward update for time
    def resetTime(self) :
        
        if self.m_bForward :
            self.m_currentStep = 0.
            self.m_OUProcess = 0.
        else :
            self.m_currentStep = self.m_T
            stDev = self.m_sigma * math.sqrt((1 - math.exp(-2 * self.m_mr * self.m_T)) / (2 * self.m_mr))
            self.m_OUProcess = stDev * np.random.randn(len(self.m_sigma), self.m_nbSimul)

        self.actualizeTrend()

    # update the number of simulations (forward only)
    # p_nbSimul  Number of simulations to update
    # p_nbSample Number of sample to update, useless here
    def updateSimulationNumberAndResetTime(self, p_nbSimul, p_nbSample) :
        
        if self.m_bForward == False :
            pass
        else :
            self.m_nbSimul = p_nbSimul
            self.m_OUProcess.reshape((len(self.m_sigma), p_nbSimul))
            self.m_currentStep = 0.
            self.m_OUProcess = 0.
            self.actualizeTrend()


class MeanRevertingSimulatorTest(unittest.TestCase):

    def test_callOption(self):
        
        nstep = 10
        timeGrid = odrsg.OneDimRegularSpaceGrid(0., 1. / nstep, nstep)
        futValues = np.zeros(nstep + 1) + 100.
        futureGrid = oddata.OneDimData(timeGrid, futValues)
        sigma = np.zeros(1) + 0.25
        mr = np.zeros(1) + 1.
        T = 2.
        nbStep = 10
        nbSimul = 1000000
        
        mrs = MeanRevertingSimulator(futureGrid, sigma, mr, T, nbStep, nbSimul, True)
        b = MeanRevertingSimulator(futureGrid, sigma, mr, T, nbStep, nbSimul, False)
        SF = np.zeros(nbStep / 2)
        SB = np.zeros(nbStep / 2)
        K = 100.
          
        def compareList(l, res) :
              
            for i in range(len(l)) :
                res[i] = max(l[i], 0.)
              
            return res
       
        for i in range(nbStep / 2) :
            particlesF = mrs.stepForwardAndGetParticles()
            spotF = mrs.fromParticlesToSpot(particlesF)
            lF = np.zeros(nbSimul)
            compareList(spotF - K, lF)
            SF[i] = np.mean(lF)
            particlesB = b.stepBackwardAndGetParticles()
            spotB = b.fromParticlesToSpot(particlesB)
            lB = np.zeros(nbSimul)
            compareList(spotB - K, lB)
            SB[i] = np.mean(lB)
            
        self.assertAlmostEqual(SF[nbStep/2-1], 7.069, None, None, 0.1)
        self.assertAlmostEqual(SB[nbStep/2-1], SF[nbStep/2-1], None, None, 0.2)
        
if __name__ == '__main__':
    unittest.main()
