"""
 Copyright (C) 2000, 2001, 2002, 2003 RiskMap srl
 Copyright (C) 2007 StatPro Italia srl
 Copyright (C) 2020 Marcin Rybacki

 This file is part of QuantLib, a free-software/open-source library
 for financial quantitative analysts and developers - http://quantlib.org/

 QuantLib is free software: you can redistribute it and/or modify it
 under the terms of the QuantLib license.  You should have received a
 copy of the license along with this program; if not, please email
 <quantlib-dev@lists.sf.net>. The license is also available online at
 <https://www.quantlib.org/license.shtml>.

 This program is distributed in the hope that it will be useful, but WITHOUT
 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 FOR A PARTICULAR PURPOSE.  See the license for more details.
"""

import QuantLib as ql
import unittest
import math

flag = None


def raiseFlag():
    global flag
    flag = 1


def binaryFunction(x, y):
    return 2.0 * x + y


def extrapolatedForwardRate(
        firstSmoothingPoint,
        lastLiquidForwardRate,
        ultimateForwardRate,
        alpha):

    def calculate(t):
        deltaT = t - firstSmoothingPoint
        beta = (1.0 - math.exp(-alpha * deltaT)) / (alpha * deltaT)
        return ultimateForwardRate + (
            lastLiquidForwardRate - ultimateForwardRate) * beta

    return calculate


