1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
#!/usr/bin/env python
# Copyright (c) 2017 Facebook, Inc.
# Licensed under the Apache License, Version 2.0 (the "License")
import ctypes as ct
import distutils.version
import os
from unittest import main, skipUnless, TestCase
from bcc import BPF
from netaddr import IPAddress
def kernel_version_ge(major, minor):
# True if running kernel is >= X.Y
version = distutils.version.LooseVersion(os.uname()[2]).version
if version[0] > major:
return True
if version[0] < major:
return False
if minor and version[1] < minor:
return False
return True
class KeyV4(ct.Structure):
_fields_ = [("prefixlen", ct.c_uint),
("data", ct.c_ubyte * 4)]
class KeyV6(ct.Structure):
_fields_ = [("prefixlen", ct.c_uint),
("data", ct.c_ushort * 8)]
@skipUnless(kernel_version_ge(4, 11), "requires kernel >= 4.11")
class TestLpmTrie(TestCase):
def test_lpm_trie_v4(self):
test_prog1 = """
struct key_v4 {
u32 prefixlen;
u32 data[4];
};
BPF_LPM_TRIE(trie, struct key_v4, int, 16);
"""
b = BPF(text=test_prog1)
t = b["trie"]
k1 = KeyV4(24, (192, 168, 0, 0))
v1 = ct.c_int(24)
t[k1] = v1
k2 = KeyV4(28, (192, 168, 0, 0))
v2 = ct.c_int(28)
t[k2] = v2
k = KeyV4(32, (192, 168, 0, 15))
self.assertEqual(t[k].value, 28)
k = KeyV4(32, (192, 168, 0, 127))
self.assertEqual(t[k].value, 24)
with self.assertRaises(KeyError):
k = KeyV4(32, (172, 16, 1, 127))
v = t[k]
def test_lpm_trie_v6(self):
test_prog1 = """
struct key_v6 {
u32 prefixlen;
u32 data[4];
};
BPF_LPM_TRIE(trie, struct key_v6, int, 16);
"""
b = BPF(text=test_prog1)
t = b["trie"]
k1 = KeyV6(64, IPAddress('2a00:1450:4001:814:200e::').words)
v1 = ct.c_int(64)
t[k1] = v1
k2 = KeyV6(96, IPAddress('2a00:1450:4001:814::200e').words)
v2 = ct.c_int(96)
t[k2] = v2
k = KeyV6(128, IPAddress('2a00:1450:4001:814::1024').words)
self.assertEqual(t[k].value, 96)
k = KeyV6(128, IPAddress('2a00:1450:4001:814:2046::').words)
self.assertEqual(t[k].value, 64)
with self.assertRaises(KeyError):
k = KeyV6(128, IPAddress('2a00:ffff::').words)
v = t[k]
if __name__ == "__main__":
main()
|