File: test_extensions.py

package info (click to toggle)
pcbasic 2.0.7-8
  • links: PTS
  • area: main
  • in suites: forky, sid
  • size: 35,416 kB
  • sloc: python: 28,411; sh: 103; makefile: 10
file content (162 lines) | stat: -rw-r--r-- 5,039 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
PC-BASIC tests.test_extension
unit tests for extensions

(c) 2020--2023 Rob Hagemans
This file is released under the GNU GPL version 3 or later.
"""

import os

from pcbasic import Session
from tests.unit.utils import TestCase, run_tests


class ExtensionTest(TestCase):
    """Unit tests for extensions."""

    tag = u'extensions'

    def test_extension(self):
        """Test extension functions."""

        class Extension(object):
            @staticmethod
            def add(x, y):
                return '%s plus %s equals %s' % (repr(x), repr(y), repr(x+y))

            @staticmethod
            def one():
                return 1

        with Session(extension=Extension) as s:
            s.execute('''
                10 a=5
                run
                b$ = _add(a, 1)
                c% = _one
            ''')
            assert s.get_variable("a!") == 5
            assert s.get_variable("c%") == 1
            assert s.get_variable("b$") == b'5.0 plus 1 equals 6.0'

    def test_extension_statement(self):
        """Test extension statements."""
        outfile = self.output_path('python-output.txt')

        class Extension(object):
            @staticmethod
            def output(*args):
                with open(outfile, 'ab') as g:
                    for arg in args:
                        if isinstance(arg, bytes):
                            g.write(arg)
                        else:
                            g.write(b'%d' % (arg,))
                        g.write(b' ')

        with Session(extension=Extension) as s:
            s.execute(b'''
                _OUTPUT "one", 2, 3!, 4#
                _output "!\x9c$"
            ''')
        with open(outfile, 'rb') as f:
            assert f.read() == b'one 2 3 4 !\x9c$ '

    def test_extended_session(self):
        """Test extensions accessing the session."""

        class ExtendedSession(Session):
            def __init__(self):
                Session.__init__(self, extension=self)

            def adda(self, x):
                return x + self.get_variable("a!")

        with ExtendedSession() as s:
            s.execute('a=4')
            s.execute('b=_adda(1)')
            assert s.evaluate('b') == 5.

    def test_extension_module(self):
        """Test using a module as extension."""
        import random
        with Session(extension=random) as s:
            s.execute('''
                _seed(42)
                b = _uniform(a, 25.6)
            ''')
            self.assertAlmostEqual(s.evaluate('b'), 16.3693256378, places=10)

    def test_extension_module_string(self):
        """Test using a module name as extension."""
        with Session(extension='random') as s:
            s.execute('''
                _seed(42)
                b = _uniform(a, 25.6)
            ''')
            self.assertAlmostEqual(s.evaluate('b'), 16.3693256378, places=10)

    def test_extension_module_not_found(self):
        """Test using a non-existant module name as extension."""
        with Session(extension='no-sirree') as s:
            s.execute('_test')
        assert self.get_text_stripped(s)[0] == b'Internal error\xff'

    def test_no_extension(self):
        """Test attempting to access extensions that aren't there."""
        with Session() as s:
            s.execute(b'''
                _NOPE "one", 2, 3!, 4#
            ''')
        assert self.get_text_stripped(s)[0] == b'Syntax error\xff'

    def test_no_statement(self):
        """Test attempting to access extensions that aren't there."""
        empty_ext = object()
        with Session(extension=empty_ext) as s:
            s.execute(b'''
                _NOPE "one", 2, 3!, 4#
            ''')
        assert self.get_text_stripped(s)[0] == b'Internal error\xff'

    def test_extension_function(self):
        """Test extension functions."""
        class Extension(object):
            @staticmethod
            def boolfunc():
                return True
            @staticmethod
            def unicodefunc():
                return u'test'
            @staticmethod
            def bytesfunc():
                return b'test'
            @staticmethod
            def intfunc():
                return 1
            @staticmethod
            def floatfunc():
                return 1

        with Session(extension=Extension) as s:
            assert s.evaluate('_BOOLFUNC') == -1
            assert s.evaluate('_INTFUNC') == 1.0
            assert s.evaluate('_FLOATFUNC') == 1.0
            assert s.evaluate('_UNICODEFUNC') == b'test'
            assert s.evaluate('_BYTESFUNC') == b'test'


    def test_extension_function_none(self):
        """Test extension functions with disallowed return type."""
        class Extension(object):
            @staticmethod
            def nonefunc():
                return None
        with Session(extension=Extension) as s:
            s.evaluate('_NONEFUNC')
        assert self.get_text_stripped(s)[0] == b'Type mismatch\xff'


if __name__ == '__main__':
    run_tests()