class TermStructureTest(unittest.TestCase):
    def setUp(self):
        self.calendar = ql.TARGET()
        today = self.calendar.adjust(ql.Date.todaysDate())
        ql.Settings.instance().evaluationDate = today
        self.settlementDays = 2
        self.dayCounter = ql.Actual360()
        self.settlement = self.calendar.advance(today, self.settlementDays, ql.Days)
        deposits = [
            ql.DepositRateHelper(
                ql.makeQuoteHandle(rate / 100),
                ql.Period(n, units),
                self.settlementDays,
                self.calendar,
                ql.ModifiedFollowing,
                False,
                self.dayCounter,
            )
            for (n, units, rate) in [
                (1, ql.Months, 4.581),
                (2, ql.Months, 4.573),
                (3, ql.Months, 4.557),
                (6, ql.Months, 4.496),
                (9, ql.Months, 4.490),
            ]
        ]
        swaps = [
            ql.SwapRateHelper(
                ql.makeQuoteHandle(rate / 100),
                ql.Period(years, ql.Years),
                self.calendar,
                ql.Annual,
                ql.Unadjusted,
                ql.Thirty360(ql.Thirty360.BondBasis),
                ql.Euribor6M(),
            )
            for (years, rate) in [(1, 4.54), (5, 4.99), (10, 5.47), (20, 5.89), (30, 5.96)]
        ]
        self.instruments = deposits + swaps

        self.termStructure = ql.PiecewiseFlatForward(
            self.settlement, self.instruments, self.dayCounter)

    def tearDown(self):
        ql.Settings.instance().evaluationDate = ql.Date()

    def testImpliedObs(self):
        "Testing observability of implied term structure"
        global flag
        flag = None
        h = ql.RelinkableYieldTermStructureHandle()
        settlement = self.termStructure.referenceDate()
        new_settlement = self.calendar.advance(settlement, 3, ql.Years)
        implied = ql.ImpliedTermStructure(h, new_settlement)
        obs = ql.Observer(raiseFlag)
        obs.registerWith(implied)
        h.linkTo(self.termStructure)
        if not flag:
            self.fail("Observer was not notified of term structure change")

    def testFSpreadedObs(self):
        "Testing observability of forward-spreaded term structure"
        global flag
        flag = None
        me = ql.SimpleQuote(0.01)
        mh = ql.QuoteHandle(me)
        h = ql.RelinkableYieldTermStructureHandle()
        spreaded = ql.ForwardSpreadedTermStructure(h, mh)
        obs = ql.Observer(raiseFlag)
        obs.registerWith(spreaded)
        h.linkTo(self.termStructure)
        if not flag:
            self.fail("Observer was not notified of term structure change")
        flag = None
        me.setValue(0.005)
        if not flag:
            self.fail("Observer was not notified of spread change")

    def testZSpreadedObs(self):
        "Testing observability of zero-spreaded term structure"
        global flag
        flag = None
        me = ql.SimpleQuote(0.01)
        mh = ql.QuoteHandle(me)
        h = ql.RelinkableYieldTermStructureHandle()
        spreaded = ql.ZeroSpreadedTermStructure(h, mh)
        obs = ql.Observer(raiseFlag)
        obs.registerWith(spreaded)
        h.linkTo(self.termStructure)
        if not flag:
            self.fail("Observer was not notified of term structure change")
        flag = None
        me.setValue(0.005)
        if not flag:
            self.fail("Observer was not notified of spread change")

    def testCompositeZeroYieldStructure(self):
        """Testing composite zero yield structure"""
        settlement = self.termStructure.referenceDate()
        compounding = ql.Compounded
        freq = ql.Semiannual
        flatTs = ql.FlatForward(
            settlement,
            ql.makeQuoteHandle(0.0085),
            self.dayCounter)
        firstHandle = ql.YieldTermStructureHandle(flatTs)
        secondHandle = ql.YieldTermStructureHandle(self.termStructure)
        compositeTs = ql.CompositeZeroYieldStructure(
            firstHandle, secondHandle, binaryFunction, compounding, freq)
        maturity = settlement + ql.Period(20, ql.Years)
        expectedZeroRate = binaryFunction(
            firstHandle.zeroRate(
                maturity, self.dayCounter, compounding, freq).rate(),
            secondHandle.zeroRate(
                maturity, self.dayCounter, compounding, freq).rate())
        actualZeroRate = compositeTs.zeroRate(
            maturity, self.dayCounter, compounding, freq).rate()
        failMsg = """ Composite zero yield structure rate replication failed:
                        expected zero rate: {expected}
                        actual zero rate: {actual}
                  """.format(expected=expectedZeroRate,
                             actual=actualZeroRate)
        self.assertAlmostEqual(
            first=expectedZeroRate,
            second=actualZeroRate,
            delta=1.0e-12,
            msg=failMsg)

    def testUltimateForwardTermStructure(self):
        """Testing ultimate forward term structure"""
        settlement = self.termStructure.referenceDate()
        ufr = ql.makeQuoteHandle(0.06)
        llfr = ql.makeQuoteHandle(0.05)
        fsp = ql.Period(20, ql.Years)
        alpha = 0.05
        baseCrvHandle = ql.YieldTermStructureHandle(self.termStructure)
        ufrCrv = ql.UltimateForwardTermStructure(
            baseCrvHandle, llfr, ufr, fsp, alpha)
        cutOff = ufrCrv.timeFromReference(settlement + fsp)
        forwardCalculator = extrapolatedForwardRate(
            cutOff, llfr.value(), ufr.value(), alpha)
        times = [ufrCrv.timeFromReference(settlement + ql.Period(x, ql.Years))
                 for x in [21, 30, 40, 50, 60, 70, 80, 90, 100]]
        for t in times:
            actualForward = ufrCrv.forwardRate(
                cutOff, t, ql.Continuous, ql.NoFrequency, True).rate()
            expectedForward = forwardCalculator(t)
            failMsg = """ UFR term structure forward replication failed for:
                            time to maturity: {timeToMaturity}
                            expected forward rate: {expected}
                            actual forward rate: {actual}
                      """.format(timeToMaturity=t,
                                 expected=expectedForward,
                                 actual=actualForward)
            self.assertAlmostEqual(
                first=expectedForward,
                second=actualForward,
                delta=1.0e-12,
                msg=failMsg)

    def testTermStructureInterpolationSchemes(self):
        """Testing different interpolation schemes and their consistency"""
        args = [self.settlement, self.instruments, self.dayCounter]
        mapping = [
            [ql.PiecewiseParabolicCubicZero, 
             ql.ParabolicCubicZeroCurve, 'Parabolic Zero'],
            [ql.PiecewiseMonotonicParabolicCubicZero, 
             ql.MonotonicParabolicCubicZeroCurve, 'Monotone Parabolic Zero'],
            [ql.PiecewiseLogParabolicCubicDiscount, 
             ql.LogParabolicCubicDiscountCurve, 'Log Parabolic Discount'],
            [ql.PiecewiseMonotonicLogParabolicCubicDiscount, 
             ql.MonotonicLogParabolicCubicDiscountCurve, 'Monotone Log Parabolic Discount'],
        ]

        for bootstrap, interp, name in mapping:
            bootstrap_crv = bootstrap(*args)
            dates, nodes = zip(*bootstrap_crv.nodes())
            equivalent_crv = interp(dates, nodes, self.dayCounter)

            for d in dates:
                expected = equivalent_crv.zeroRate(
                    d, self.dayCounter, ql.Continuous, ql.NoFrequency).rate()
                actual = bootstrap_crv.zeroRate(
                    d, self.dayCounter, ql.Continuous, ql.NoFrequency).rate()
                
                failMsg = """ Interpolation check failed for:
                            interpolation: {interpolation}
                            expected zero rate: {expected}
                            actual zero rate: {actual}
                      """.format(interpolation=name,
                                 expected=expected,
                                 actual=actual)
                self.assertAlmostEqual(
                    first=expected,
                    second=actual,
                    delta=1.0e-12,
                    msg=failMsg)

    def testInterpolatedPiecewiseZeroSpreadedTermStructure(self):
        """Testing different interpolation schemes for zero spreaded term structure"""
        h = ql.RelinkableYieldTermStructureHandle()
        h.linkTo(self.termStructure)
        spreads = [(1, 0.005), (2, 0.008), (3, 0.0103), (4, 0.0145), (5, 0.025)]
        dates, quotes = zip(*[(h.referenceDate() + ql.Period(t, ql.Years),
                               ql.QuoteHandle(ql.SimpleQuote(s)))
                              for t, s in spreads])
        args = [h, quotes, dates]
        constructors = [ql.SpreadedLinearZeroInterpolatedTermStructure,
                        ql.SpreadedBackwardFlatZeroInterpolatedTermStructure,
                        ql.SpreadedCubicZeroInterpolatedTermStructure,
                        ql.SpreadedKrugerZeroInterpolatedTermStructure,
                        ql.SpreadedSplineCubicZeroInterpolatedTermStructure]
        for constructor in constructors:
            spreadedTs = constructor(*args)
            for d, r in zip(dates, quotes):
                expected = r.value()

                zeroFromSpreadTS = spreadedTs.zeroRate(
                    d, self.dayCounter, ql.Continuous, ql.NoFrequency).rate()
                zeroFromBaseTs = h.zeroRate(
                    d, self.dayCounter, ql.Continuous, ql.NoFrequency).rate()
                actual = zeroFromSpreadTS - zeroFromBaseTs

                failMsg = """ Interpolated piecewise zero spreaded term structure 
                              zero rate replication failed for:
                            maturity: {maturity}
                            expected zero rate: {expected}
                            actual zero rate: {actual}
                      """.format(maturity=d,
                                 expected=expected,
                                 actual=actual)
                self.assertAlmostEqual(
                    first=expected,
                    second=actual,
                    delta=1.0e-12,
                    msg=failMsg)

    def testQuantoTermStructure(self):
        """Testing quanto term structure"""
        today = ql.Date.todaysDate()

        dividend_ts = ql.YieldTermStructureHandle(
            ql.FlatForward(
                today,
                ql.makeQuoteHandle(0.055),
                self.dayCounter
            )
        )
        r_domestic_ts = ql.YieldTermStructureHandle(
            ql.FlatForward(
                today,
                ql.makeQuoteHandle(-0.01),
                self.dayCounter
            )
        )
        r_foreign_ts = ql.YieldTermStructureHandle(
            ql.FlatForward(
                today,
                ql.makeQuoteHandle(0.02),
                self.dayCounter
            )
        )
        sigma_s = ql.BlackVolTermStructureHandle(
            ql.BlackConstantVol(
                today,
                self.calendar,
                ql.makeQuoteHandle(0.25),
                self.dayCounter
            )
        )
        sigma_fx = ql.BlackVolTermStructureHandle(
            ql.BlackConstantVol(
                today,
                self.calendar,
                ql.makeQuoteHandle(0.05),
                self.dayCounter
            )
        )
        rho = ql.makeQuoteHandle(0.3)
        s_0 = ql.makeQuoteHandle(100.0)

        exercise = ql.EuropeanExercise(self.calendar.advance(today, 6, ql.Months))
        payoff = ql.PlainVanillaPayoff(ql.Option.Call, 95.0)

        vanilla_option = ql.VanillaOption(payoff, exercise)
        quanto_ts = ql.YieldTermStructureHandle(
            ql.QuantoTermStructure(
                dividend_ts,
                r_domestic_ts,
                r_foreign_ts,
                sigma_s,
                ql.nullDouble(),
                sigma_fx,
                ql.nullDouble(),
                rho.value()
            )
        )
        gbm_quanto = ql.BlackScholesMertonProcess(s_0, quanto_ts, r_domestic_ts, sigma_s)
        vanilla_engine = ql.AnalyticEuropeanEngine(gbm_quanto)
        vanilla_option.setPricingEngine(vanilla_engine)

        quanto_option = ql.QuantoVanillaOption(payoff, exercise)
        gbm_vanilla = ql.BlackScholesMertonProcess(s_0, dividend_ts, r_domestic_ts, sigma_s)
        quanto_engine = ql.QuantoEuropeanEngine(gbm_vanilla, r_foreign_ts, sigma_fx, rho)
        quanto_option.setPricingEngine(quanto_engine)

        quanto_option_pv = quanto_option.NPV()
        vanilla_option_pv = vanilla_option.NPV()

        message = """Failed to reproduce QuantoOption / EuropeanQuantoEngine NPV:
                      {quanto_pv}
                      by using the QuantoTermStructure as the dividend together with
                      VanillaOption / AnalyticEuropeanEngine:
                      {vanilla_pv}
                  """.format(
            quanto_pv=quanto_option_pv,
            vanilla_pv=vanilla_option_pv
        )

        self.assertAlmostEqual(
            quanto_option_pv,
            vanilla_option_pv,
            delta=1e-12,
            msg=message
        )

    def testLazyObject(self):
        evaluationDate = ql.Settings.instance().evaluationDate
        nodes = self.termStructure.nodes()
        self.termStructure.freeze()

        ql.Settings.instance().evaluationDate = self.calendar.advance(evaluationDate, 100, ql.Days)

        # Check that dates and rates are unchanged
        for i in range(len(self.termStructure.nodes())):
            self.assertEqual(nodes[i][0], self.termStructure.nodes()[i][0])
            self.assertEqual(nodes[i][1], self.termStructure.nodes()[i][1])

        self.termStructure.recalculate()

        # Check that dates have changed (except the reference, which is fixed)
        for i in range(1, len(self.termStructure.nodes())):
            self.assertNotEqual(nodes[i][0], self.termStructure.nodes()[i][0])

        ql.Settings.instance().evaluationDate = evaluationDate

        self.termStructure.unfreeze()


if __name__ == "__main__":
    print("testing QuantLib", ql.__version__)
    unittest.main(verbosity=2)
