1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
|
//! This is an example of sorting arrays along an axis.
//! This file may not be so instructive except for advanced users, instead it
//! could be a "feature preview" before sorting is added to the main crate.
//!
use ndarray::prelude::*;
use ndarray::{Data, RemoveAxis, Zip};
use rawpointer::PointerExt;
use std::cmp::Ordering;
use std::ptr::copy_nonoverlapping;
// Type invariant: Each index appears exactly once
#[derive(Clone, Debug)]
pub struct Permutation
{
indices: Vec<usize>,
}
impl Permutation
{
/// Checks if the permutation is correct
pub fn from_indices(v: Vec<usize>) -> Result<Self, ()>
{
let perm = Permutation { indices: v };
if perm.correct() {
Ok(perm)
} else {
Err(())
}
}
fn correct(&self) -> bool
{
let axis_len = self.indices.len();
let mut seen = vec![false; axis_len];
for &i in &self.indices {
match seen.get_mut(i) {
None => return false,
Some(s) =>
if *s {
return false;
} else {
*s = true;
},
}
}
true
}
}
pub trait SortArray
{
/// ***Panics*** if `axis` is out of bounds.
fn identity(&self, axis: Axis) -> Permutation;
fn sort_axis_by<F>(&self, axis: Axis, less_than: F) -> Permutation
where F: FnMut(usize, usize) -> bool;
}
pub trait PermuteArray
{
type Elem;
type Dim;
fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array<Self::Elem, Self::Dim>
where
Self::Elem: Clone,
Self::Dim: RemoveAxis;
}
impl<A, S, D> SortArray for ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn identity(&self, axis: Axis) -> Permutation
{
Permutation {
indices: (0..self.len_of(axis)).collect(),
}
}
fn sort_axis_by<F>(&self, axis: Axis, mut less_than: F) -> Permutation
where F: FnMut(usize, usize) -> bool
{
let mut perm = self.identity(axis);
perm.indices.sort_by(move |&a, &b| {
if less_than(a, b) {
Ordering::Less
} else if less_than(b, a) {
Ordering::Greater
} else {
Ordering::Equal
}
});
perm
}
}
impl<A, D> PermuteArray for Array<A, D>
where D: Dimension
{
type Elem = A;
type Dim = D;
fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array<A, D>
where D: RemoveAxis
{
let axis_len = self.len_of(axis);
let axis_stride = self.stride_of(axis);
assert_eq!(axis_len, perm.indices.len());
debug_assert!(perm.correct());
if self.is_empty() {
return self;
}
let mut result = Array::uninit(self.dim());
unsafe {
// logically move ownership of all elements from self into result
// the result realizes this ownership at .assume_init() further down
let mut moved_elements = 0;
// the permutation vector is used like this:
//
// index: 0 1 2 3 (index in result)
// permut: 2 3 0 1 (index in the source)
//
// move source 2 -> result 0,
// move source 3 -> result 1,
// move source 0 -> result 2,
// move source 1 -> result 3,
// et.c.
let source_0 = self.raw_view().index_axis_move(axis, 0);
Zip::from(&perm.indices)
.and(result.axis_iter_mut(axis))
.for_each(|&perm_i, result_pane| {
// Use a shortcut to avoid bounds checking in `index_axis` for the source.
//
// It works because for any given element pointer in the array we have the
// relationship:
//
// .index_axis(axis, 0) + .stride_of(axis) * j == .index_axis(axis, j)
//
// where + is pointer arithmetic on the element pointers.
//
// Here source_0 and the offset is equivalent to self.index_axis(axis, perm_i)
Zip::from(result_pane)
.and(source_0.clone())
.for_each(|to, from_0| {
let from = from_0.stride_offset(axis_stride, perm_i);
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
moved_elements += 1;
});
});
debug_assert_eq!(result.len(), moved_elements);
// forget the old elements but not the allocation
let mut old_storage = self.into_raw_vec_and_offset().0;
old_storage.set_len(0);
// transfer ownership of the elements into the result
result.assume_init()
}
}
}
#[cfg(feature = "std")]
fn main()
{
let a = Array::linspace(0., 63., 64)
.into_shape_with_order((8, 8))
.unwrap();
let strings = a.map(|x| x.to_string());
let perm = a.sort_axis_by(Axis(1), |i, j| a[[i, 0]] > a[[j, 0]]);
println!("{:?}", perm);
let b = a.permute_axis(Axis(0), &perm);
println!("{:?}", b);
println!("{:?}", strings);
let c = strings.permute_axis(Axis(1), &perm);
println!("{:?}", c);
}
#[cfg(not(feature = "std"))]
fn main() {}
#[cfg(test)]
mod tests
{
use super::*;
#[test]
fn test_permute_axis()
{
let a = array![
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[108000.33, 4.],
[107999.45, 5.],
[107999.57, 6.],
[108010.69, 7.],
[107999.81, 8.],
[107999.94, 9.],
[75600.09, 10.],
[75600.21, 11.],
[75601.33, 12.],
[75600.45, 13.],
[75600.58, 14.],
[109000.70, 15.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
];
let answer = array![
[75600.09, 10.],
[75600.21, 11.],
[75600.45, 13.],
[75600.58, 14.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
[75601.33, 12.],
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[107999.45, 5.],
[107999.57, 6.],
[107999.81, 8.],
[107999.94, 9.],
[108000.33, 4.],
[108010.69, 7.],
[109000.70, 15.],
];
// f layout copy of a
let mut af = Array::zeros(a.dim().f());
af.assign(&a);
// transposed copy of a
let at = a.t().to_owned();
// c layout permute
let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]);
let b = a.permute_axis(Axis(0), &perm);
assert_eq!(b, answer);
// f layout permute
let bf = af.permute_axis(Axis(0), &perm);
assert_eq!(bf, answer);
// transposed permute
let bt = at.permute_axis(Axis(1), &perm);
assert_eq!(bt, answer.t());
}
}
|