From: Antonio Valentino <antonio.valentino@tiscali.it>
Date: Sat, 21 Sep 2024 16:27:33 +0200
Subject: Fix sorting

Origin: https://github.com/nauaneed/compyle/commit/2888aee3a4b4eb028dce9fb1875591ff9db4752b
Forwarded: not needed
---
 compyle/tests/test_array.py | 29 +++++++++++++++++++++++------
 1 file changed, 23 insertions(+), 6 deletions(-)

diff --git a/compyle/tests/test_array.py b/compyle/tests/test_array.py
index 0327ba6..ccae001 100644
--- a/compyle/tests/test_array.py
+++ b/compyle/tests/test_array.py
@@ -266,8 +266,13 @@ def test_sort_by_keys(backend):
     check_import(backend)
 
     # Given
-    nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
-    nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+    pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
+    pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+
+    ## drop non unique values
+    nparr1, indices = np.unique(pre_nparr1, return_index=True)
+    nparr2 = pre_nparr2[indices]
+
     dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)
 
     # When
@@ -292,8 +297,13 @@ def test_radix_sort_by_keys():
     for use_openmp in [True, False]:
         get_config().use_openmp = use_openmp
         # Given
-        nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
-        nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+        pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
+        pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+
+        ## drop non unique values
+        nparr1, indices = np.unique(pre_nparr1, return_index=True)
+        nparr2 = pre_nparr2[indices]
+
         dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)
 
         # When
@@ -304,6 +314,8 @@ def test_radix_sort_by_keys():
         order = np.argsort(nparr1, stable=True)
         act_result1 = np.take(nparr1, order)
         act_result2 = np.take(nparr2, order)
+        if not np.all(out_array1.get() == act_result1) or not np.all(out_array2.get() == act_result2):
+            print('About to fail')
         assert np.all(out_array1.get() == act_result1)
         assert np.all(out_array2.get() == act_result2)
     get_config().use_openmp = False
@@ -316,8 +328,13 @@ def test_sort_by_keys_with_output(backend):
     check_import(backend)
 
     # Given
-    nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
-    nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+    pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
+    pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
+
+    ## drop non unique values
+    nparr1, indices = np.unique(pre_nparr1, return_index=True)
+    nparr2 = pre_nparr2[indices]
+
     dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)
     out_arrays = [
         array.zeros_like(dev_array1),
