1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""DNS Versioned Zones."""
import collections
import threading
from typing import Callable, Deque, Optional, Set, Union
import dns.exception
import dns.immutable
import dns.name
import dns.node
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.rdtypes.ANY.SOA
import dns.zone
class UseTransaction(dns.exception.DNSException):
"""To alter a versioned zone, use a transaction."""
# Backwards compatibility
Node = dns.zone.VersionedNode
ImmutableNode = dns.zone.ImmutableVersionedNode
Version = dns.zone.Version
WritableVersion = dns.zone.WritableVersion
ImmutableVersion = dns.zone.ImmutableVersion
Transaction = dns.zone.Transaction
class Zone(dns.zone.Zone): # lgtm[py/missing-equals]
__slots__ = [
"_versions",
"_versions_lock",
"_write_txn",
"_write_waiters",
"_write_event",
"_pruning_policy",
"_readers",
]
node_factory = Node
def __init__(
self,
origin: Optional[Union[dns.name.Name, str]],
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
relativize: bool = True,
pruning_policy: Optional[Callable[["Zone", Version], Optional[bool]]] = None,
):
"""Initialize a versioned zone object.
*origin* is the origin of the zone. It may be a ``dns.name.Name``,
a ``str``, or ``None``. If ``None``, then the zone's origin will
be set by the first ``$ORIGIN`` line in a zone file.
*rdclass*, an ``int``, the zone's rdata class; the default is class IN.
*relativize*, a ``bool``, determine's whether domain names are
relativized to the zone's origin. The default is ``True``.
*pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning
a ``bool``, or ``None``. Should the version be pruned? If ``None``,
the default policy, which retains one version is used.
"""
super().__init__(origin, rdclass, relativize)
self._versions: Deque[Version] = collections.deque()
self._version_lock = threading.Lock()
if pruning_policy is None:
self._pruning_policy = self._default_pruning_policy
else:
self._pruning_policy = pruning_policy
self._write_txn: Optional[Transaction] = None
self._write_event: Optional[threading.Event] = None
self._write_waiters: Deque[threading.Event] = collections.deque()
self._readers: Set[Transaction] = set()
self._commit_version_unlocked(
None, WritableVersion(self, replacement=True), origin
)
def reader(
self, id: Optional[int] = None, serial: Optional[int] = None
) -> Transaction: # pylint: disable=arguments-differ
if id is not None and serial is not None:
raise ValueError("cannot specify both id and serial")
with self._version_lock:
if id is not None:
version = None
for v in reversed(self._versions):
if v.id == id:
version = v
break
if version is None:
raise KeyError("version not found")
elif serial is not None:
if self.relativize:
oname = dns.name.empty
else:
assert self.origin is not None
oname = self.origin
version = None
for v in reversed(self._versions):
n = v.nodes.get(oname)
if n:
rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
if rds and rds[0].serial == serial:
version = v
break
if version is None:
raise KeyError("serial not found")
else:
version = self._versions[-1]
txn = Transaction(self, False, version)
self._readers.add(txn)
return txn
def writer(self, replacement: bool = False) -> Transaction:
event = None
while True:
with self._version_lock:
# Checking event == self._write_event ensures that either
# no one was waiting before we got lucky and found no write
# txn, or we were the one who was waiting and got woken up.
# This prevents "taking cuts" when creating a write txn.
if self._write_txn is None and event == self._write_event:
# Creating the transaction defers version setup
# (i.e. copying the nodes dictionary) until we
# give up the lock, so that we hold the lock as
# short a time as possible. This is why we call
# _setup_version() below.
self._write_txn = Transaction(
self, replacement, make_immutable=True
)
# give up our exclusive right to make a Transaction
self._write_event = None
break
# Someone else is writing already, so we will have to
# wait, but we want to do the actual wait outside the
# lock.
event = threading.Event()
self._write_waiters.append(event)
# wait (note we gave up the lock!)
#
# We only wake one sleeper at a time, so it's important
# that no event waiter can exit this method (e.g. via
# cancellation) without returning a transaction or waking
# someone else up.
#
# This is not a problem with Threading module threads as
# they cannot be canceled, but could be an issue with trio
# tasks when we do the async version of writer().
# I.e. we'd need to do something like:
#
# try:
# event.wait()
# except trio.Cancelled:
# with self._version_lock:
# self._maybe_wakeup_one_waiter_unlocked()
# raise
#
event.wait()
# Do the deferred version setup.
self._write_txn._setup_version()
return self._write_txn
def _maybe_wakeup_one_waiter_unlocked(self):
if len(self._write_waiters) > 0:
self._write_event = self._write_waiters.popleft()
self._write_event.set()
# pylint: disable=unused-argument
def _default_pruning_policy(self, zone, version):
return True
# pylint: enable=unused-argument
def _prune_versions_unlocked(self):
assert len(self._versions) > 0
# Don't ever prune a version greater than or equal to one that
# a reader has open. This pins versions in memory while the
# reader is open, and importantly lets the reader open a txn on
# a successor version (e.g. if generating an IXFR).
#
# Note our definition of least_kept also ensures we do not try to
# delete the greatest version.
if len(self._readers) > 0:
least_kept = min(txn.version.id for txn in self._readers)
else:
least_kept = self._versions[-1].id
while self._versions[0].id < least_kept and self._pruning_policy(
self, self._versions[0]
):
self._versions.popleft()
def set_max_versions(self, max_versions: Optional[int]) -> None:
"""Set a pruning policy that retains up to the specified number
of versions
"""
if max_versions is not None and max_versions < 1:
raise ValueError("max versions must be at least 1")
if max_versions is None:
def policy(zone, _): # pylint: disable=unused-argument
return False
else:
def policy(zone, _):
return len(zone._versions) > max_versions
self.set_pruning_policy(policy)
def set_pruning_policy(
self, policy: Optional[Callable[["Zone", Version], Optional[bool]]]
) -> None:
"""Set the pruning policy for the zone.
The *policy* function takes a `Version` and returns `True` if
the version should be pruned, and `False` otherwise. `None`
may also be specified for policy, in which case the default policy
is used.
Pruning checking proceeds from the least version and the first
time the function returns `False`, the checking stops. I.e. the
retained versions are always a consecutive sequence.
"""
if policy is None:
policy = self._default_pruning_policy
with self._version_lock:
self._pruning_policy = policy
self._prune_versions_unlocked()
def _end_read(self, txn):
with self._version_lock:
self._readers.remove(txn)
self._prune_versions_unlocked()
def _end_write_unlocked(self, txn):
assert self._write_txn == txn
self._write_txn = None
self._maybe_wakeup_one_waiter_unlocked()
def _end_write(self, txn):
with self._version_lock:
self._end_write_unlocked(txn)
def _commit_version_unlocked(self, txn, version, origin):
self._versions.append(version)
self._prune_versions_unlocked()
self.nodes = version.nodes
if self.origin is None:
self.origin = origin
# txn can be None in __init__ when we make the empty version.
if txn is not None:
self._end_write_unlocked(txn)
def _commit_version(self, txn, version, origin):
with self._version_lock:
self._commit_version_unlocked(txn, version, origin)
def _get_next_version_id(self):
if len(self._versions) > 0:
id = self._versions[-1].id + 1
else:
id = 1
return id
def find_node(
self, name: Union[dns.name.Name, str], create: bool = False
) -> dns.node.Node:
if create:
raise UseTransaction
return super().find_node(name)
def delete_node(self, name: Union[dns.name.Name, str]) -> None:
raise UseTransaction
def find_rdataset(
self,
name: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str],
covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
if create:
raise UseTransaction
rdataset = super().find_rdataset(name, rdtype, covers)
return dns.rdataset.ImmutableRdataset(rdataset)
def get_rdataset(
self,
name: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str],
covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
create: bool = False,
) -> Optional[dns.rdataset.Rdataset]:
if create:
raise UseTransaction
rdataset = super().get_rdataset(name, rdtype, covers)
if rdataset is not None:
return dns.rdataset.ImmutableRdataset(rdataset)
else:
return None
def delete_rdataset(
self,
name: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str],
covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
) -> None:
raise UseTransaction
def replace_rdataset(
self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset
) -> None:
raise UseTransaction
|