1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
//go:build ignore
// +build ignore
package main
import (
. "github.com/mmcloughlin/avo/build"
. "github.com/mmcloughlin/avo/operand"
. "github.com/mmcloughlin/avo/reg"
. "github.com/segmentio/asm/build/internal/x86"
"github.com/segmentio/asm/cpu"
)
func init() {
ConstraintExpr("!purego")
}
func main() {
TEXT("ValidString", NOSPLIT, "func(s string) bool")
Doc("ValidString returns true if s contains only ASCII characters.")
p := Mem{Base: Load(Param("s").Base(), GP64())}
n := Load(Param("s").Len(), GP64())
ret, _ := ReturnIndex(0).Resolve()
v := GP32()
vl := GP32()
maskG := GP64()
MOVQ(U64(0x8080808080808080), maskG) // maskG = 0x8080808080808080
CMPQ(n, U8(16)) // if n < 16:
JB(LabelRef("cmp8")) // goto cmp8
JumpIfFeature("init_avx", cpu.AVX2)
Label("cmp8")
CMPQ(n, U8(8)) // if n < 8:
JB(LabelRef("cmp4")) // goto cmp4
TESTQ(maskG, p) // if (p[0:8] & 0x8080808080808080) != 0:
JNZ(LabelRef("invalid")) // return false
ADDQ(U8(8), p.Base) // p += 8
SUBQ(U8(8), n) // n -= 8
JMP(LabelRef("cmp8")) // loop cmp8
Label("cmp4")
CMPQ(n, U8(4)) // if n < 4:
JB(LabelRef("cmp3")) // goto cmp3
TESTL(U32(0x80808080), p) // if (p[0:4] & 0x80808080) != 0:
JNZ(LabelRef("invalid")) // return false
ADDQ(U8(4), p.Base) // p += 4
SUBQ(U8(4), n) // n -= 4
Label("cmp3")
CMPQ(n, U8(3)) // if n < 3:
JB(LabelRef("cmp2")) // goto cmp2
MOVWLZX(p, vl) // vl = p[i:i+2]
MOVBLZX(p.Offset(2), v) // v = p[i+2:i+3]
SHLL(U8(16), v) // v <<= 16
ORL(vl, v) // v = vl | v
TESTL(U32(0x80808080), v) // ZF = (v & 0x80808080) == 0
JMP(LabelRef("done")) // return ZF
Label("cmp2")
CMPQ(n, U8(2)) // if n < 2:
JB(LabelRef("cmp1")) // goto cmp1
TESTW(U16(0x8080), p) // ZF = (p[0:2] & 0x8080) == 0
JMP(LabelRef("done")) // return ZF
Label("cmp1")
CMPQ(n, U8(0)) // if n == 0:
JE(LabelRef("done")) // return true
TESTB(U8(0x80), p) // ZF = (p[0:1] & 0x80) == 0
Label("done")
SETEQ(ret.Addr) // return ZF
RET() // ...
Label("invalid")
MOVB(U8(0), ret.Addr)
RET()
Label("init_avx")
maskY := VecBroadcast(maskG, YMM())
maskX := maskY.(Vec).AsX()
vec := NewVectorizer(15, func(l VectorLane) Register {
r := l.Alloc()
VMOVDQU(l.Offset(p), r)
VPOR(l.Offset(p), r, r)
return r
}).Reduce(ReduceOr)
Label("cmp256")
CMPQ(n, U32(256)) // if n < 256:
JB(LabelRef("cmp128")) // goto cmp128
VPTEST(vec.Compile(S256, 4)[0], maskY) // if (OR & maskY) != 0:
JNZ(LabelRef("invalid")) // return false
ADDQ(U32(256), p.Base) // p += 256
SUBQ(U32(256), n) // n -= 256
JMP(LabelRef("cmp256")) // loop cmp256
Label("cmp128")
CMPQ(n, U8(128)) // if n < 128:
JB(LabelRef("cmp64")) // goto cmp64
VPTEST(vec.Compile(S256, 2)[0], maskY) // if (OR & maskY) != 0:
JNZ(LabelRef("invalid")) // return false
ADDQ(U8(128), p.Base) // p += 128
SUBQ(U8(128), n) // n -= 128
JMP(LabelRef("cmp64")) // goto cmp64
Label("cmp64")
CMPQ(n, U8(64)) // if n < 64:
JB(LabelRef("cmp32")) // goto cmp32
VPTEST(vec.Compile(S256, 1)[0], maskY) // if (OR & maskY) != 0:
JNZ(LabelRef("invalid")) // return false
ADDQ(U8(64), p.Base) // p += 64
SUBQ(U8(64), n) // n -= 64
Label("cmp32")
CMPQ(n, U8(32)) // if n < 32:
JB(LabelRef("cmp16")) // goto cmp16
VPTEST(p, maskY) // if (p[0:32] & maskY) != 0:
JNZ(LabelRef("invalid")) // return false
ADDQ(U8(32), p.Base) // p += 32
SUBQ(U8(32), n) // n -= 32
Label("cmp16")
CMPQ(n, U8(16)) // if n <= 16:
JLE(LabelRef("cmp_tail")) // goto cmp_tail
VPTEST(p, maskX) // if (p[0:16] & maskX) != 0:
JNZ(LabelRef("invalid")) // return false
ADDQ(U8(16), p.Base) // p += 16
SUBQ(U8(16), n) // n -= 16
Label("cmp_tail")
// At this point, we have <= 16 bytes to compare, but we know the total input
// is >= 16 bytes. Move the pointer to the *last* 16 bytes of the input so we
// can skip the fallback.
SUBQ(Imm(16), n) // n -= 16
ADDQ(n, p.Base) // p += n
VPTEST(p, maskX) // ZF = (p[0:16] & maskX) == 0
JMP(LabelRef("done")) // return ZF
Generate()
}
|