File: sankey_demo_old.py

package info (click to toggle)
matplotlib 1.1.1~rc2-1
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 66,076 kB
  • sloc: python: 90,600; cpp: 69,891; objc: 5,231; ansic: 1,723; makefile: 171; sh: 7
file content (189 lines) | stat: -rw-r--r-- 6,904 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
#!/usr/bin/env python

__author__ = "Yannick Copin <ycopin@ipnl.in2p3.fr>"
__version__ = "Time-stamp: <10/02/2010 16:49 ycopin@lyopc548.in2p3.fr>"

import numpy as N

def sankey(ax,
           outputs=[100.], outlabels=None,
           inputs=[100.], inlabels='',
           dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
    """Draw a Sankey diagram.

outputs: array of outputs, should sum up to 100%
outlabels: output labels (same length as outputs),
or None (use default labels) or '' (no labels)
inputs and inlabels: similar for inputs
dx: horizontal elongation
dy: vertical elongation
outangle: output arrow angle [deg]
w: output arrow shoulder
inangle: input dip angle
offset: text offset
**kwargs: propagated to Patch (e.g. fill=False)

Return (patch,[intexts,outtexts])."""

    import matplotlib.patches as mpatches
    from matplotlib.path import Path

    outs = N.absolute(outputs)
    outsigns = N.sign(outputs)
    outsigns[-1] = 0 # Last output

    ins = N.absolute(inputs)
    insigns = N.sign(inputs)
    insigns[0] = 0 # First input

    assert sum(outs)==100, "Outputs don't sum up to 100%"
    assert sum(ins)==100, "Inputs don't sum up to 100%"

    def add_output(path, loss, sign=1):
        h = (loss/2+w)*N.tan(outangle/180.*N.pi) # Arrow tip height
        move,(x,y) = path[-1] # Use last point as reference
        if sign==0: # Final loss (horizontal)
            path.extend([(Path.LINETO,[x+dx,y]),
                         (Path.LINETO,[x+dx,y+w]),
                         (Path.LINETO,[x+dx+h,y-loss/2]), # Tip
                         (Path.LINETO,[x+dx,y-loss-w]),
                         (Path.LINETO,[x+dx,y-loss])])
            outtips.append((sign,path[-3][1]))
        else: # Intermediate loss (vertical)
            path.extend([(Path.CURVE4,[x+dx/2,y]),
                         (Path.CURVE4,[x+dx,y]),
                         (Path.CURVE4,[x+dx,y+sign*dy]),
                         (Path.LINETO,[x+dx-w,y+sign*dy]),
                         (Path.LINETO,[x+dx+loss/2,y+sign*(dy+h)]), # Tip
                         (Path.LINETO,[x+dx+loss+w,y+sign*dy]),
                         (Path.LINETO,[x+dx+loss,y+sign*dy]),
                         (Path.CURVE3,[x+dx+loss,y-sign*loss]),
                         (Path.CURVE3,[x+dx/2+loss,y-sign*loss])])
            outtips.append((sign,path[-5][1]))

    def add_input(path, gain, sign=1):
        h = (gain/2)*N.tan(inangle/180.*N.pi) # Dip depth
        move,(x,y) = path[-1] # Use last point as reference
        if sign==0: # First gain (horizontal)
            path.extend([(Path.LINETO,[x-dx,y]),
                         (Path.LINETO,[x-dx+h,y+gain/2]), # Dip
                         (Path.LINETO,[x-dx,y+gain])])
            xd,yd = path[-2][1] # Dip position
            indips.append((sign,[xd-h,yd]))
        else: # Intermediate gain (vertical)
            path.extend([(Path.CURVE4,[x-dx/2,y]),
                         (Path.CURVE4,[x-dx,y]),
                         (Path.CURVE4,[x-dx,y+sign*dy]),
                         (Path.LINETO,[x-dx-gain/2,y+sign*(dy-h)]), # Dip
                         (Path.LINETO,[x-dx-gain,y+sign*dy]),
                         (Path.CURVE3,[x-dx-gain,y-sign*gain]),
                         (Path.CURVE3,[x-dx/2-gain,y-sign*gain])])
            xd,yd = path[-4][1] # Dip position
            indips.append((sign,[xd,yd+sign*h]))

    outtips = [] # Output arrow tip dir. and positions
    urpath = [(Path.MOVETO,[0,100])] # 1st point of upper right path
    lrpath = [(Path.LINETO,[0,0])] # 1st point of lower right path
    for loss,sign in zip(outs,outsigns):
        add_output(sign>=0 and urpath or lrpath, loss, sign=sign)

    indips = [] # Input arrow tip dir. and positions
    llpath = [(Path.LINETO,[0,0])] # 1st point of lower left path
    ulpath = [(Path.MOVETO,[0,100])] # 1st point of upper left path
    for gain,sign in zip(ins,insigns)[::-1]:
        add_input(sign<=0 and llpath or ulpath, gain, sign=sign)

    def revert(path):
        """A path is not just revertable by path[::-1] because of Bezier
curves."""
        rpath = []
        nextmove = Path.LINETO
        for move,pos in path[::-1]:
            rpath.append((nextmove,pos))
            nextmove = move
        return rpath

    # Concatenate subpathes in correct order
    path = urpath + revert(lrpath) + llpath + revert(ulpath)

    codes,verts = zip(*path)
    verts = N.array(verts)

    # Path patch
    path = Path(verts,codes)
    patch = mpatches.PathPatch(path, **kwargs)
    ax.add_patch(patch)

    if False: # DEBUG
        print "urpath", urpath
        print "lrpath", revert(lrpath)
        print "llpath", llpath
        print "ulpath", revert(ulpath)

        xs,ys = zip(*verts)
        ax.plot(xs,ys,'go-')

    # Labels

    def set_labels(labels,values):
        """Set or check labels according to values."""
        if labels=='': # No labels
            return labels
        elif labels is None: # Default labels
            return [ '%2d%%' % val for val in values ]
        else:
            assert len(labels)==len(values)
            return labels

    def put_labels(labels,positions,output=True):
        """Put labels to positions."""
        texts = []
        lbls = output and labels or labels[::-1]
        for i,label in enumerate(lbls):
            s,(x,y) = positions[i] # Label direction and position
            if s==0:
                t = ax.text(x+offset,y,label,
                            ha=output and 'left' or 'right', va='center')
            elif s>0:
                t = ax.text(x,y+offset,label, ha='center', va='bottom')
            else:
                t = ax.text(x,y-offset,label, ha='center', va='top')
            texts.append(t)
        return texts

    outlabels = set_labels(outlabels, outs)
    outtexts = put_labels(outlabels, outtips, output=True)

    inlabels = set_labels(inlabels, ins)
    intexts = put_labels(inlabels, indips, output=False)

    # Axes management
    ax.set_xlim(verts[:,0].min()-dx, verts[:,0].max()+dx)
    ax.set_ylim(verts[:,1].min()-dy, verts[:,1].max()+dy)
    ax.set_aspect('equal', adjustable='datalim')

    return patch,[intexts,outtexts]

if __name__=='__main__':

    import matplotlib.pyplot as P

    outputs = [10.,-20.,5.,15.,-10.,40.]
    outlabels = ['First','Second','Third','Fourth','Fifth','Hurray!']
    outlabels = [ s+'\n%d%%' % abs(l) for l,s in zip(outputs,outlabels) ]

    inputs = [60.,-25.,15.]

    fig = P.figure()
    ax = fig.add_subplot(1,1,1, xticks=[],yticks=[],
                         title="Sankey diagram"
                         )

    patch,(intexts,outtexts) = sankey(ax, outputs=outputs, outlabels=outlabels,
                                      inputs=inputs, inlabels=None,
                                      fc='g', alpha=0.2)
    outtexts[1].set_color('r')
    outtexts[-1].set_fontweight('bold')

    P.show()