File: matmul-opt.py

package info (click to toggle)
python-peachpy 0.0~git20211013.257881e-1.1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 2,452 kB
  • sloc: python: 29,286; ansic: 54; makefile: 44; cpp: 31
file content (59 lines) | stat: -rw-r--r-- 1,583 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# This file is part of PeachPy package and is licensed under the Simplified BSD license.
#    See license.rst for the full text of the license.

from peachpy.x86_64 import *
from peachpy import *

a = Argument(ptr(const_float_))
b = Argument(ptr(const_float_))
c = Argument(ptr(float_))

with Function("matmul", (a, b, c)) as function:
    reg_a = GeneralPurposeRegister64()
    LOAD.ARGUMENT(reg_a, a)

    reg_b = GeneralPurposeRegister64()
    LOAD.ARGUMENT(reg_b, b)

    reg_c = GeneralPurposeRegister64()
    LOAD.ARGUMENT(reg_c, c)

    xmm_Brow0 = XMMRegister()
    MOVUPS(xmm_Brow0, [reg_b + 0])

    xmm_Brow1 = XMMRegister()
    MOVUPS(xmm_Brow1, [reg_b + 16])

    xmm_Brow2 = XMMRegister()
    MOVUPS(xmm_Brow2, [reg_b + 32])

    xmm_Brow3 = XMMRegister()
    MOVUPS(xmm_Brow3, [reg_b + 48])

    for k in range(4):
        xmm_Ak0 = XMMRegister()
        MOVSS(xmm_Ak0, [reg_a + k * 16])
        SHUFPS(xmm_Ak0, xmm_Ak0, 0x00)
        MULPS(xmm_Ak0, xmm_Brow0)

        xmm_Ak1 = XMMRegister()
        MOVSS(xmm_Ak1, [reg_a + k * 16 + 4])
        SHUFPS(xmm_Ak1, xmm_Ak1, 0x00)
        MULPS(xmm_Ak1, xmm_Brow1)
        ADDPS(xmm_Ak0, xmm_Ak1)

        xmm_Ak2 = XMMRegister()
        MOVSS(xmm_Ak2, [reg_a + k * 16 + 8])
        SHUFPS(xmm_Ak2, xmm_Ak2, 0x00)
        MULPS(xmm_Ak2, xmm_Brow2)

        xmm_Ak3 = XMMRegister()
        MOVSS(xmm_Ak3, [reg_a + k * 16 + 12])
        SHUFPS(xmm_Ak3, xmm_Ak3, 0x00)
        MULPS(xmm_Ak3, xmm_Brow3)
        ADDPS(xmm_Ak2, xmm_Ak3)

        ADDPS(xmm_Ak0, xmm_Ak2)
        MOVUPS([reg_c + k * 16], xmm_Ak0)

    RETURN()