File: test_data_processing.py

package info (click to toggle)
baler 1.4.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 207,900 kB
  • sloc: python: 2,468; sh: 98; makefile: 7
file content (121 lines) | stat: -rw-r--r-- 3,815 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright 2022 Baler Contributors

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import pytest
import torch
from sklearn.preprocessing import MinMaxScaler

from baler.modules import data_processing
from baler.modules import helper


def test_import_config_success():
    # Call the import_config function with the sample config file path
    config = helper.Config
    config.Foo = "Bar"
    config.Baz = 10

    # Assert that the result is equal to the expected config
    # This checks that the import_config function correctly loads the JSON file and returns the expected dictionary
    assert config.Foo == "Bar"


def test_save_model():
    # Test data
    model = torch.nn.Linear(3, 2)
    model_path = "test_model.pt"

    # Save the model
    data_processing.save_model(model, model_path)

    # Check that the model file has been created
    assert os.path.exists(model_path)

    # Clean up
    os.remove(model_path)


@pytest.fixture
def minmax_test_data():
    return [
        (np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), np.array([[1, 2, 3], [6, 6, 6]])),
        (
            np.array([[-1, -2, -3], [-4, -5, -6], [-7, -8, -9]]),
            np.array([[-7, -8, -9], [6, 6, 6]]),
        ),
        (np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]), np.array([[0, 0, 0], [2, 2, 2]])),
    ]


def test_find_minmax_success(minmax_test_data):
    for data, expected_result in minmax_test_data:
        result = data_processing.find_minmax(data)
        assert np.array_equal(result, expected_result)


def test_normalize():
    # Test data
    data = [1, 2, 3, 4, 5]

    # Test configuration 1
    custom_norm1 = False
    expected_result1 = np.array([0.0, 0.25, 0.5, 0.75, 1.0])

    # Test configuration 2
    custom_norm2 = True
    expected_result2 = np.array([1, 2, 3, 4, 5])

    # Test the normalize function with the test data and configuration 1
    result1 = data_processing.normalize(data, custom_norm1)
    np.testing.assert_almost_equal(result1, expected_result1)

    # Test the normalize function with the test data and configuration 2
    result2 = data_processing.normalize(data, custom_norm2)
    np.testing.assert_almost_equal(result2, expected_result2)


def test_renormalize_std():
    # Test data
    data = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
    true_min = 1
    feature_range = 2

    # Renormalize the data using the renormalize_std function
    renormalized_data = data_processing.renormalize_std(data, true_min, feature_range)

    # Check that the renormalized data is correct
    expected_renormalized_data = np.array([1.2, 1.4, 1.6, 1.8, 2.0])
    np.testing.assert_array_equal(renormalized_data, expected_renormalized_data)


def test_renormalize_func():
    # Test data
    scaler = MinMaxScaler()
    data = [[-1, 2], [-0.5, 6], [0, 10], [1, 18]]
    scaler.fit(data)
    norm_data = scaler.transform(data)
    true_min = [-1, 2]
    feature_range = [2, 16]

    # Renormalize the data using the renormalize_std function
    renormalized_data = data_processing.renormalize_func(
        norm_data, true_min, feature_range
    )

    # Check that the renormalized data is correct
    expected_renormalized_data = np.array([[-1, 2], [-0.5, 6], [0, 10], [1, 18]])
    np.testing.assert_array_equal(renormalized_data, expected_renormalized_data)