"""A couple of helper descriptors to allow to use the same code as query
criterion creators and as instance code. As this doesn't do advanced
magic recompiling, you can only use basic expression-like code."""

import new

class MethodDescriptor(object):
    def __init__(self, func):
        self.func = func
    def __get__(self, instance, owner):
        if instance is None:
            return new.instancemethod(self.func, owner, owner.__class__)
        else:
            return new.instancemethod(self.func, instance, owner)

class PropertyDescriptor(object):
    def __init__(self, fget, fset, fdel):
        self.fget = fget
        self.fset = fset
        self.fdel = fdel
    def __get__(self, instance, owner):
        if instance is None:
            return self.fget(owner)
        else:
            return self.fget(instance)
    def __set__(self, instance, value):
        self.fset(instance, value)
    def __delete__(self, instance):
        self.fdel(instance)
        
def hybrid(func):
    return MethodDescriptor(func)

def hybrid_property(fget, fset=None, fdel=None):
    return PropertyDescriptor(fget, fset, fdel)

### Example code

from sqlalchemy import *
from sqlalchemy.orm import *

metadata = MetaData('sqlite://')
metadata.bind.echo = True

print "Set up database metadata"

interval_table1 = Table('interval1', metadata,
    Column('id', Integer, primary_key=True),
    Column('start', Integer, nullable=False),
    Column('end', Integer, nullable=False))

interval_table2 = Table('interval2', metadata,
    Column('id', Integer, primary_key=True),
    Column('start', Integer, nullable=False),
    Column('length', Integer, nullable=False))

metadata.create_all()

# A base class for intervals

class BaseInterval(object):
    @hybrid
    def contains(self,point):
        return (self.start <= point) & (point < self.end)
    
    @hybrid
    def intersects(self, other):
        return (self.start < other.end) & (self.end > other.start)

    def __repr__(self):
        return "%s(%s..%s)" % (self.__class__.__name__, self.start, self.end)

# Interval stored as endpoints

class Interval1(BaseInterval):
    def __init__(self, start, end):
        self.start = start
        self.end = end
    
    length = hybrid_property(lambda s: s.end - s.start)

mapper(Interval1, interval_table1)

# Interval stored as start and length

class Interval2(BaseInterval):
    def __init__(self, start, length):
        self.start = start
        self.length = length
    
    end = hybrid_property(lambda s: s.start + s.length)

mapper(Interval2, interval_table2)

print "Create the data"

session = create_session()

intervals = [Interval1(1,4), Interval1(3,15), Interval1(11,16)]

for interval in intervals:
    session.save(interval)
    session.save(Interval2(interval.start, interval.length))

session.flush()

print "Clear the cache and do some queries"

session.clear()

for Interval in (Interval1, Interval2):
    print "Querying using interval class %s" % Interval.__name__
    
    print
    print '-- length less than 10'
    print [(i, i.length) for i in session.query(Interval).filter(Interval.length < 10).all()]
    
    print
    print '-- contains 12'
    print session.query(Interval).filter(Interval.contains(12)).all()
    
    print
    print '-- intersects 2..10'
    other = Interval1(2,10)
    result = session.query(Interval).filter(Interval.intersects(other)).order_by(Interval.length).all()
    print [(interval, interval.intersects(other)) for interval in result]

