1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
|
import errno
import functools
import getpass
import itertools
import logging
import os
import sys
import uuid
from collections import namedtuple
from socket import AF_INET, AF_INET6
import pytest
from utils import allocate_network, free_network
from pyroute2 import netns
from pyroute2.common import basestring, uifname
from pyroute2.iproute.linux import IPRoute
from pyroute2.ndb.main import NDB
from pyroute2.netlink.exceptions import NetlinkError
from pyroute2.netlink.generic.wireguard import WireGuard
from pyroute2.nslink.nslink import NetNS
def skip_if_not_implemented(func):
@functools.wraps(func)
def test_wrapper(context):
try:
return func(context)
except (AttributeError, NotImplementedError):
pytest.skip('feature not implemented')
return test_wrapper
def skip_if_not_supported(func):
@functools.wraps(func)
def test_wrapper(*argv, **kwarg):
try:
return func(*argv, **kwarg)
except NetlinkError as e:
if e.code in {errno.EOPNOTSUPP, errno.ENOENT}:
pytest.skip('feature not supported by platform')
raise
except RuntimeError as e:
pytest.skip(*e.args)
return test_wrapper
def make_test_matrix(targets=None, tables=None, dbs=None, types=None):
targets = targets or ['local']
tables = tables or [None]
types = types or [None]
dbs = dbs or ['sqlite3/:memory:']
ret = []
skipdb = list(filter(lambda x: x, os.environ.get('SKIPDB', '').split(':')))
for db in dbs:
db_provider, db_spec = db.split('/')
if any(map(db_provider.startswith, skipdb)):
continue
if db_provider != 'sqlite3':
db_spec = {'dbname': db_spec}
user = os.environ.get('PGUSER')
port = os.environ.get('PGPORT')
host = os.environ.get('PGHOST')
if user:
db_spec['user'] = user
if host:
if not port:
db_spec['port'] = 5432
db_spec['host'] = host
if port:
if not host:
db_spec['host'] = 'localhost'
db_spec['port'] = port
for target in targets:
for table in tables:
for kind in types:
param_id = f'db={db} ' f'target={target}'
if table is not None:
param_id += f' table={table}'
if kind is not None:
param_id += f' kind={kind}'
param = pytest.param(
ContextParams(
db_provider, db_spec, target, table, kind
),
id=param_id,
)
ret.append(param)
return ret
ContextParams = namedtuple(
'ContextParams', ('db_provider', 'db_spec', 'target', 'table', 'kind')
)
Interface = namedtuple('Interface', ('index', 'ifname'))
Network = namedtuple('Network', ('family', 'network', 'netmask'))
class SpecContextManager(object):
'''
Prepare simple common variables
'''
def __init__(self, request, tmpdir):
self.uid = str(uuid.uuid4())
pid = os.getpid()
self.log_base = f'{tmpdir}/ndb-{pid}'
self.log_spec = (f'{self.log_base}-{self.uid}.log', logging.DEBUG)
self.db_spec = f'{self.log_base}-{self.uid}.sql'
def teardown(self):
pass
class NDBContextManager(object):
'''
This class is used to manage fixture contexts.
* create log spec
* create NDB with specified parameters
* provide methods to register interfaces
* automatically remove registered interfaces
'''
def __init__(self, request, tmpdir, **kwarg):
self.spec = SpecContextManager(request, tmpdir)
self.netns = None
#
# the cleanup registry
self.interfaces = {}
self.namespaces = {}
if 'log' not in kwarg:
kwarg['log'] = self.spec.log_spec
if 'rtnl_debug' not in kwarg:
kwarg['rtnl_debug'] = True
target = 'local'
self.table = None
self.kind = None
kwarg['db_provider'] = 'sqlite3'
kwarg['db_spec'] = ':memory:'
if hasattr(request, 'param'):
if isinstance(request.param, ContextParams):
target = request.param.target
self.table = request.param.table
self.kind = request.param.kind
kwarg['db_provider'] = request.param.db_provider
kwarg['db_spec'] = request.param.db_spec
elif isinstance(request.param, (tuple, list)):
target, self.table = request.param
else:
target = request.param
if target == 'local':
sources = [{'target': 'localhost', 'kind': 'local'}]
elif target == 'netns':
self.netns = self.new_nsname
sources = [
{'target': 'localhost', 'kind': 'netns', 'netns': self.netns}
]
else:
sources = None
if sources is not None:
kwarg['sources'] = sources
#
# select the DB to work on
db_name = os.environ.get('PYROUTE2_TEST_DBNAME')
if isinstance(db_name, basestring) and len(db_name):
kwarg['db_provider'] = 'psycopg2'
kwarg['db_spec'] = {'dbname': db_name}
#
# this instance is to be tested, so do NOT use it
# in utility methods
self.db_provider = kwarg['db_provider']
self.ndb = NDB(**kwarg)
self.ipr = self.ndb.sources['localhost'].nl.clone()
self.wg = WireGuard()
#
# IPAM
self.ipnets = [allocate_network() for _ in range(3)]
self.ipranges = [[str(x) for x in net] for net in self.ipnets]
self.ip6net = allocate_network(AF_INET6)
self.ip6counter = itertools.count(1024)
self.allocated_networks = {AF_INET: [], AF_INET6: []}
#
# RPDB objects for cleanup
self.rules = []
#
# default interface (if running as root)
if getpass.getuser() == 'root':
ifname = self.new_ifname
index = self.ndb.interfaces.create(
ifname=ifname, kind='dummy', state='up'
).commit()['index']
self.default_interface = Interface(index, ifname)
else:
self.default_interface = Interface(1, 'lo')
def register(self, ifname=None, netns=None):
'''
Register an interface in `self.interfaces`. If no interface
name specified, create a random one.
All the saved interfaces will be removed on `teardown()`
'''
if ifname is None:
ifname = uifname()
self.interfaces[ifname] = netns
return ifname
def register_netns(self, netns=None):
'''
Register netns in `self.namespaces`. If no netns name is
specified, create a random one.
All the save namespaces will be removed on `teardown()`
'''
if netns is None:
netns = str(uuid.uuid4())
self.namespaces[netns] = None
return netns
def register_rule(self, spec, netns=None):
'''
Register IP rule for cleanup on `teardown()`.
'''
self.rules.append((netns, spec))
return spec
def register_network(self, family=AF_INET, network=None):
'''
Register or allocate a network.
All the allocated networks should be deallocated on `teardown()`.
'''
if network is None:
network = allocate_network(family)
# regsiter for cleanup
self.allocated_networks[family].append(network)
# return a simple convenient named tuple
return Network(family, network.network.format(), network.prefixlen)
def get_ipaddr(self, r=0):
'''
Returns an ip address from the specified range.
'''
return str(self.ipranges[r].pop())
def get_ip6addr(self, r=0):
'''
Returns an ip6 address from the specified range.
'''
return str(self.ip6net[next(self.ip6counter)])
@property
def new_log(self, uid=None):
uid = uid or str(uuid.uuid4())
return f'{self.spec.log_base}-{uid}.log'
@property
def new_ifname(self):
'''
Returns a new unique ifname and registers it to be
cleaned up on `self.teardown()`
'''
return self.register()
@property
def new_ipaddr(self):
'''
Returns a new ipaddr from the configured range
'''
return self.get_ipaddr()
@property
def new_ip6addr(self):
'''
Returns a new ip6addr from the configured range
'''
return self.get_ip6addr()
@property
def new_ip4net(self):
'''
Returns a new IPv4 network
'''
return self.register_network(family=AF_INET)
@property
def new_ip6net(self):
'''
Returns a new IPv6 network
'''
return self.register_network(family=AF_INET6)
@property
def new_nsname(self):
'''
Returns a new unique nsname and registers it to be
removed on `self.teardown()`
'''
return self.register_netns()
def teardown(self):
'''
1. close the test NDB
2. remove the registered interfaces, ignore not existing
'''
# save postmortem DB for SQLite3
if self.db_provider == 'sqlite3' and sys.version_info >= (3, 7):
self.ndb.backup(f'{self.spec.uid}-post.db')
self.ndb.close()
self.ipr.close()
self.wg.close()
for ifname, nsname in self.interfaces.items():
try:
ipr = None
#
# spawn ipr to remove the interface
if nsname is not None:
ipr = NetNS(nsname)
else:
ipr = IPRoute()
#
# lookup the interface index
index = list(ipr.link_lookup(ifname=ifname))
if len(index):
index = index[0]
else:
#
# ignore not existing interfaces
continue
#
# try to remove it
ipr.link('del', index=index)
except NetlinkError as e:
#
# ignore if removed (t.ex. by another process)
if e.code != errno.ENODEV:
raise
finally:
if ipr is not None:
ipr.close()
for nsname in self.namespaces:
try:
netns.remove(nsname)
except FileNotFoundError:
pass
for nsname, rule in self.rules:
try:
ipr = None
if nsname is not None:
ipr = NetNS(nsname)
else:
ipr = IPRoute()
ipr.rule('del', **rule)
except NetlinkError as e:
if e.code != errno.ENOENT:
raise
finally:
if ipr is not None:
ipr.close()
for net in self.ipnets:
free_network(net)
for family, networks in self.allocated_networks.items():
for net in networks:
free_network(net, family)
|