File: interface.py

package info (click to toggle)
python-scipy 0.6.0-12
  • links: PTS, VCS
  • area: main
  • in suites: lenny
  • size: 32,016 kB
  • ctags: 46,675
  • sloc: cpp: 124,854; ansic: 110,614; python: 108,664; fortran: 76,260; objc: 424; makefile: 384; sh: 10
file content (519 lines) | stat: -rw-r--r-- 14,857 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
from numpy.core.umath import *
from scipy import *
from numpy import limits, display_test
import sys

_have_wx = 0
if not display_test.have_x11() or display_test.try_XOpenDisplay():
    try:
        import wxPython
        _have_wx = 1
    except ImportError,msg:
        print __file__,msg

if _have_wx:
    import wxplt
    import plot_objects
    import gui_thread
    plot_module = wxplt
    plot_class = gui_thread.register(plot_module.plot_frame)

_figure = []
_active = None


def figure(which_one = None):
    global _figure; global _active
    if which_one is None:
        title ='Figure %d' % len(_figure)
        _figure.append(plot_class(title=title))
        _active = _figure[-1]
    elif (type(which_one) == type(1)) or (type(which_one) == type(1.)):
        try:
            _active = _figure[int(which_one)]
            _active.Raise()
        except IndexError:
            msg = "There are currently only %d active figures" % len(_figure)
            raise IndexError, msg
    elif which_one in _figure:
        _active = which_one
        _active.Raise()
    else:
        try:
            if which_one.__type_hack__ == "plot_canvas":
                _active = which_one
                _figure.append(_active)
                _active.Raise()
            else:
                raise ValueError, "The specified figure or index is not not known"
        except (AtrributeError):
            pass
    fig = current()
    return fig


def validate_active():
    global _active
    if _active is None: figure()
    try:
        if not _active.proxy_object_alive:
            _active = None
            figure()
    except:
        pass

def current():
    return _active

def redraw():
    validate_active()
    _active.redraw()

def close(which_one = None):
    global _figure; global _active
    if which_one is None:
        try:
            _active.Close()
            _figure.remove(_active)
        except ValueError:
            pass

        try:
            # should make sure the new plot window really exist
            set_new_active()
        except IndexError: _active = None
    elif which_one == 'all':
        for fig in _figure: fig.Close()
        _active = None
    else:
        raise NotImplementedError, "currently close only works with"\
                                   " _active window or 'all'"
        #try:
        #   _figure.remove(which_one)
        #   which_one.close()
        #except ValueError:
        #   which_one.close()

def set_new_active():
    # should validate new active here
    try:
        _active = _figure[-1]
    except IndexError:
        _active = None


def _auto_all():
    validate_active()
    _active.x_axis.bounds = ['auto','auto']
    _active.y_axis.bounds = ['auto','auto']
    _active.x_axis.tick_interval = 'auto'
    _active.y_axis.tick_interval = 'auto'
def autoscale():
    validate_active()
    _auto_all()
    _active.update()

def _an_axis(ax,setting):
    ticks = ax.ticks
    interval = ax.ticks[1]- ax.ticks[0]
    if setting in ['normal','auto']:
        ax.bounds = ['auto','auto']
    elif setting == 'freeze':
        ax.bounds = [axes[0],axes[1]]
        ax.tick_interval = interval
    elif setting in ['tight','fit']:
        ax.bounds = ['fit','fit']
        ax.tick_interval = 'auto'
    else:
        ax.bounds = [setting[0],setting[1]]
        if len(setting) > 2:
            ax.tick_interval = setting[2]

def xaxis(rng):
    validate_active()
    _an_axis(_active.x_axis,rng)
    _active.update()

def yaxis(rng):
    validate_active()
    _an_axis(_active.y_axis,rng)
    _active.update()

def title(name):
    validate_active()
    _active.title.text = name
    _active.update()

def xtitle(name):
    validate_active()
    _active.x_title.text = name
    _active.update()

def ytitle(name):
    validate_active()
    _active.y_title.text = name
    _active.update()

on = 'on'
off = 'off'
def grid(state=None):
    validate_active()
    if state is None:
        if _active.x_axis.grid_visible in ['on','yes']:
            _active.x_axis.grid_visible = 'off'
            _active.y_axis.grid_visible = 'off'
        else:
            _active.x_axis.grid_visible = 'on'
            _active.y_axis.grid_visible = 'on'
    elif state in ['on','off','yes','no']:
        _active.x_axis.grid_visible = state
        _active.y_axis.grid_visible = state
    else:
        raise ValueError, 'grid argument can be "on","off",'\
                          '"yes","no". Not ' + state
    _active.update()

def hold(state):
    validate_active()
    if state in ['on','off','yes','no']:
        _active.hold = state
    else:
        raise ValueError, 'holds argument can be "on","off",'\
                          '"yes","no". Not ' + state

