File: x86_sse_sgemm.rs

package info (click to toggle)
rust-matrixmultiply 0.3.9-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 436 kB
  • sloc: python: 86; sh: 11; makefile: 2
file content (84 lines) | stat: -rw-r--r-- 2,758 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

// 4x4 sse sgemm
macro_rules! mm_transpose4 {
    ($c0:expr, $c1:expr, $c2:expr, $c3:expr) => {{
        // This is _MM_TRANSPOSE4_PS except we take variables, not references
        let tmp0 = _mm_unpacklo_ps($c0, $c1);
        let tmp2 = _mm_unpacklo_ps($c2, $c3);
        let tmp1 = _mm_unpackhi_ps($c0, $c1);
        let tmp3 = _mm_unpackhi_ps($c2, $c3);

        $c0 = _mm_movelh_ps(tmp0, tmp2);
        $c1 = _mm_movehl_ps(tmp2, tmp0);
        $c2 = _mm_movelh_ps(tmp1, tmp3);
        $c3 = _mm_movehl_ps(tmp3, tmp1);
    }}
}

#[inline(always)]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
unsafe fn kernel_x86_sse(k: usize, alpha: T, a: *const T, b: *const T,
                         beta: T, c: *mut T, rsc: isize, csc: isize)
{
    let mut ab = [_mm_setzero_ps(); MR];

    let mut bv;
    let (mut a, mut b) = (a, b);

    // Compute A B
    for _ in 0..k {
        bv = _mm_load_ps(b as _); // aligned due to GemmKernel::align_to

        loop_m!(i, {
            // Compute ab_i += [ai b_j+0, ai b_j+1, ai b_j+2, ai b_j+3]
            let aiv = _mm_set1_ps(at(a, i));
            ab[i] = _mm_add_ps(ab[i], _mm_mul_ps(aiv, bv));
        });

        a = a.add(MR);
        b = b.add(NR);
    }

    // Compute α (A B)
    let alphav = _mm_set1_ps(alpha);
    loop_m!(i, ab[i] = _mm_mul_ps(alphav, ab[i]));

    macro_rules! c {
        ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
    }

    // C ← α A B + β C
    let mut c = [_mm_setzero_ps(); MR];
    let betav = _mm_set1_ps(beta);
    if beta != 0. {
        // Read C
        if csc == 1 {
            loop_m!(i, c[i] = _mm_loadu_ps(c![i, 0]));
        } else if rsc == 1 {
            loop_m!(i, c[i] = _mm_loadu_ps(c![0, i]));
            mm_transpose4!(c[0], c[1], c[2], c[3]);
        } else {
            loop_m!(i, c[i] = _mm_set_ps(*c![i, 3], *c![i, 2], *c![i, 1], *c![i, 0]));
        }
        // Compute β C
        loop_m!(i, c[i] = _mm_mul_ps(c[i], betav));
    }

    // Compute (α A B) + (β C)
    loop_m!(i, c[i] = _mm_add_ps(c[i], ab[i]));

    // Store C back to memory
    if csc == 1 {
        loop_m!(i, _mm_storeu_ps(c![i, 0], c[i]));
    } else if rsc == 1 {
        mm_transpose4!(c[0], c[1], c[2], c[3]);
        loop_m!(i, _mm_storeu_ps(c![0, i], c[i]));
    } else {
        // extract the nth value of a vector using _mm_cvtss_f32 (extract lowest)
        // in combination with shuffle (move nth value to first position)
        loop_m!(i, *c![i, 0] = _mm_cvtss_f32(c[i]));
        loop_m!(i, *c![i, 1] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 1)));
        loop_m!(i, *c![i, 2] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 2)));
        loop_m!(i, *c![i, 3] = _mm_cvtss_f32(_mm_shuffle_ps(c[i], c[i], 3)));
    }
}