# Copyright (C) 2016, 2018 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np
import math
import random


class AR1Simulator :
    
    # Constructor
    # p_D0     Initial value
    # p_m      average value
    # p_sigma  Volatility
    # p_mr     Mean reverting per factor
    # p_T      Maturity
    # p_nbStep Number of time step for simulation
    # p_nbSimul Number of simulations for the Monte Carlo
    # p_bForward true if the simulator is forward, false if the simulation is backward
    def __init__(self, p_D0, p_m, p_sigma, p_mr, p_T, p_nbStep, p_nbSimul, p_bForward) :
        
        self.m_D0 = p_D0
        self.m_m = p_m
        self.m_sigma = p_sigma
        self.m_mr = p_mr
        self.m_T = p_T
        self.m_step = p_T / p_nbStep
        self.m_nbStep = p_nbStep
        self.m_nbSimul = p_nbSimul
        self.m_bForward = p_bForward
        self.m_currentStep = 0. if p_bForward else p_T
        self.m_OUProcess = np.zeros(p_nbSimul)

        np.random.seed(0)
        
        if self.m_bForward :
            self.m_OUProcess = np.zeros(p_nbSimul)
        
        else :
            stDev = self.m_sigma * math.sqrt(1 - math.exp(-2 * self.m_mr * self.m_T) / (2 * self.m_mr))
            self.m_OUProcess = stDev * np.random.randn(p_nbSimul)
            
    # a step forward for OU process
    def forwardStepForOU(self) :
        
        racine = math.sqrt((1 - np.exp(-2 * self.m_mr * self.m_currentStep)) / (2 * self.m_mr))
        stDev = self.m_sigma * racine
        expActu = np.exp(-self.m_mr * self.m_step)
        normalSample = np.random.randn(self.m_nbSimul)
        increment = np.multiply(stDev, normalSample)
        # update OU process
        self.m_OUProcess = np.multiply(self.m_OUProcess, expActu) + increment
        
    # a step backward for OU process
    def backwardStepForOU(self) :
        
        if self.m_currentStep <= 0. :
            self.m_OUProcess = np.zeros(self.m_nbSimul)
        
        else :
            # use brownian bridge
            util = np.sinh(self.m_mr * self.m_currentStep) / np.sinh(self.m_mr * (self.m_currentStep + self.m_step))
            variance = np.multiply(pow(self.m_sigma, 2.) / (np.multiply(2, self.m_mr)), (1 - np.exp(-2 * self.m_mr * self.m_currentStep)) * pow(1 - np.multiply(np.exp(-self.m_mr * self.m_step), util), 2.) + np.multiply(1 - np.exp(-2 * self.m_mr * self.m_step), pow(util, 2.)))
            stdDev = np.sqrt(variance)
            temp = self.m_OUProcess
            self.m_OUProcess = np.multiply(temp, util) + np.multiply(stdDev, np.random.randn(self.m_nbSimul))

    # get current markov state
    def getParticles(self) :
        
        ret = self.m_OUProcess + (np.zeros(len(self.m_OUProcess)) + ((self.m_D0 - self.m_m) * math.exp(- self.m_m * self.m_currentStep) + self.m_m))
        retMap = np.zeros((1, len(ret)))
        for i in range(len(ret)):
            retMap[0,i] = max(ret[i], 0.)
        
        return retMap
    
    # a step forward for simulations
    def stepForward(self) :
        
        if self.m_bForward == False :
            pass
        
        else :
            self.m_currentStep += self.m_step
            self.forwardStepForOU()

    # return the asset values (asset, simulations)
    def stepBackward(self) :
        
        if self.m_bForward == True :
            pass
        
        else :
            self.m_currentStep -= self.m_step
            self.backwardStepForOU()

    # a step forward for simulations
    # return  the asset values (asset, simulations)
    def stepForwardAndGetParticles(self) :
        
        if self.m_bForward == False :
            pass
        
        else :
            self.m_currentStep += self.m_step
            self.forwardStepForOU()
            
            return self.getParticles()

    # a step backward for simulations
    # return  the asset values (asset, simulations)
    def stepBackwardAndGetParticles(self) :
        
        if self.m_bForward == True :
            pass
        
        else :
            self.m_currentStep -= self.m_step
            self.backwardStepForOU()
            
            return self.getParticles()

    # Get back attribute
    
    def getCurrentStep(self) :
        
        return self.m_currentStep

    def getT(self) :
        
        return self.m_T

    def getStep(self) :
        
        return self.m_step

    def getSigma(self) :
        
        return self.m_sigma

    def getMr(self) :
        
        return self.m_mr

    def getNbSimul(self) :
        
        return self.m_nbSimul

    def getNbStep(self) :
        
        return self.m_nbStep

    def getDimension(self) :
        
        return 1
