1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
from __future__ import absolute_import, print_function
import os
import sys
import tempfile
import numpy
from numpy.testing import TestCase, assert_, run_module_suite
from scipy.weave import inline_tools, ext_tools
from scipy.weave.build_tools import msvc_exists, gcc_exists
from scipy.weave.catalog import unique_file
from scipy.weave.numpy_scalar_spec import numpy_complex_scalar_converter
from weave_test_utils import dec
def unique_mod(d,file_name):
f = os.path.basename(unique_file(d,file_name))
m = os.path.splitext(f)[0]
return m
#----------------------------------------------------------------------------
# Scalar conversion test classes
# int, float, complex
#----------------------------------------------------------------------------
class NumpyComplexScalarConverter(TestCase):
compiler = ''
def setUp(self):
self.converter = numpy_complex_scalar_converter()
@dec.slow
def test_type_match_string(self):
assert_(not self.converter.type_match('string'))
@dec.slow
def test_type_match_int(self):
assert_(not self.converter.type_match(5))
@dec.slow
def test_type_match_float(self):
assert_(not self.converter.type_match(5.))
@dec.slow
def test_type_match_complex128(self):
assert_(self.converter.type_match(numpy.complex128(5.+1j)))
@dec.slow
def test_complex_var_in(self):
mod_name = sys._getframe().f_code.co_name + self.compiler
mod_name = unique_mod(test_dir,mod_name)
mod = ext_tools.ext_module(mod_name)
a = numpy.complex(1.+1j)
code = "a=std::complex<double>(2.,2.);"
test = ext_tools.ext_function('test',code,['a'])
mod.add_function(test)
mod.compile(location=test_dir, compiler=self.compiler)
exec('from ' + mod_name + ' import test')
b = numpy.complex128(1.+1j)
test(b)
try:
b = 1.
test(b)
except TypeError:
pass
try:
b = 'abc'
test(b)
except TypeError:
pass
@dec.slow
def test_complex_return(self):
mod_name = sys._getframe().f_code.co_name + self.compiler
mod_name = unique_mod(test_dir,mod_name)
mod = ext_tools.ext_module(mod_name)
a = 1.+1j
code = """
a= a + std::complex<double>(2.,2.);
return_val = PyComplex_FromDoubles(a.real(),a.imag());
"""
test = ext_tools.ext_function('test',code,['a'])
mod.add_function(test)
mod.compile(location=test_dir, compiler=self.compiler)
exec('from ' + mod_name + ' import test')
b = 1.+1j
c = test(b)
assert_(c == 3.+3j)
@dec.slow
def test_inline(self):
a = numpy.complex128(1+1j)
result = inline_tools.inline("return_val=1.0/a;",['a'])
assert_(result == .5-.5j)
for _n in dir():
if _n[-9:] == 'Converter':
if msvc_exists():
exec("class Test%sMsvc(%s):\n compiler = 'msvc'" % (_n,_n))
else:
exec("class Test%sUnix(%s):\n compiler = ''" % (_n,_n))
if gcc_exists():
exec("class Test%sGcc(%s):\n compiler = 'gcc'" % (_n,_n))
def setup_test_location():
test_dir = tempfile.mkdtemp()
sys.path.insert(0,test_dir)
return test_dir
test_dir = setup_test_location()
def teardown_test_location():
import tempfile
test_dir = os.path.join(tempfile.gettempdir(),'test_files')
if sys.path[0] == test_dir:
sys.path = sys.path[1:]
return test_dir
if not msvc_exists():
for _n in dir():
if _n[:8] == 'TestMsvc':
exec('del '+_n)
else:
for _n in dir():
if _n[:8] == 'TestUnix':
exec('del '+_n)
if not (gcc_exists() and msvc_exists() and sys.platform == 'win32'):
for _n in dir():
if _n[:7] == 'TestGcc':
exec('del '+_n)
if __name__ == "__main__":
run_module_suite()
|