File: nftest.go

package info (click to toggle)
golang-github-google-nftables 0.2.0-3
  • links: PTS, VCS
  • area: main
  • in suites: experimental, forky, sid, trixie
  • size: 756 kB
  • sloc: makefile: 8
file content (142 lines) | stat: -rw-r--r-- 3,593 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Package nftest contains utility functions for nftables testing.
package nftest

import (
	"bytes"
	"fmt"
	"strings"
	"testing"

	"github.com/google/nftables"
	"github.com/mdlayher/netlink"
)

// Recorder provides an nftables connection that does not send to the Linux
// kernel but instead records netlink messages into the recorder. The recorded
// requests can later be obtained using Requests and compared using Diff.
type Recorder struct {
	requests []netlink.Message
}

// Conn opens an nftables connection that records netlink messages into the
// Recorder.
func (r *Recorder) Conn() (*nftables.Conn, error) {
	return nftables.New(nftables.WithTestDial(
		func(req []netlink.Message) ([]netlink.Message, error) {
			r.requests = append(r.requests, req...)

			acks := make([]netlink.Message, 0, len(req))
			for _, msg := range req {
				if msg.Header.Flags&netlink.Acknowledge != 0 {
					acks = append(acks, netlink.Message{
						Header: netlink.Header{
							Length:   4,
							Type:     netlink.Error,
							Sequence: msg.Header.Sequence,
							PID:      msg.Header.PID,
						},
						Data: []byte{0, 0, 0, 0},
					})
				}
			}
			return acks, nil
		}))
}

// Requests returns the recorded netlink messages (typically nftables requests).
func (r *Recorder) Requests() []netlink.Message {
	return r.requests
}

// NewRecorder returns a ready-to-use Recorder.
func NewRecorder() *Recorder {
	return &Recorder{}
}

// Diff returns the first difference between the specified netlink messages and
// the expected netlink message payloads.
func Diff(got []netlink.Message, want [][]byte) string {
	for idx, msg := range got {
		b, err := msg.MarshalBinary()
		if err != nil {
			return fmt.Sprintf("msg.MarshalBinary: %v", err)
		}
		if len(b) < 16 {
			continue
		}
		b = b[16:]
		if len(want) == 0 {
			return fmt.Sprintf("no want entry for message %d: %x", idx, b)
		}
		if got, want := b, want[0]; !bytes.Equal(got, want) {
			return fmt.Sprintf("message %d: %s", idx, linediff(nfdump(got), nfdump(want)))
		}
		want = want[1:]
	}
	return ""
}

// MatchRulesetBytes is a test helper that ensures the fillRuleset modifications
// correspond to the provided want netlink message payloads
func MatchRulesetBytes(t *testing.T, fillRuleset func(c *nftables.Conn), want [][]byte) {
	t.Helper()

	rec := NewRecorder()

	c, err := rec.Conn()
	if err != nil {
		t.Fatal(err)
	}

	c.FlushRuleset()

	fillRuleset(c)

	if err := c.Flush(); err != nil {
		t.Fatal(err)
	}

	if diff := Diff(rec.Requests(), want); diff != "" {
		t.Errorf("unexpected netlink messages: diff: %s", diff)
	}
}

// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
// users to make sense of large byte literals more easily.
func nfdump(b []byte) string {
	var buf bytes.Buffer
	i := 0
	for ; i < len(b); i += 4 {
		// TODO: show printable characters as ASCII
		fmt.Fprintf(&buf, "%02x %02x %02x %02x\n",
			b[i],
			b[i+1],
			b[i+2],
			b[i+3])
	}
	for ; i < len(b); i++ {
		fmt.Fprintf(&buf, "%02x ", b[i])
	}
	return buf.String()
}

// linediff returns a side-by-side diff of two nfdump() return values, flagging
// lines which are not equal with an exclamation point prefix.
func linediff(a, b string) string {
	var buf bytes.Buffer
	fmt.Fprintf(&buf, "got -- want\n")
	linesA := strings.Split(a, "\n")
	linesB := strings.Split(b, "\n")
	for idx, lineA := range linesA {
		if idx >= len(linesB) {
			break
		}
		lineB := linesB[idx]
		prefix := "! "
		if lineA == lineB {
			prefix = "  "
		}
		fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB)
	}
	return buf.String()
}