1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
"""
PC-BASIC tests.test_extension
unit tests for extensions
(c) 2020--2023 Rob Hagemans
This file is released under the GNU GPL version 3 or later.
"""
import os
from pcbasic import Session
from tests.unit.utils import TestCase, run_tests
class ExtensionTest(TestCase):
"""Unit tests for extensions."""
tag = u'extensions'
def test_extension(self):
"""Test extension functions."""
class Extension(object):
@staticmethod
def add(x, y):
return '%s plus %s equals %s' % (repr(x), repr(y), repr(x+y))
@staticmethod
def one():
return 1
with Session(extension=Extension) as s:
s.execute('''
10 a=5
run
b$ = _add(a, 1)
c% = _one
''')
assert s.get_variable("a!") == 5
assert s.get_variable("c%") == 1
assert s.get_variable("b$") == b'5.0 plus 1 equals 6.0'
def test_extension_statement(self):
"""Test extension statements."""
outfile = self.output_path('python-output.txt')
class Extension(object):
@staticmethod
def output(*args):
with open(outfile, 'ab') as g:
for arg in args:
if isinstance(arg, bytes):
g.write(arg)
else:
g.write(b'%d' % (arg,))
g.write(b' ')
with Session(extension=Extension) as s:
s.execute(b'''
_OUTPUT "one", 2, 3!, 4#
_output "!\x9c$"
''')
with open(outfile, 'rb') as f:
assert f.read() == b'one 2 3 4 !\x9c$ '
def test_extended_session(self):
"""Test extensions accessing the session."""
class ExtendedSession(Session):
def __init__(self):
Session.__init__(self, extension=self)
def adda(self, x):
return x + self.get_variable("a!")
with ExtendedSession() as s:
s.execute('a=4')
s.execute('b=_adda(1)')
assert s.evaluate('b') == 5.
def test_extension_module(self):
"""Test using a module as extension."""
import random
with Session(extension=random) as s:
s.execute('''
_seed(42)
b = _uniform(a, 25.6)
''')
self.assertAlmostEqual(s.evaluate('b'), 16.3693256378, places=10)
def test_extension_module_string(self):
"""Test using a module name as extension."""
with Session(extension='random') as s:
s.execute('''
_seed(42)
b = _uniform(a, 25.6)
''')
self.assertAlmostEqual(s.evaluate('b'), 16.3693256378, places=10)
def test_extension_module_not_found(self):
"""Test using a non-existant module name as extension."""
with Session(extension='no-sirree') as s:
s.execute('_test')
assert self.get_text_stripped(s)[0] == b'Internal error\xff'
def test_no_extension(self):
"""Test attempting to access extensions that aren't there."""
with Session() as s:
s.execute(b'''
_NOPE "one", 2, 3!, 4#
''')
assert self.get_text_stripped(s)[0] == b'Syntax error\xff'
def test_no_statement(self):
"""Test attempting to access extensions that aren't there."""
empty_ext = object()
with Session(extension=empty_ext) as s:
s.execute(b'''
_NOPE "one", 2, 3!, 4#
''')
assert self.get_text_stripped(s)[0] == b'Internal error\xff'
def test_extension_function(self):
"""Test extension functions."""
class Extension(object):
@staticmethod
def boolfunc():
return True
@staticmethod
def unicodefunc():
return u'test'
@staticmethod
def bytesfunc():
return b'test'
@staticmethod
def intfunc():
return 1
@staticmethod
def floatfunc():
return 1
with Session(extension=Extension) as s:
assert s.evaluate('_BOOLFUNC') == -1
assert s.evaluate('_INTFUNC') == 1.0
assert s.evaluate('_FLOATFUNC') == 1.0
assert s.evaluate('_UNICODEFUNC') == b'test'
assert s.evaluate('_BYTESFUNC') == b'test'
def test_extension_function_none(self):
"""Test extension functions with disallowed return type."""
class Extension(object):
@staticmethod
def nonefunc():
return None
with Session(extension=Extension) as s:
s.evaluate('_NONEFUNC')
assert self.get_text_stripped(s)[0] == b'Type mismatch\xff'
if __name__ == '__main__':
run_tests()
|