1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
INFINITY = float('inf')
NEGATIVE_INFINITY = -INFINITY
class IntervalSet:
def __init__(self, intervals, disjoint=False):
self.intervals = intervals
if not disjoint:
self.intervals = union_overlapping(self.intervals)
self.size = sum(i.size for i in self.intervals)
def __repr__(self):
return repr(self.intervals)
def __iter__(self):
return iter(self.intervals)
def __len__(self):
return len(self.intervals)
def __getitem__(self, i):
return self.intervals[i]
def __nonzero__(self):
return self.size != 0
def __sub__(self, other):
return self.intersect( other.complement() )
def complement(self):
complementary = []
cursor = NEGATIVE_INFINITY
for interval in self.intervals:
if cursor < interval.start:
complementary.append( Interval(cursor, interval.start) )
cursor = interval.end
if cursor < INFINITY:
complementary.append( Interval(cursor, INFINITY) )
return IntervalSet(complementary, disjoint=True)
def intersect(self, other): #XXX The last major bottleneck. Factorial-time hell.
# Then again, this function is entirely unused...
if (not self) or (not other):
return IntervalSet([])
#earliest = max(self.intervals[0].start, other.intervals[0].start)
#latest = min(self.intervals[-1].end, other.intervals[-1].end)
#mine = [i for i in self.intervals if i.start >= earliest and i.end <= latest]
#theirs = [i for i in other.intervals if i.start >= earliest and i.end <= latest]
intersections = [x for x in (i.intersect(j)
for i in self.intervals
for j in other.intervals)
if x]
return IntervalSet(intersections, disjoint=True)
def intersect_interval(self, interval):
intersections = [x for x in (i.intersect(interval)
for i in self.intervals)
if x]
return IntervalSet(intersections, disjoint=True)
def union(self, other):
return IntervalSet( sorted(self.intervals + other.intervals) )
class Interval:
def __init__(self, start, end):
if end - start < 0:
raise ValueError("Invalid interval start=%s end=%s" % (start, end))
self.start = start
self.end = end
self.tuple = (start, end)
self.size = self.end - self.start
def __eq__(self, other):
return self.tuple == other.tuple
def __ne__(self, other):
return self.tuple != other.tuple
def __hash__(self):
return hash( self.tuple )
def __lt__(self, other):
return self.start < self.start
def __le__(self, other):
return self.start <= self.start
def __gt__(self, other):
return self.start > self.start
def __ge__(self, other):
return self.start >= self.start
def __cmp__(self, other):
return (self.start > other.start) - (self.start < other.start)
def __len__(self):
raise TypeError("len() doesn't support infinite values, use the 'size' attribute instead")
def __nonzero__(self): # Python 2
return self.size != 0
def __bool__(self): # Python 3
return self.size != 0
def __repr__(self):
return '<Interval: %s>' % str(self.tuple)
def intersect(self, other):
start = max(self.start, other.start)
end = min(self.end, other.end)
if end > start:
return Interval(start, end)
def overlaps(self, other):
earlier = self if self.start <= other.start else other
later = self if earlier is other else other
return earlier.end >= later.start
def union(self, other):
if not self.overlaps(other):
raise TypeError("Union of disjoint intervals is not an interval")
start = min(self.start, other.start)
end = max(self.end, other.end)
return Interval(start, end)
def union_overlapping(intervals):
"""Union any overlapping intervals in the given set."""
disjoint_intervals = []
for interval in intervals:
if disjoint_intervals and disjoint_intervals[-1].overlaps(interval):
disjoint_intervals[-1] = disjoint_intervals[-1].union(interval)
else:
disjoint_intervals.append(interval)
return disjoint_intervals
|