1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
From: Antonio Valentino <antonio.valentino@tiscali.it>
Date: Sat, 21 Sep 2024 16:27:33 +0200
Subject: Fix sorting
Origin: https://github.com/nauaneed/compyle/commit/2888aee3a4b4eb028dce9fb1875591ff9db4752b
Forwarded: not needed
---
compyle/tests/test_array.py | 29 +++++++++++++++++++++++------
1 file changed, 23 insertions(+), 6 deletions(-)
diff --git a/compyle/tests/test_array.py b/compyle/tests/test_array.py
index 0327ba6..ccae001 100644
--- a/compyle/tests/test_array.py
+++ b/compyle/tests/test_array.py
@@ -266,8 +266,13 @@ def test_sort_by_keys(backend):
check_import(backend)
# Given
- nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
- nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+ pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
+ pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+
+ ## drop non unique values
+ nparr1, indices = np.unique(pre_nparr1, return_index=True)
+ nparr2 = pre_nparr2[indices]
+
dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)
# When
@@ -292,8 +297,13 @@ def test_radix_sort_by_keys():
for use_openmp in [True, False]:
get_config().use_openmp = use_openmp
# Given
- nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
- nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+ pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
+ pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+
+ ## drop non unique values
+ nparr1, indices = np.unique(pre_nparr1, return_index=True)
+ nparr2 = pre_nparr2[indices]
+
dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)
# When
@@ -304,6 +314,8 @@ def test_radix_sort_by_keys():
order = np.argsort(nparr1, stable=True)
act_result1 = np.take(nparr1, order)
act_result2 = np.take(nparr2, order)
+ if not np.all(out_array1.get() == act_result1) or not np.all(out_array2.get() == act_result2):
+ print('About to fail')
assert np.all(out_array1.get() == act_result1)
assert np.all(out_array2.get() == act_result2)
get_config().use_openmp = False
@@ -316,8 +328,13 @@ def test_sort_by_keys_with_output(backend):
check_import(backend)
# Given
- nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
- nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+ pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
+ pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+
+ ## drop non unique values
+ nparr1, indices = np.unique(pre_nparr1, return_index=True)
+ nparr2 = pre_nparr2[indices]
+
dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)
out_arrays = [
array.zeros_like(dev_array1),
|