File: test_BaseCalculator.py

package info (click to toggle)
python-libpyvinyl 1.2.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,020 kB
  • sloc: python: 3,213; makefile: 11
file content (315 lines) | stat: -rw-r--r-- 11,303 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import unittest
import pytest
import os
import shutil
from typing import Union
from pathlib import Path

from libpyvinyl.BaseCalculator import BaseCalculator
from libpyvinyl.BaseData import BaseData, DataCollection
from libpyvinyl.Parameters import CalculatorParameters
from libpyvinyl.AbstractBaseClass import AbstractBaseClass


class NumberData(BaseData):
    """Example dict mapping data"""

    def __init__(
        self,
        key,
        data_dict=None,
        filename=None,
        file_format_class=None,
        file_format_kwargs=None,
    ):
        expected_data = {}

        # DataClass developer's job start
        expected_data["number"] = None
        # DataClass developer's job end

        super().__init__(
            key,
            expected_data,
            data_dict,
            filename,
            file_format_class,
            file_format_kwargs,
        )

    @classmethod
    def supported_formats(self):
        return {}

    @classmethod
    def from_file(cls, filename: str, format_class, key, **kwargs):
        raise NotImplementedError()

    @classmethod
    def from_dict(cls, data_dict, key):
        """Create the data class by a python dictionary."""
        return cls(key, data_dict=data_dict)


class PlusCalculator(BaseCalculator):
    """:class: Specialized calculator, calculates the sum of two datasets."""

    def __init__(
        self,
        name: str,
        input: Union[DataCollection, list, NumberData],
        output_keys: Union[list, str] = ["plus_result"],
        output_data_types=[NumberData],
        output_filenames: Union[list, str] = [],
        instrument_base_dir="./",
        calculator_base_dir="PlusCalculator",
        parameters=None,
    ):
        """A python object calculator example"""
        super().__init__(
            name,
            input,
            output_keys,
            output_data_types=output_data_types,
            output_filenames=output_filenames,
            instrument_base_dir=instrument_base_dir,
            calculator_base_dir=calculator_base_dir,
            parameters=parameters,
        )

    def init_parameters(self):
        parameters = CalculatorParameters()
        times = parameters.new_parameter(
            "plus_times", comment="How many times to do the plus"
        )
        # Set defaults
        times.value = 1

        self.parameters = parameters

    def backengine(self):
        Path(self.base_dir).mkdir(parents=True, exist_ok=True)
        input_num0 = self.input.to_list()[0].get_data()["number"]
        input_num1 = self.input.to_list()[1].get_data()["number"]
        output_num = float(input_num0) + float(input_num1)
        if self.parameters["plus_times"].value > 1:
            for i in range(self.parameters["plus_times"].value - 1):
                output_num += input_num1
        data_dict = {"number": output_num}
        key = self.output_keys[0]
        output_data = self.output[key]
        output_data.set_dict(data_dict)
        return self.output


