"""
 Copyright (C) 2000, 2001, 2002 RiskMap srl

 This file is part of QuantLib, a free-software/open-source library
 for financial quantitative analysts and developers - http://quantlib.org/

 QuantLib is free software: you can redistribute it and/or modify it under the
 terms of the QuantLib license.  You should have received a copy of the
 license along with this program; if not, please email ferdinando@ametrano.net
 The license is also available online at http://quantlib.org/html/license.html

 This program is distributed in the hope that it will be useful, but WITHOUT
 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 FOR A PARTICULAR PURPOSE.  See the license for more details.
"""

# $Id: defaults.py,v 1.30 2002/03/14 17:40:07 lballabio Exp $

from QuantLib import *
import sys
import os
import new
import code
import types

# Array class on alpha seems not to raise the right exception
if sys.platform.startswith('linux') and os.uname()[-1] == 'alpha':
    Array._old___getitem__ = Array.__getitem__
    def Array_new___getitem__(self,i):
        try:
            return self._old___getitem__(i)
        except Exception, e:
            if str(e).endswith('out of range'):
                raise IndexError, e
            else:
                raise
    Array.__getitem__ = Array_new___getitem__

# Calendar class
Calendar._old_roll = Calendar.roll
def Calendar_new_roll(self,d,convention='Following'):
    return self._old_roll(d,convention)
Calendar.roll = Calendar_new_roll

Calendar._old_advance = Calendar.advance
def Calendar_new_advance(self,d,n,unit,convention='Following'):
    return self._old_advance(d,n,unit,convention)
Calendar.advance = Calendar_new_advance

# DayCounter class
DayCounter._old_yearFraction = DayCounter.yearFraction
def DayCounter_new_yearFraction(self,d1,d2,startRef=None,endRef=None):
    return self._old_yearFraction(d1,d2,startRef,endRef)
DayCounter.yearFraction = DayCounter_new_yearFraction

# Cash flow classes

CashFlow._old___init__ = CashFlow.__init__
def PyCashFlow_notifyObservers(self):
    self._wrapper.notifyObservers()
def CashFlow_new___init__(self,pyCashFlow):
    self._old___init__(pyCashFlow)
    pyCashFlow._wrapper = self
    pyCashFlow.notifyObservers = \
        new.instancemethod(PyCashFlow_notifyObservers,pyCashFlow,
            pyCashFlow.__class__)
    # proxy pyCashFlow methods
    c = pyCashFlow.__class__
    for method in dir(c):
        if method not in dir(CashFlow):
            command = "self.%s = new.instancemethod(c.%s,pyCashFlow,c)" % \
                      (method,method)
            try:
                eval(code.compile_command(command))
            except:
                pass
CashFlow.__init__ = CashFlow_new___init__


FixedRateCoupon._old___init__ = FixedRateCoupon.__init__
def FixedRateCoupon_new___init__(self,nominal,rate,calendar,convention,
  dayCounter,startDate,endDate,startRef=None,endRef=None):
    self._old___init__(nominal,rate,calendar,convention,dayCounter,startDate,
        endDate,startRef,endRef)
FixedRateCoupon.__init__ = FixedRateCoupon_new___init__

FloatingRateCoupon._old___init__ = FloatingRateCoupon.__init__
def FloatingRateCoupon_new___init__(self,nominal,index,termStructure,
  startDate,endDate,fixingDays=0, spread=0.0,startRef=None,endRef=None):
    self._old___init__(nominal,index,termStructure,startDate,endDate,spread,
        startRef,endRef)
FloatingRateCoupon.__init__ = FloatingRateCoupon_new___init__

FixedRateCouponVector._old___init__ = FixedRateCouponVector.__init__
def FixedRateCouponVector_new___init__(self,nominals,couponRates,
  startDate,endDate,frequency,calendar,convention,isAdjusted,dayCount,
  stubDate=None,firstPeriodDayCount=None):
    firstPeriodDayCount = firstPeriodDayCount or dayCount
    self._old___init__(nominals,couponRates,startDate,endDate,frequency,
        calendar,convention,isAdjusted,dayCount,firstPeriodDayCount,stubDate)