def axis(setting):
    validate_active()
    x_ticks = _active.x_axis.ticks
    print type(x_ticks), dir(x_ticks)
    x_interval = x_ticks[1]- x_ticks[0]
    y_ticks = _active.y_axis.ticks
    x_interval = x_ticks[1]- y_ticks[0]
    axes = array((x_ticks[0],x_ticks[-1],y_ticks[0],y_ticks[-1]),Float)
    # had to use client below cause of __setattr__ troubles in plot_frame
    if setting == 'normal':
        _active.client.aspect_ratio = setting
        _auto_all()
    elif setting == 'equal':
        _active.client.aspect_ratio = setting
    elif setting == 'freeze':
        _active.x_axis.bounds = [axes[0],axes[1]]
        _active.y_axis.bounds = [axes[2],axes[3]]
        _active.x_axis.tick_interval = x_interval
        _active.x_axis.tick_interval = y_interval
    elif setting in ['tight','fit']:
        _active.x_axis.bounds = ['fit','fit']
        _active.y_axis.bounds = ['fit','fit']
        _active.x_axis.tick_interval = 'auto'
        _active.x_axis.tick_interval = 'auto'
    else:
        _active.x_axis.bounds = [setting[0],setting[1]]
        _active.y_axis.bounds = [setting[2],setting[3]]
    _active.update()

def save(file_name,format='png'):
    _active.save(file_name,format)


##########################################################
#----------------- plotting machinery -------------------#
##########################################################

#---- array utilities ------------

def is1D(a):
    as = shape(a)
    if(len(as) == 1):
        return 1
    if(as[0] == 1 or as[1]==1):
        return 1
    return 0

def row(a):
    return reshape(asarray(a),[1,-1])
def col(a):
    return reshape(asarray(a),[-1,1])

SizeMismatch = 'SizeMismatch'
SizeError = 'SizeError'
NotImplemented = 'NotImplemented'

#------------ Numerical constants ----------------

# really should do better than this...
BIG = 1e20
SMALL = 1e-20

#------------ plot group parsing -----------------
from types import *

def plot_groups(data):
    remains = data; groups = []
    while len(remains):
        group,remains = get_plot_group(remains)
        groups.append(group)
    return groups

def get_plot_group(data):
    group = ()
    remains = data
    state = 0
    finished = 0
    while(len(remains) > 0 and not finished):
        el = remains[0]
        if(state == 0):
            el = asarray(el)
            state = 1
        elif(state == 1):
            if(type(el) == StringType):
                finished = 1
            else:
                el = asarray(el)
            state = 2
        elif(state == 2):
            finished = 1
            if(type(el) != StringType):
                break
        try:
            if el.typecode() == 'D':
                print 'warning plotting magnitude of complex values'
                el = abs(el)
        except:
            pass
        group = group + (el,)
        remains = remains[1:]
    return group, remains

def hstack(tup):
        #horizontal stack (column wise)
    return concatenate(tup,1)

def lines_from_group(group):
    lines = []
    plotinfo = ''
    x = group[0]
    ar_num = 1
    if len(group) > 1:
        if type(group[1]) == StringType:
            plotinfo = group[1]
        else:
            ar_num = 2
            y = group[1]
    if len(group) == 3:
        plotinfo = group[2]
    #force 1D arrays to 2D columns
    if is1D(x):
        x = col(x)
    if ar_num == 2 and is1D(y):
        y = col(y)

    xs = shape(x)
    if ar_num == 2:  ys = shape(y)
    #test that x and y have compatible shapes
    if ar_num == 2:
        #check that each array has the same number of rows
        if(xs[0] != ys[0] ):
            raise SizeMismatch, ('rows', xs, ys)
        #check that x.cols = y.cols
        #no error x has 1 column
        if(xs[1] > 1 and xs[1] != ys[1]):
            raise SizeMismatch, ('cols', xs, ys)

    #plot x against index
    if(ar_num == 1):
        for y_data in transpose(x):
            index = arange(len(y_data))
            pts = hstack(( col(index), col(y_data) ))
            pts = remove_bad_vals(pts)
            line = plot_module.line_object(pts)
            lines.append(line)
    #plot x vs y
    elif(ar_num ==2):
        #x is effectively 1D
        if(xs[1] == 1):
            for y_data in transpose(y):
                pts = hstack(( col(x), col(y_data) ))
                pts = remove_bad_vals(pts)
                line = plot_module.line_object(pts)
                lines.append(line)
        #x is 2D
        else:
            x = transpose(x); y = transpose(y)
            for i in range(len(x)):
                pts = hstack(( col(x[i]), col(y[i]) ))
                pts = remove_bad_vals(pts)
                line = plot_module.line_object(pts)
                lines.append(line)
    color,marker,line_type = process_format(plotinfo)
    #print color,marker,line_type
    for line in lines:
        if color != 'auto':
            line.color = 'custom'
            line.set_color(color)
            #print color
        if not marker:
            line.marker_type = 'custom'
            line.markers.visible = 'no'
            #print marker
        elif marker != 'auto':
            line.marker_type = 'custom'
            line.markers.symbol = marker
            line.markers.visible = 'yes'
            #print marker
        if not line_type:
            line.line_type = 'custom'
            line.line.visible = 'no'
        elif line_type != 'auto':
            line.line_type = 'custom'
            line.line.visible = 'yes'
            line.line.style = line_type
            #print line_type
        #print line.markers.visible,    line.line.visible,
    return lines