class BaseCalculatorTest(unittest.TestCase):
    """
    Test class for the BaseCalculator class.
    """

    @classmethod
    def setUpClass(cls):
        """Setting up the test class."""

        input1 = NumberData.from_dict({"number": 1}, "input1")
        input2 = NumberData.from_dict({"number": 1}, "input2")
        input_data = [input1, input2]
        plus = PlusCalculator("plus", input_data)
        cls.__default_calculator = plus
        cls.__default_input = input_data

    @classmethod
    def tearDownClass(cls):
        """Tearing down the test class."""
        del cls.__default_calculator
        del cls.__default_input

    def setUp(self):
        """Setting up a test."""
        self.__files_to_remove = []
        self.__dirs_to_remove = []

    def tearDown(self):
        """Tearing down a test."""

        for f in self.__files_to_remove:
            if os.path.isfile(f):
                os.remove(f)
        for d in self.__dirs_to_remove:
            if os.path.isdir(d):
                shutil.rmtree(d)

    def test_base_class_constructor_raises(self):
        """Test that we cannot construct instances of the base class."""

        self.assertRaises(TypeError, BaseCalculator, "name")

    def test_default_construction(self):
        """Testing the default construction of the class."""

        # Test positional arguments
        calculator = PlusCalculator("test", self.__default_input)

        self.assertIsInstance(calculator, PlusCalculator)
        self.assertIsInstance(calculator, BaseCalculator)
        self.assertIsInstance(calculator, AbstractBaseClass)

    def test_deep_copy(self):
        """Test the copy constructor behaves as expected."""
        # Parameters are not deepcopied by itself
        calculator_copy = self.__default_calculator()
        self.assertEqual(calculator_copy.parameters["plus_times"].value, 1)
        new_parameters = calculator_copy.parameters
        new_parameters["plus_times"] = 5
        self.assertEqual(new_parameters["plus_times"].value, 5)
        self.assertEqual(calculator_copy.parameters["plus_times"].value, 5)

        # Parameters are deepcopied when copy the calculator
        calculator_copy = self.__default_calculator()
        self.assertEqual(calculator_copy.parameters["plus_times"].value, 1)
        calculator_copy.parameters["plus_times"] = 10
        self.assertEqual(calculator_copy.parameters["plus_times"].value, 10)
        self.assertEqual(self.__default_calculator.parameters["plus_times"].value, 1)
        calculator_copy.input["input1"] = NumberData.from_dict({"number": 5}, "input1")
        self.assertEqual(calculator_copy.input["input1"].get_data()["number"], 5)
        self.assertEqual(
            self.__default_calculator.input["input1"].get_data()["number"], 1
        )

        # Calculator reference
        self.assertEqual(calculator_copy.parameters["plus_times"].value, 10)
        calculator_reference = calculator_copy
        self.assertEqual(calculator_reference.parameters["plus_times"].value, 10)
        calculator_reference.parameters["plus_times"] = 3
        self.assertEqual(calculator_reference.parameters["plus_times"].value, 3)
        self.assertEqual(calculator_copy.parameters["plus_times"].value, 3)

        # New parameters can be set while caculator deepcopy
        new_parameters = CalculatorParameters()
        times = new_parameters.new_parameter(
            "plus_times", comment="How many times to do the plus"
        )
        times.value = 1
        new_parameters["plus_times"].value = 5
        new_calculator = self.__default_calculator(parameters=new_parameters)
        self.assertIsInstance(new_calculator, PlusCalculator)
        self.assertIsInstance(new_calculator, BaseCalculator)
        self.assertIsInstance(new_calculator, AbstractBaseClass)
        self.assertEqual(new_calculator.parameters["plus_times"].value, 5)
        self.assertEqual(self.__default_calculator.parameters["plus_times"].value, 1)

    def test_dump(self):
        """Test dumping to file."""
        calculator = self.__default_calculator

        self.__files_to_remove.append(calculator.dump())
        self.__files_to_remove.append(calculator.dump("dump.dill"))

    def test_parameters_in_copied_calculator(self):
        """Test parameters in a copied calculator"""

        calculator = self.__default_calculator
        self.assertEqual(calculator.parameters["plus_times"].value, 1)
        calculator.parameters["plus_times"] = 5
        self.assertEqual(self.__default_calculator.parameters["plus_times"].value, 5)
        calculator.parameters["plus_times"] = 1
        self.assertEqual(self.__default_calculator.parameters["plus_times"].value, 1)

    def test_resurrect_from_dump(self):
        """Test loading from dumpfile."""

        calculator = self.__default_calculator()

        self.assertEqual(calculator.parameters["plus_times"].value, 1)
        output = calculator.backengine()
        self.assertEqual(output.get_data()["number"], 2)
        self.__dirs_to_remove.append("PlusCalculator")

        # dump
        dump = calculator.dump()
        self.__files_to_remove.append(dump)

        del calculator

        calculator = PlusCalculator.from_dump(dump)

        self.assertEqual(
            calculator.input.get_data(),
            self.__default_calculator.input.get_data(),
        )

        calculator.parameters.to_dict()
        self.assertEqual(
            calculator.parameters.to_dict(),
            self.__default_calculator.parameters.to_dict(),
        )

        calculator.parameters["plus_times"] = 5
        self.assertNotEqual(
            calculator.parameters.to_dict(),
            self.__default_calculator.parameters.to_dict(),
        )

        self.assertIsNotNone(calculator.data)

    def test_attributes(self):
        """Test that all required attributes are present."""

        calculator = self.__default_calculator

        self.assertTrue(hasattr(calculator, "name"))
        self.assertTrue(hasattr(calculator, "input"))
        self.assertTrue(hasattr(calculator, "output"))
        self.assertTrue(hasattr(calculator, "parameters"))
        self.assertTrue(hasattr(calculator, "instrument_base_dir"))
        self.assertTrue(hasattr(calculator, "calculator_base_dir"))
        self.assertTrue(hasattr(calculator, "base_dir"))
        self.assertTrue(hasattr(calculator, "backengine"))
        self.assertTrue(hasattr(calculator, "data"))
        self.assertTrue(hasattr(calculator, "dump"))
        self.assertTrue(hasattr(calculator, "from_dump"))

    def test_set_param_values(self):
        calculator = self.__default_calculator

        calculator.parameters["plus_times"] = 5
        self.assertEqual(calculator.parameters["plus_times"].value, 5)

    def test_set_param_values_with_set_parameters(self):
        calculator = self.__default_calculator

        calculator.set_parameters(plus_times=7)
        self.assertEqual(calculator.parameters["plus_times"].value, 7)

    def test_set_param_values_with_set_parameters_with_dict(self):
        calculator = self.__default_calculator

        calculator.set_parameters({"plus_times": 9})
        self.assertEqual(calculator.parameters["plus_times"].value, 9)

    def test_collection_get_data(self):
        calculator = self.__default_calculator
        print(calculator.input)
        input_dict = calculator.input.get_data()
        self.assertEqual(input_dict["input1"]["number"], 1)
        self.assertEqual(input_dict["input2"]["number"], 1)

    def test_output_file_paths(self):
        calculator = self.__default_calculator
        with self.assertRaises(ValueError) as exception:
            calculator.output_file_paths

        calculator.output_filenames = "bingo.txt"
        self.assertEqual(calculator.output_file_paths[0], "PlusCalculator/bingo.txt")
        self.__dirs_to_remove.append("PlusCalculator")

    def test_calculator_output_set_inconsistent(self):
        input1 = NumberData.from_dict({"number": 1}, "input1")
        with self.assertRaises(ValueError) as exception:
            calculator = PlusCalculator(
                "test", input1, output_keys=["result"], output_data_types=[]
            )


if __name__ == "__main__":
    unittest.main()