FixedRateCouponVector.__init__ = FixedRateCouponVector_new___init__

FloatingRateCouponVector._old___init__ = FloatingRateCouponVector.__init__
def FloatingRateCouponVector_new___init__(self,nominals,startDate,endDate,
  frequency,calendar,convention,termStructure,index, fixingDays=0,spreads=[],
  stubDate=None):
    self._old___init__(nominals,startDate,endDate,frequency,calendar,
        convention,termStructure,index,fixingDays,spreads,stubDate)
FloatingRateCouponVector.__init__ = FloatingRateCouponVector_new___init__


# History

def HistoryIterator_next_21(self):
    if self == self.end:
        return None
    else:
        x = self._value()
        self._advance()
        return x
def HistoryIterator_next_22(self):
    if self == self.end:
        raise StopIteration
    else:
        x = self._value()
        self._advance()
        return x
if sys.hexversion > 0x02020000:
    HistoryIterator.next = HistoryIterator_next_22
    HistoryValidIterator.next = HistoryIterator_next_22
    HistoryDataIterator.next = HistoryIterator_next_22
    HistoryValidDataIterator.next = HistoryIterator_next_22
else:
    HistoryIterator.next = HistoryIterator_next_21
    HistoryValidIterator.next = HistoryIterator_next_21
    HistoryDataIterator.next = HistoryIterator_next_21
    HistoryValidDataIterator.next = HistoryIterator_next_21

class HistoryValidEntries:
    def __init__(self,history):
        self.history = history
    def __iter__(self):
        i = self.history._vbegin()
        i.end = self.history._vend()
        return i
    def iterator(self,date=None):
        date = date or self.history.firstDate()
        i = self.history._valid_iterator(date)
        i.end = self.history._vend()
        return i

class HistoryData:
    def __init__(self,history):
        self.history = history
    def __iter__(self):
        i = self.history._dbegin()
        i.end = self.history._dend()
        return i
    def iterator(self,date=None):
        date = date or self.history.firstDate()
        i = self.history._data_iterator(date)
        i.end = self.history._dend()
        return i

class HistoryValidData:
    def __init__(self,history):
        self.history = history
    def __iter__(self):
        i = self.history._vdbegin()
        i.end = self.history._vdend()
        return i
    def iterator(self,date=None):
        date = date or self.history.firstDate()
        i = self.history._valid_data_iterator(date)
        i.end = self.history._vdend()
        return i

def History_new_iterator(self,date=None):
    date = date or self.firstDate()
    i = self._iterator(date)
    i.end = self._end()
    return i
def History_new___iter__(self):
    i = self._begin()
    i.end = self._end()
    return i
def History_new_valid(self):
    return HistoryValidEntries(self)
def History_new_data(self):
    return HistoryData(self)
def History_new_validData(self):
    return HistoryValidData(self)
History.iterator = History_new_iterator
History.__iter__ = History_new___iter__
History.valid = History_new_valid
if sys.hexversion > 0x02020000:
    History.data = History_new_data
    History.validData = History_new_validData


# Instruments

PlainOption._old___init__ = PlainOption.__init__
def PlainOption_new___init__(self,type,underlying,strike,dividendYield,
  riskFreeRate,exerciseDate,volatility,engine,isinCode="",description=""):
    self._old___init__(type,underlying,strike,dividendYield,riskFreeRate,
        exerciseDate,volatility,engine,isinCode,description)
PlainOption.__init__ = PlainOption_new___init__

Stock._old___init__ = Stock.__init__
def Stock_new___init__(self,price,isinCode="",description=""):
    self._old___init__(price,isinCode,description)
Stock.__init__ = Stock_new___init__

Swap._old___init__ = Swap.__init__
def Swap_new___init__(self,firstLeg,secondLeg,termStructure,
  isinCode="",description="interest rate swap"):
    self._old___init__(firstLeg,secondLeg,termStructure,isinCode,description)
Swap.__init__ = Swap_new___init__

