File: test_dependency.py

package info (click to toggle)
ipython 2.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 28,032 kB
  • ctags: 15,433
  • sloc: python: 73,792; makefile: 428; sh: 297
file content (136 lines) | stat: -rw-r--r-- 4,136 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Tests for dependency.py

Authors:

* Min RK
"""

__docformat__ = "restructuredtext en"

#-------------------------------------------------------------------------------
#  Copyright (C) 2011  The IPython Development Team
#
#  Distributed under the terms of the BSD License.  The full license is in
#  the file COPYING, distributed as part of this software.
#-------------------------------------------------------------------------------

#-------------------------------------------------------------------------------
# Imports
#-------------------------------------------------------------------------------

# import
import os

from IPython.utils.pickleutil import can, uncan

import IPython.parallel as pmod
from IPython.parallel.util import interactive

from IPython.parallel.tests import add_engines
from .clienttest import ClusterTestCase

def setup():
    add_engines(1, total=True)

@pmod.require('time')
def wait(n):
    time.sleep(n)
    return n

@pmod.interactive
def func(x):
    return x*x

mixed = list(map(str, range(10)))
completed = list(map(str, range(0,10,2)))
failed = list(map(str, range(1,10,2)))

class DependencyTest(ClusterTestCase):
    
    def setUp(self):
        ClusterTestCase.setUp(self)
        self.user_ns = {'__builtins__' : __builtins__}
        self.view = self.client.load_balanced_view()
        self.dview = self.client[-1]
        self.succeeded = set(map(str, range(0,25,2)))
        self.failed = set(map(str, range(1,25,2)))
    
    def assertMet(self, dep):
        self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
        
    def assertUnmet(self, dep):
        self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
        
    def assertUnreachable(self, dep):
        self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
    
    def assertReachable(self, dep):
        self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
    
    def cancan(self, f):
        """decorator to pass through canning into self.user_ns"""
        return uncan(can(f), self.user_ns)
    
    def test_require_imports(self):
        """test that @require imports names"""
        @self.cancan
        @pmod.require('base64')
        @interactive
        def encode(arg):
            return base64.b64encode(arg)
        # must pass through canning to properly connect namespaces
        self.assertEqual(encode(b'foo'), b'Zm9v')
    
    def test_success_only(self):
        dep = pmod.Dependency(mixed, success=True, failure=False)
        self.assertUnmet(dep)
        self.assertUnreachable(dep)
        dep.all=False
        self.assertMet(dep)
        self.assertReachable(dep)
        dep = pmod.Dependency(completed, success=True, failure=False)
        self.assertMet(dep)
        self.assertReachable(dep)
        dep.all=False
        self.assertMet(dep)
        self.assertReachable(dep)

    def test_failure_only(self):
        dep = pmod.Dependency(mixed, success=False, failure=True)
        self.assertUnmet(dep)
        self.assertUnreachable(dep)
        dep.all=False
        self.assertMet(dep)
        self.assertReachable(dep)
        dep = pmod.Dependency(completed, success=False, failure=True)
        self.assertUnmet(dep)
        self.assertUnreachable(dep)
        dep.all=False
        self.assertUnmet(dep)
        self.assertUnreachable(dep)
    
    def test_require_function(self):
        
        @pmod.interactive
        def bar(a):
            return func(a)

        @pmod.require(func)
        @pmod.interactive
        def bar2(a):
            return func(a)
        
        self.client[:].clear()
        self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
        ar = self.view.apply_async(bar2, 5)
        self.assertEqual(ar.get(5), func(5))

    def test_require_object(self):
        
        @pmod.require(foo=func)
        @pmod.interactive
        def bar(a):
            return foo(a)

        ar = self.view.apply_async(bar, 5)
        self.assertEqual(ar.get(5), func(5))