import re
color_re = re.compile('[ymcrgbwk]')
color_trans = {'y':'yellow','m':'magenta','c':'cyan','r':'red','g':'green',
               'b':'blue', 'w':'white','k':'black'}
# this one isn't quite right
marker_re = re.compile('[ox+s^v]|(?:[^-])[.]')
marker_trans = {'.':'dot','o':'circle','x':'cross','+':'plus','s':'square',
                '^':'triangle','v':'down_triangle'}

line_re = re.compile('--|-\.|[-:]')
line_trans = {'-':'solid',':':'dot','-.':'dot dash','--':'dash'}
def process_format(format):
    if format == '':
        return 'auto','auto','auto'
    color,marker,line = 'auto',None,None
    m = color_re.findall(format)
    if len(m): color = color_trans[m[0]]
    m = marker_re.findall(format)
    # the -1 takes care of 'r.', etc
    if len(m): marker = marker_trans[m[0][-1]]
    m = line_re.findall(format)
    if len(m): line = line_trans[m[0]]
    return color,marker,line

def remove_bad_vals(x):
    # !! Fix axis order when interface changed.
    # mapping:
    #    NaN -> 0
    #    Inf -> limits.double_max
    #   -Inf -> limits.double_min
    y = nan_to_num(x)
    big = limits.double_max / 10
    small = limits.double_min / 10
    y = clip(y,small,big)
    return y

def stem(*data):
    if len(data) == 1:
        n = arange(len(data[0]))
        x = data[0]
        ltype = ['b-','mo']
    if len(data) == 2:
        if type(data[1]) is types.StringType:
            ltype = [data[1],'mo']
            n = arange(len(data[0]))
            x = data[0]
        elif type(data[1]) in [types.ListType, types.TupleType]:
            n = arange(len(data[0]))
            x = data[0]
            ltype = data[1][:2]
        else:
            n = data[0]
            x = data[1]
            ltype = ['b-','mo']
    elif len(data) > 2:
        n = data[0]
        x = data[1]
        ltype = data[2]
        if type(ltype) is types.StringType:
            ltype = [ltype,'mo']
    else:
        raise ValueError, "Invalid input arguments."

    if len(n) != len(x):
        raise SizeMismatch, ('lengths', len(n), len(x))
    # line at zero:
    newdata = []
    newdata.extend([[n[0],n[-1]],[0,0],ltype[0]])

    # stems
    for k in range(len(x)):
        newdata.extend([[n[k],n[k]],[0,x[k]],ltype[0]])

    # circles
    newdata.extend([n,x,ltype[1]])
    keywds = {'fill_style': 'transparent'}
    return plot(*newdata,**keywds)



def plot(*data,**keywds):
    groups = plot_groups(data)
    lines = []
    for group in groups:
        lines.extend(lines_from_group(group))
        #default to markers being invisible
        #lines[-1].markers.visible = 'no'
    # check for hold here
    for name in plot_objects.poly_marker._attributes.keys():
        value = keywds.get(name)
        if value is not None:
            for k in range(len(lines)):
                exec('lines[k].markers.%s = value' % name)
    validate_active()
    if not _active.hold in ['on','yes']:
        _active.line_list.data = [] # clear it out
        _active.image_list.data = [] # clear it out
    for i in lines:
        _active.line_list.append(i)
    _active.update()
    return _active

def markers(visible=None):
    pass

#-------------------------------------------------------------------#
#--------------------------- image ---------------------------------#
#-------------------------------------------------------------------#

def image(img,x=None,y=None,colormap = 'grey',scale='no'):
    """Colormap should really default to the current colormap..."""
    # check for hold here
    validate_active()
    image = wxplt.image_object(img,x,y,colormap=colormap,scale=scale)
    if not _active.hold in ['on','yes']:
        _active.line_list.data = [] # clear it out
        _active.image_list.data = [] # clear it out
        _active.image_list.append(image)
        try:
            axis('equal')
        except AttributeError:
            # cluge to handle case where ticks didn't exist when
            # calling axis()
            _active.client.layout_all()
            axis('equal')
    else:
        _active.image_list.append(image)
        _active.update()
    return _active

def imagesc(img,x=None,y=None,colormap = 'grey'):
    image(img,x,y,colormap,scale='yes')

#matlab equivalence
xlabel = xtitle
ylabel = ytitle

def speed_test():
    p = plot([1,2,3],'r:o')
    s1 = (200,200)
    s2 = (400,400)
    p.SetSize(s1)
    for i in range(20):
        if p.GetSizeTuple()[0] == 200:
            p.SetSize(s2)
        else:
            p.SetSize(s1)