SimpleSwap._old___init__ = SimpleSwap.__init__
def SimpleSwap_new___init__(self,payFixedRate,startDate,n,unit,
  calendar,rollingConvention,nominal,fixedFrequency,fixedRate,
  fixedIsAdjusted,fixedDayCount,floatingFrequency,index,
  indexFixingDays, spread, termStructure,isinCode="",
  description="interest rate swap"):
    self._old___init__(payFixedRate,startDate,n,unit,calendar,
        rollingConvention,nominal,fixedFrequency,fixedRate,
        fixedIsAdjusted,fixedDayCount,floatingFrequency,index,
        indexFixingDays,spread,termStructure,isinCode,
        description)
SimpleSwap.__init__ = SimpleSwap_new___init__


# Market elements
MarketElement._old___init__ = MarketElement.__init__
def PyMarketElement_notifyObservers(self):
    self._wrapper.notifyObservers()
def MarketElement_new___init__(self,pyMarketElement):
    self._old___init__(pyMarketElement)
    pyMarketElement._wrapper = self
    pyMarketElement.notifyObservers = \
        new.instancemethod(PyMarketElement_notifyObservers,pyMarketElement,
            pyMarketElement.__class__)
    # proxy pyMarketElement methods
    c = pyMarketElement.__class__
    for method in dir(c):
        if method not in dir(MarketElement):
            command = "self.%s = new.instancemethod(c.%s,pyMarketElement,c)"\
                % (method,method)
            try:
                eval(code.compile_command(command))
            except:
                pass
MarketElement.__init__ = MarketElement_new___init__


MarketElementHandle._old___init__ = MarketElementHandle.__init__
MarketElementHandle._old_linkTo = MarketElementHandle.linkTo
def MarketElementHandle_new___init__(self,h=None):
    self._old___init__(h)
    self.currentLink = h
def MarketElementHandle_new_linkTo(self,h):
    self._old_linkTo(h)
    self.currentLink = h
def MarketElementHandle___getattr__(self,attr):
    return getattr(self.currentLink,attr)
MarketElementHandle.__init__ = MarketElementHandle_new___init__
MarketElementHandle.linkTo = MarketElementHandle_new_linkTo
MarketElementHandle.__getattr__ = MarketElementHandle___getattr__

# Scheduler

Scheduler._old___init__ = Scheduler.__init__
def Scheduler_new___init__(self,calendar,startDate,endDate,frequency,
  convention,isAdjusted,stubDate=None):
    self._old___init__(calendar,startDate,endDate,frequency,convention,
        isAdjusted,stubDate)
Scheduler.__init__ = Scheduler_new___init__

# PiecewiseFlatForward

PiecewiseFlatForward._old___init__ = PiecewiseFlatForward.__init__
def PiecewiseFlatForward_new___init__(self,currency, dayCounter,
                    todaysDate, calendar, settlementDays,
                    instruments, accuracy = 1.0e-12):
    self._old___init__(currency, dayCounter,
                    todaysDate, calendar, settlementDays,
                    instruments, accuracy)
PiecewiseFlatForward.__init__ = PiecewiseFlatForward_new___init__


# Term structures

def TermStructure_new_zeroYield(self,arg,extrapolate=0):
    # dispatch on type
    if type(arg) == types.FloatType or type(arg) == types.IntType:
        return self._zeroYieldVsTime(arg,extrapolate)
    else:
        return self._zeroYieldVsDate(arg,extrapolate)
def TermStructure_new_discount(self,arg,extrapolate=0):
    # dispatch on type
    if type(arg) == types.FloatType or type(arg) == types.IntType:
        return self._discountVsTime(arg,extrapolate)
    else:
        return self._discountVsDate(arg,extrapolate)
def TermStructure_new_forward(self,arg,extrapolate=0):
    # dispatch on type
    if type(arg) == types.FloatType or type(arg) == types.IntType:
        return self._forwardVsTime(arg,extrapolate)
    else:
        return self._forwardVsDate(arg,extrapolate)
TermStructure.zeroYield = TermStructure_new_zeroYield
TermStructure.discount = TermStructure_new_discount
TermStructure.forward = TermStructure_new_forward


TermStructureHandle._old___init__ = TermStructureHandle.__init__
TermStructureHandle._old_linkTo = TermStructureHandle.linkTo
def TermStructureHandle_new___init__(self,h=None):
    self._old___init__(h)
    self.currentLink = h
