1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
"""
Unit tests for loading data files using the extension registry
"""
import logging
import unittest
import os
import shutil
import numpy as np
from sasdata.dataloader.loader import Registry as Loader
from sasdata.dataloader.loader import Loader as LoaderMain
import os
import pytest
skip_network_tests = os.environ.get("SKIP_NETWORK_TESTS", False)
logger = logging.getLogger(__name__)
BASE_URL = 'https://github.com/SasView/sasdata/raw/master/test/sasdataloader/data/'
def find(filename):
return os.path.join(os.path.dirname(__file__), 'data', filename)
class ExtensionRegistryTests(unittest.TestCase):
def setUp(self):
# Local and remote files to compare loading
# NXcanSAS
self.valid_hdf_file = find("MAR07232_rest.h5")
self.valid_hdf_url = BASE_URL + "MAR07232_rest.h5"
# canSAS XML
self.valid_xml_file = find("valid_cansas_xml.xml")
self.valid_xml_url = BASE_URL + "valid_cansas_xml.xml"
# ASCII Text
self.valid_txt_file = find("avg_testdata.txt")
self.valid_txt_url = BASE_URL + "avg_testdata.txt"
# ABS Text
self.valid_abs_file = find("ascii_test_4.abs")
self.valid_abs_url = BASE_URL + "ascii_test_4.abs"
# DAT 2D NIST format
self.valid_dat_file = find("detector_square.dat")
self.valid_dat_url = BASE_URL + "detector_square.dat"
# Anton Parr SAXSess PDH format
self.valid_pdh_file = find("Anton-Paar.pdh")
self.valid_pdh_url = BASE_URL + "Anton-Paar.pdh"
self.valid_file_wrong_known_ext = find("valid_cansas_xml.txt")
self.valid_file_wrong_unknown_ext = find("valid_cansas_xml.xyz")
shutil.copyfile(self.valid_xml_file, self.valid_file_wrong_known_ext)
shutil.copyfile(self.valid_xml_file, self.valid_file_wrong_unknown_ext)
self.invalid_file = find("cansas1d_notitle.xml")
self.loader = Loader()
def test_wrong_known_ext(self):
"""
Load a valid CanSAS XML file that has the extension '.txt', which is in
the extension registry. Compare the results to loading the same file
with the extension '.xml'
"""
correct = self.loader.load(self.valid_xml_file)
wrong_ext = self.loader.load(self.valid_file_wrong_known_ext)
self.assertEqual(len(correct), 1)
self.assertEqual(len(wrong_ext), 1)
correct = correct[0]
wrong_ext = wrong_ext[0]
self.assertTrue(np.all(correct.x == wrong_ext.x))
self.assertTrue(np.all(correct.y == wrong_ext.y))
self.assertTrue(np.all(correct.dy == wrong_ext.dy))
def test_wrong_unknown_ext(self):
"""
Load a valid CanSAS XML file that has the extension '.xyz', which isn't
in the extension registry. Compare the results to loading the same file
with the extension '.xml'
"""
correct = self.loader.load(self.valid_xml_file)
wrong_ext = self.loader.load(self.valid_file_wrong_unknown_ext)
self.assertEqual(len(correct), 1)
self.assertEqual(len(wrong_ext), 1)
correct = correct[0]
wrong_ext = wrong_ext[0]
self.assertTrue(np.all(correct.x == wrong_ext.x))
self.assertTrue(np.all(correct.y == wrong_ext.y))
self.assertTrue(np.all(correct.dy == wrong_ext.dy))
def test_data_reader_exception(self):
"""
Load a CanSAS XML file that doesn't meet the schema, and check errors
are set correctly
"""
data = self.loader.load(self.invalid_file)
self.assertEqual(len(data), 1)
data = data[0]
self.assertEqual(len(data.errors), 1)
err_msg = data.errors[0]
self.assertTrue("does not fully meet the CanSAS v1.x specification" in err_msg)
@pytest.mark.skipif(skip_network_tests, reason="Requires downloading data from network")
def test_compare_remote_file_to_local(self):
"""Load the same file from a local directory and a remote URL and compare data objects."""
# ASCII Text file loading
remote_txt = self.loader.load(self.valid_txt_url)
local_txt = self.loader.load(self.valid_txt_file)
# Ensure the string representation of the file contents match
self.assertEqual(str(local_txt[0]), str(remote_txt[0]))
# NXcanSAS file loading
local_hdf = self.loader.load(self.valid_hdf_file)
remote_hdf = self.loader.load(self.valid_hdf_url)
# Ensure the string representation of the file contents match
self.assertEqual(str(local_hdf[0]), str(remote_hdf[0]))
# canSAS XML file loading
local_xml = self.loader.load(self.valid_xml_file)
remote_xml = self.loader.load(self.valid_xml_url)
# Ensure the string representation of the file contents match
self.assertEqual(str(local_xml[0]), str(remote_xml[0]))
# ABS file loading
local_abs = self.loader.load(self.valid_abs_file)
remote_abs = self.loader.load(self.valid_abs_url)
# Ensure the string representation of the file contents match
self.assertEqual(str(local_abs[0]), str(remote_abs[0]))
# DAT file loading
local_dat = self.loader.load(self.valid_dat_file)
remote_dat = self.loader.load(self.valid_dat_url)
# Ensure the string representation of the file contents match
self.assertEqual(str(local_dat[0]), str(remote_dat[0]))
# PDH file loading
local_pdh = self.loader.load(self.valid_pdh_file)
remote_pdh = self.loader.load(self.valid_pdh_url)
# Ensure the string representation of the file contents match
self.assertEqual(str(local_pdh[0]), str(remote_pdh[0]))
def test_load_simultaneously(self):
"""Load a list of files, not just a single file, and ensure the content matches"""
loader = LoaderMain()
local_txt = loader.load(self.valid_txt_file)
local_hdf = loader.load(self.valid_hdf_file)
local_xml = loader.load(self.valid_xml_file)
strings = [str(local_txt[0]), str(local_hdf[0]), str(local_xml[0])]
all_files = loader.load([self.valid_xml_file, self.valid_hdf_file, self.valid_txt_file])
for file in all_files:
self.assertTrue(str(file) in strings)
def tearDown(self):
if os.path.isfile(self.valid_file_wrong_known_ext):
os.remove(self.valid_file_wrong_known_ext)
if os.path.isfile(self.valid_file_wrong_unknown_ext):
os.remove(self.valid_file_wrong_unknown_ext)
|