File: sums_asm.go

package info (click to toggle)
golang-github-segmentio-asm 1.2.0%2Bgit20231107.1cfacc8-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 932 kB
  • sloc: asm: 6,093; makefile: 32
file content (156 lines) | stat: -rw-r--r-- 3,458 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// +build ignore

package main

import (
	"fmt"

	. "github.com/mmcloughlin/avo/build"
	. "github.com/mmcloughlin/avo/operand"
	. "github.com/segmentio/asm/build/internal/x86"

	"github.com/mmcloughlin/avo/reg"
	"github.com/segmentio/asm/cpu"
)

const unroll = 8

type Processor struct {
	name      string
	typ       string
	scale     uint8
	avxOffset uint64
	avxAdd    func(...Op)
	x86Mov    func(imr, mr Op)
	x86Add    func(imr, amr Op)
	x86Reg    reg.GPVirtual
}

func init() {
	ConstraintExpr("!purego")
}

func main() {
	generate(Processor{
		name:      "sumUint64",
		typ:       "uint64",
		scale:     8,
		avxOffset: 2,
		avxAdd:    VPADDQ,
		x86Mov:    MOVQ,
		x86Add:    ADDQ,
		x86Reg:    GP64(),
	})

	generate(Processor{
		name:      "sumUint32",
		typ:       "uint32",
		scale:     4,
		avxOffset: 4,
		avxAdd:    VPADDD,
		x86Mov:    MOVL,
		x86Add:    ADDL,
		x86Reg:    GP32(),
	})

	generate(Processor{
		name:      "sumUint16",
		typ:       "uint16",
		scale:     2,
		avxOffset: 8,
		avxAdd:    VPADDW,
		x86Mov:    MOVW,
		x86Add:    ADDW,
		x86Reg:    GP16(),
	})

	generate(Processor{
		name:      "sumUint8",
		typ:       "uint8",
		scale:     1,
		avxOffset: 16,
		avxAdd:    VPADDB,
		x86Mov:    MOVB,
		x86Add:    ADDB,
		x86Reg:    GP8(),
	})

	Generate()
}

func generate(p Processor) {
	TEXT(p.name, NOSPLIT, fmt.Sprintf("func(x, y []%s)", p.typ))
	Doc(fmt.Sprintf("Sum %ss using avx2 instructions, results stored in x", p.typ))
	idx := GP64()
	XORQ(idx, idx)
	xPtr := Mem{Base: Load(Param("x").Base(), GP64()), Index: idx, Scale: p.scale}
	yPtr := Mem{Base: Load(Param("y").Base(), GP64()), Index: idx, Scale: p.scale}
	len := Load(Param("x").Len(), GP64())
	yLen := Load(Param("y").Len(), GP64())
	// len = min(len(x), len(y))
	CMPQ(yLen, len)
	CMOVQLT(yLen, len)

	JumpUnlessFeature("x86_loop", cpu.AVX2)

	Label("avx2_loop")
	next := GP64()
	MOVQ(idx, next)
	ADDQ(Imm(unroll*p.avxOffset), next)
	CMPQ(next, len)
	JAE(LabelRef("x86_loop"))

	// Create unroll num vector registers
	var vectors [unroll]reg.VecVirtual
	for i := 0; i < unroll; i++ {
		vectors[i] = YMM()
	}
	// So here essentially what we're doing is populating pairs
	// of vector registers with 256 bits of integer data, so as an example
	// for uint64s, it would look like...
	// YMM0 [ x0, x1, x2, x3 ]
	// YMM1 [ y0, y1, y2, y3 ]
	// ...
	// YMM(N) ...
	//
	// We then use VPADDQ to perform a SIMD addition operation
	// on the pairs and the result is stored in even register (0,2,4...).
	// Finally we copy the results back out to the slice pointed to by x
	for offset, i := 0, 0; i < unroll/2; i++ {
		VMOVDQU(xPtr.Offset(i*32), vectors[offset])
		VMOVDQU(yPtr.Offset(i*32), vectors[offset+1])
		offset += 2
	}

	// AVX intrinsics to sum 64 bit integers/quad words
	for offset, i := 0, 0; i < unroll/2; i++ {
		p.avxAdd(vectors[offset], vectors[offset+1], vectors[offset])
		offset += 2
	}

	for offset, i := 0, 0; i < unroll/2; i++ {
		VMOVDQU(vectors[offset], xPtr.Offset(i*32))
		offset += 2
	}
	// Increment ptrs and loop.
	MOVQ(next, idx)
	JMP(LabelRef("avx2_loop"))

	// Here's we're just going to manually bump our pointers
	// and do a the addition on the remaining integers (if any)
	Label("x86_loop")
	CMPQ(idx, len)
	JAE(LabelRef("return"))

	// Delegate to specific computation
	//calc()
	p.x86Mov(yPtr, p.x86Reg)
	p.x86Add(p.x86Reg, xPtr)

	// Increment ptrs and loop.
	ADDQ(Imm(1), idx)
	JMP(LabelRef("x86_loop"))

	Label("return")
	RET()
}