File: keyset_asm.go

package info (click to toggle)
golang-github-segmentio-asm 1.2.0%2Bgit20231107.1cfacc8-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 932 kB
  • sloc: asm: 6,093; makefile: 32
file content (178 lines) | stat: -rw-r--r-- 4,905 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
//go:build ignore
// +build ignore

package main

import (
	"fmt"
	"math"

	. "github.com/mmcloughlin/avo/build"
	. "github.com/mmcloughlin/avo/operand"
	. "github.com/mmcloughlin/avo/reg"
	. "github.com/segmentio/asm/build/internal/asm"
)

const (
	unroll    = 4
	pageSize  = 4096
	maxLength = 16
)

func init() {
	ConstraintExpr("!purego")
}

func main() {
	Lookup()

	Generate()
}

// Lookup searches for a key in a set of keys.
//
// Each key in the set of keys should be padded to 16 bytes and concatenated
// into a single buffer. The routine searches for the input key in the set of
// keys and returns its index if found. If not found, the routine returns the
// number of keys (len(keyset)/16).
func Lookup() {
	TEXT("Lookup", NOSPLIT, "func(keyset []byte, key []byte) int")
	Doc("Lookup searches for a key in a set of keys, returning its index if ",
		"found. If the key cannot be found, the number of keys is returned.")

	// Load inputs.
	keyset := Load(Param("keyset").Base(), GP64())
	count := Load(Param("keyset").Len(), GP64())
	SHRQ(Imm(4), count)
	keyPtr := Load(Param("key").Base(), GP64())
	keyLen := Load(Param("key").Len(), GP64())
	keyCap := Load(Param("key").Cap(), GP64())

	// None of the keys are larger than maxLength.
	CMPQ(keyLen, Imm(maxLength))
	JA(LabelRef("not_found"))

	// We're going to be unconditionally loading 16 bytes from the input key
	// so first check if it's safe to do so (cap >= 16). If not, defer to
	// safe_load for additional checks.
	CMPQ(keyCap, Imm(maxLength))
	JB(LabelRef("safe_load"))

	// Load the input key and pad with zeroes to 16 bytes.
	Label("load")
	key := XMM()
	VMOVUPS(Mem{Base: keyPtr}, key)
	Label("prepare")
	zeroes := XMM()
	VPXOR(zeroes, zeroes, zeroes)
	ones := XMM()
	VPCMPEQB(ones, ones, ones)
	var blendBytes [maxLength * 2]byte
	for j := 0; j < maxLength; j++ {
		blendBytes[j] = 0xFF
	}
	blendMasks := ConstBytes("blend_masks", blendBytes[:])
	blendMasksPtr := GP64()
	LEAQ(blendMasks.Offset(maxLength), blendMasksPtr)
	SUBQ(keyLen, blendMasksPtr)
	blend := XMM()
	VMOVUPS(Mem{Base: blendMasksPtr}, blend)
	VPBLENDVB(blend, key, zeroes, key)

	// Zero out i so we can use it as the loop increment.
	i := GP64()
	XORQ(i, i)

	// Round the key count down to the nearest multiple of unroll to determine
	// how many iterations of the big loop we'll need.
	truncatedCount := GP64()
	MOVQ(count, truncatedCount)
	shift := uint64(math.Log2(float64(unroll)))
	SHRQ(Imm(shift), truncatedCount)
	SHLQ(Imm(shift), truncatedCount)

	// Loop over multiple keys in the big loop.
	Label("bigloop")
	CMPQ(i, truncatedCount)
	JE(LabelRef("loop"))

	x := []VecPhysical{X8, X9, X10, X11, X12, X13, X14, X15}
	for n := 0; n < unroll; n++ {
		VPCMPEQB(Mem{Base: keyset, Disp: maxLength * n}, key, x[n])
		VPTEST(ones, x[n])
		var target string
		if n == 0 {
			target = "done"
		} else {
			target = fmt.Sprintf("found%d", n)
		}
		JCS(LabelRef(target))
	}

	// Advance and loop again.
	ADDQ(Imm(unroll), i)
	ADDQ(Imm(unroll*maxLength), keyset)
	JMP(LabelRef("bigloop"))

	// Loop over the remaining keys.
	Label("loop")
	CMPQ(i, count)
	JE(LabelRef("done"))

	// Try to match against the input key.
	match := XMM()
	VPCMPEQB(Mem{Base: keyset}, key, match)
	VPTEST(ones, match)
	JCS(LabelRef("done"))

	// Advance and loop again.
	Label("next")
	INCQ(i)
	ADDQ(Imm(maxLength), keyset)
	JMP(LabelRef("loop"))
	JMP(LabelRef("done"))

	// Return the loop increment, or the count if the key wasn't found. If we're
	// here from a jump within the big loop, the loop increment needs
	// correcting first.
	for j := unroll - 1; j > 0; j-- {
		Label(fmt.Sprintf("found%d", j))
		INCQ(i)
	}
	Label("done")
	Store(i, ReturnIndex(0))
	RET()
	Label("not_found")
	Store(count, ReturnIndex(0))
	RET()

	// If the input key is near a page boundary, we must change the way we load
	// it to avoid a fault. We instead want to load the 16 bytes up to and
	// including the key, then shuffle the key forward in the register. E.g. for
	// key "foo" we would load the 13 bytes prior to the key along with "foo"
	// and then move the last 3 bytes forward so the first 3 bytes are equal
	// to "foo".
	Label("safe_load")
	pageOffset := GP64()
	MOVQ(keyPtr, pageOffset)
	ANDQ(U32(pageSize-1), pageOffset)
	CMPQ(pageOffset, U32(pageSize-maxLength))
	JBE(LabelRef("load")) // Not near a page boundary.
	offset := GP64()
	MOVQ(^U64(0)-maxLength+1, offset)
	ADDQ(keyLen, offset)
	VMOVUPS(Mem{Base: keyPtr, Index: offset, Scale: 1}, key)
	var shuffleBytes [maxLength * 2]byte
	for j := 0; j < maxLength; j++ {
		shuffleBytes[j] = byte(j)
		shuffleBytes[j+maxLength] = byte(j)
	}
	shuffleMasks := ConstBytes("shuffle_masks", shuffleBytes[:])
	shuffleMasksPtr := GP64()
	LEAQ(shuffleMasks.Offset(maxLength), shuffleMasksPtr)
	SUBQ(keyLen, shuffleMasksPtr)
	shuffle := XMM()
	VMOVUPS(Mem{Base: shuffleMasksPtr}, shuffle)
	VPSHUFB(shuffle, key, key)
	JMP(LabelRef("prepare"))
}