def TermStructureHandle_new_linkTo(self,h):
    self._old_linkTo(h)
    self.currentLink = h
def TermStructureHandle___getattr__(self,attr):
    return getattr(self.currentLink,attr)
TermStructureHandle.__init__ = TermStructureHandle_new___init__
TermStructureHandle.linkTo = TermStructureHandle_new_linkTo
TermStructureHandle.__getattr__ = TermStructureHandle___getattr__


# Volatilities
def SwaptionVolatilityStructure_new_volatility(self,start,length,rate):
    # dispatch on type
    if type(start) == types.FloatType or type(start) == types.IntType:
        return self._volatilityVsTime(start,length,rate)
    else:
        return self._volatilityVsDate(start,length,rate)
SwaptionVolatilityStructure.volatility = \
    SwaptionVolatilityStructure_new_volatility


SwaptionVolatilityStructureHandle._old___init__ = \
    SwaptionVolatilityStructureHandle.__init__
SwaptionVolatilityStructureHandle._old_linkTo = \
    SwaptionVolatilityStructureHandle.linkTo
def SwaptionVolatilityStructureHandle_new___init__(self,h=None):
    self._old___init__(h)
    self.currentLink = h
def SwaptionVolatilityStructureHandle_new_linkTo(self,h):
    self._old_linkTo(h)
    self.currentLink = h
def SwaptionVolatilityStructureHandle___getattr__(self,attr):
    return getattr(self.currentLink,attr)
SwaptionVolatilityStructureHandle.__init__ = \
    SwaptionVolatilityStructureHandle_new___init__
SwaptionVolatilityStructureHandle.linkTo = \
    SwaptionVolatilityStructureHandle_new_linkTo
SwaptionVolatilityStructureHandle.__getattr__ = \
    SwaptionVolatilityStructureHandle___getattr__


def CapFlatVolatilityStructure_new_volatility(self,end,rate):
    # dispatch on type
    if type(end) == types.FloatType or type(end) == types.IntType:
        return self._volatilityVsTime(end,rate)
    else:
        return self._volatilityVsDate(end,rate)
CapFlatVolatilityStructure.volatility = \
    CapFlatVolatilityStructure_new_volatility


CapFlatVolatilityStructureHandle._old___init__ = \
    CapFlatVolatilityStructureHandle.__init__
CapFlatVolatilityStructureHandle._old_linkTo = \
    CapFlatVolatilityStructureHandle.linkTo
def CapFlatVolatilityStructureHandle_new___init__(self,h=None):
    self._old___init__(h)
    self.currentLink = h
def CapFlatVolatilityStructureHandle_new_linkTo(self,h):
    self._old_linkTo(h)
    self.currentLink = h
def CapFlatVolatilityStructureHandle___getattr__(self,attr):
    return getattr(self.currentLink,attr)
CapFlatVolatilityStructureHandle.__init__ = \
    CapFlatVolatilityStructureHandle_new___init__
CapFlatVolatilityStructureHandle.linkTo = \
    CapFlatVolatilityStructureHandle_new_linkTo
CapFlatVolatilityStructureHandle.__getattr__ = \
    CapFlatVolatilityStructureHandle___getattr__


SwaptionVolatilityMatrix._old___init__ = SwaptionVolatilityMatrix.__init__
def SwaptionVolatilityMatrix_new___init__(self,today,dates,lengths,vols,
                                          dayCounter=DayCounter('30/360')):
    self._old___init__(today,dates,lengths,vols,dayCounter)
SwaptionVolatilityMatrix.__init__ = SwaptionVolatilityMatrix_new___init__


CapFlatVolatilityVector._old___init__ = CapFlatVolatilityVector.__init__
def CapFlatVolatilityVector_new___init__(self,today,calendar,
                                         settlementDays,lengths,vols,
                                         dayCounter=DayCounter('30/360')):
    self._old___init__(today,calendar,settlementDays,lengths,vols,dayCounter)
CapFlatVolatilityVector.__init__ = CapFlatVolatilityVector_new___init__

