1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional, List
from pkb_client.client.dns import DNSRecordType, DNS_RECORDS_WITH_PRIORITY
class RecordClass(str, Enum):
IN = "IN"
def __str__(self):
return self.value
@dataclass
class BindRecord:
name: str
ttl: int
record_class: RecordClass
record_type: DNSRecordType
data: str
prio: Optional[int] = None
comment: Optional[str] = None
def __str__(self):
record_string = f"{self.name} {self.ttl} {self.record_class} {self.record_type}"
if self.prio is not None:
record_string += f" {self.prio}"
record_string += f' "{self.data}"'
if self.comment:
record_string += f" ; {self.comment}"
return record_string
class BindFile:
origin: str
ttl: Optional[int] = None
records: List[BindRecord]
def __init__(
self,
origin: str,
ttl: Optional[int] = None,
records: Optional[List[BindRecord]] = None,
) -> None:
self.origin = origin
self.ttl = ttl
self.records = records or []
@staticmethod
def from_file(file_path: str) -> "BindFile":
with open(file_path, "r") as f:
file_data = f.readlines()
# parse the file line by line
origin = None
ttl = None
records = []
for line in file_data:
if line.startswith("$ORIGIN"):
origin = line.split()[1]
elif line.startswith("$TTL"):
ttl = int(line.split()[1])
else:
# parse the records with the two possible formats:
# 1: name ttl record-class record-type record-data
# 2: name record-class ttl record-type record-data
# whereby the ttl is optional
record_parts = line.strip().split()
# skip comments
if not record_parts or record_parts[0].startswith(";"):
continue
prio = None
if record_parts[1].isdigit():
# scheme 1
if record_parts[3] not in DNSRecordType.__members__:
logging.warning(f"Ignoring unsupported record type: {line}")
continue
if record_parts[2] not in RecordClass.__members__:
logging.warning(f"Ignoring unsupported record class: {line}")
continue
record_name = record_parts[0]
record_ttl = int(record_parts[1])
record_class = RecordClass[record_parts[2]]
record_type = DNSRecordType[record_parts[3]]
if record_type in DNS_RECORDS_WITH_PRIORITY:
prio = int(record_parts[4])
record_data = " ".join(record_parts[5:])
else:
record_data = " ".join(record_parts[4:])
elif record_parts[2].isdigit():
# scheme 2
if record_parts[3] not in DNSRecordType.__members__:
logging.warning(f"Ignoring unsupported record type: {line}")
continue
if record_parts[1] not in RecordClass.__members__:
logging.warning(f"Ignoring unsupported record class: {line}")
continue
record_name = record_parts[0]
record_ttl = int(record_parts[2])
record_class = RecordClass[record_parts[1]]
record_type = DNSRecordType[record_parts[3]]
if record_type in DNS_RECORDS_WITH_PRIORITY:
prio = int(record_parts[4])
record_data = " ".join(record_parts[5:])
else:
record_data = " ".join(record_parts[4:])
else:
# no ttl, use default or previous
if record_parts[2] not in DNSRecordType.__members__:
logging.warning(f"Ignoring unsupported record type: {line}")
continue
if record_parts[1] not in RecordClass.__members__:
logging.warning(f"Ignoring unsupported record class: {line}")
continue
record_name = record_parts[0]
if ttl is None and not records:
raise ValueError("No TTL found in file")
record_ttl = ttl or records[-1].ttl
record_class = RecordClass[record_parts[1]]
record_type = DNSRecordType[record_parts[2]]
if record_type in DNS_RECORDS_WITH_PRIORITY:
prio = int(record_parts[3])
record_data = " ".join(record_parts[4:])
else:
record_data = " ".join(record_parts[3:])
# replace @ in record name with origin
record_name = record_name.replace("@", origin)
# handle comments and quoted strings as record data
comment = None
line = record_data.strip()
if line.startswith('"'):
# find rightmost double quote
rindex = line.rfind('"')
if rindex != -1:
# split at the last double quote
line_parts = line.rsplit('"', 1)
record_data = line_parts[0].strip('"')
comment = line_parts[1].strip() if len(line_parts) > 1 else None
# left strip semicolon from comment
if comment and comment.startswith(";"):
comment = comment[1:].strip()
if not comment:
comment = None
else:
record_data = line.strip('"')
else:
# try to split at the first semicolon for comments
if ";" in line:
record_data, comment = line.split(";", 1)
record_data = record_data.strip()
comment = comment.strip()
else:
record_data = line
records.append(
BindRecord(
record_name,
record_ttl,
record_class,
record_type,
record_data,
prio=prio,
comment=comment,
)
)
if origin is None:
raise ValueError("No origin found in file")
return BindFile(origin, ttl, records)
def to_file(self, file_path: str) -> None:
with open(file_path, "w") as f:
f.write(str(self))
def __str__(self) -> str:
bind = f"$ORIGIN {self.origin}\n"
if self.ttl is not None:
bind += f"$TTL {self.ttl}\n"
for record in self.records:
bind += f"{record}\n"
return bind
|