1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
// Package nftest contains utility functions for nftables testing.
package nftest
import (
"bytes"
"fmt"
"strings"
"testing"
"github.com/google/nftables"
"github.com/mdlayher/netlink"
)
// Recorder provides an nftables connection that does not send to the Linux
// kernel but instead records netlink messages into the recorder. The recorded
// requests can later be obtained using Requests and compared using Diff.
type Recorder struct {
requests []netlink.Message
}
// Conn opens an nftables connection that records netlink messages into the
// Recorder.
func (r *Recorder) Conn() (*nftables.Conn, error) {
return nftables.New(nftables.WithTestDial(
func(req []netlink.Message) ([]netlink.Message, error) {
r.requests = append(r.requests, req...)
acks := make([]netlink.Message, 0, len(req))
for _, msg := range req {
if msg.Header.Flags&netlink.Acknowledge != 0 {
acks = append(acks, netlink.Message{
Header: netlink.Header{
Length: 4,
Type: netlink.Error,
Sequence: msg.Header.Sequence,
PID: msg.Header.PID,
},
Data: []byte{0, 0, 0, 0},
})
}
}
return acks, nil
}))
}
// Requests returns the recorded netlink messages (typically nftables requests).
func (r *Recorder) Requests() []netlink.Message {
return r.requests
}
// NewRecorder returns a ready-to-use Recorder.
func NewRecorder() *Recorder {
return &Recorder{}
}
// Diff returns the first difference between the specified netlink messages and
// the expected netlink message payloads.
func Diff(got []netlink.Message, want [][]byte) string {
for idx, msg := range got {
b, err := msg.MarshalBinary()
if err != nil {
return fmt.Sprintf("msg.MarshalBinary: %v", err)
}
if len(b) < 16 {
continue
}
b = b[16:]
if len(want) == 0 {
return fmt.Sprintf("no want entry for message %d: %x", idx, b)
}
if got, want := b, want[0]; !bytes.Equal(got, want) {
return fmt.Sprintf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
}
want = want[1:]
}
return ""
}
// MatchRulesetBytes is a test helper that ensures the fillRuleset modifications
// correspond to the provided want netlink message payloads
func MatchRulesetBytes(t *testing.T, fillRuleset func(c *nftables.Conn), want [][]byte) {
t.Helper()
rec := NewRecorder()
c, err := rec.Conn()
if err != nil {
t.Fatal(err)
}
c.FlushRuleset()
fillRuleset(c)
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if diff := Diff(rec.Requests(), want); diff != "" {
t.Errorf("unexpected netlink messages: diff: %s", diff)
}
}
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
// users to make sense of large byte literals more easily.
func nfdump(b []byte) string {
var buf bytes.Buffer
i := 0
for ; i < len(b); i += 4 {
// TODO: show printable characters as ASCII
fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
b[i],
b[i+1],
b[i+2],
b[i+3])
}
for ; i < len(b); i++ {
fmt.Fprintf(&buf, "%02x ", b[i])
}
return buf.String()
}
// linediff returns a side-by-side diff of two nfdump() return values, flagging
// lines which are not equal with an exclamation point prefix.
func linediff(a, b string) string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "got -- want\n")
linesA := strings.Split(a, "\n")
linesB := strings.Split(b, "\n")
for idx, lineA := range linesA {
if idx >= len(linesB) {
break
}
lineB := linesB[idx]
prefix := "! "
if lineA == lineB {
prefix = " "
}
fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
}
return buf.String()
}
|