1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
|
import sys
import json
import traceback
from copy import copy
from collections import OrderedDict
_LOG_LEVELS = {
level: idx
for idx, level in enumerate(
("critical", "error", "warning", "notice", "info", "verbose", "debug")
)
}
def _skip_until_empty_line(file):
while True:
line = file.readline().strip()
if not line:
break
def _get_request(file, record_file=None):
line = file.readline()
blank_line = file.readline()
if record_file is not None:
record_file.write("< " + line)
record_file.write("< " + blank_line)
return json.loads(line.strip())
def _put_response(data, file, record_file=None):
data = json.dumps(data)
file.write(data + "\n\n")
file.flush()
if record_file is not None:
record_file.write("> " + data + "\n")
record_file.write("> \n")
def _would_log(level_set, msg_level):
if msg_level not in _LOG_LEVELS:
# uknown level, assume it would be logged
return True
return _LOG_LEVELS[msg_level] <= _LOG_LEVELS[level_set]
def _cfengine_type(typing):
if typing is str:
return "string"
if typing is int:
return "int"
if typing in (list, tuple):
return "slist"
if typing is dict:
return "data container"
if typing is bool:
return "true/false"
return "Error in promise module"
class AttributeObject(object):
def __init__(self, d):
for key, value in d.items():
setattr(self, key, value)
def __repr__(self):
return "{}({})".format(
self.__class__.__qualname__,
", ".join("{}={!r}".format(k, v) for k, v in self.__dict__.items())
)
class ValidationError(Exception):
def __init__(self, message):
self.message = message
class ProtocolError(Exception):
def __init__(self, message):
self.message = message
class Result:
# Promise evaluation outcomes, can reveal "real" problems with system:
KEPT = "kept" # Satisfied already, no change
REPAIRED = "repaired" # Not satisfied before , but fixed
NOT_KEPT = "not_kept" # Not satisfied before , not fixed
# Validation only, can reveal problems in CFEngine policy:
VALID = "valid" # Validation successful
INVALID = "invalid" # Validation failed, error in cfengine policy
# Generic succes / fail for init / terminate requests:
SUCCESS = "success"
FAILURE = "failure"
# Unexpected, can reveal problems in promise module:
ERROR = "error" # Something went wrong in module / protocol
class PromiseModule:
def __init__(
self, name="default_module_name", version="0.0.1", record_file_path=None
):
self.name = name
self.version = version
# Note: The class doesn't expose any way to set protocol version
# or flags, because that should be abstracted away from the
# user (module author).
self._validator_attributes = OrderedDict()
self._result_classes = None
# File to record all the incoming and outgoing communication
self._record_file = open(record_file_path, "a") if record_file_path else None
def start(self, in_file=None, out_file=None):
self._in = in_file or sys.stdin
self._out = out_file or sys.stdout
first_line = self._in.readline()
if self._record_file is not None:
self._record_file.write("< " + first_line)
header = first_line.strip().split(" ")
name = header[0]
version = header[1]
protocol_version = header[2]
# flags = header[3:] -- unused for now
assert len(name) > 0 # cf-agent
assert version.startswith("3.") # 3.18.0
assert protocol_version[0] == "v" # v1
_skip_until_empty_line(self._in)
header_reply = "{name} {version} v1 json_based\n\n".format(
name=self.name, version=self.version
)
self._out.write(header_reply)
self._out.flush()
if self._record_file is not None:
self._record_file.write("> " + header_reply.strip() + "\n")
self._record_file.write(">\n")
while True:
self._response = {}
self._result = None
request = _get_request(self._in, self._record_file)
self._handle_request(request)
def _convert_types(self, promiser, attributes):
# Will only convert types if module has typing information:
if not self._has_validation_attributes:
return promiser, attributes
replacements = {}
for name, value in attributes.items():
if type(value) is not str:
# If something is not string, assume it is correct type
continue
if name not in self._validator_attributes:
# Unknown attribute, this will cause a validation error later
continue
# "true"/"false" -> True/False
if self._validator_attributes[name]["typing"] is bool:
if value == "true":
replacements[name] = True
elif value == "false":
replacements[name] = False
# "int" -> int()
elif self._validator_attributes[name]["typing"] is int:
try:
replacements[name] = int(value)
except ValueError:
pass
# Don't edit dict while iterating over it, after instead:
attributes.update(replacements)
return (promiser, attributes)
def _handle_request(self, request):
if not request:
sys.exit("Error: Empty/invalid request or EOF reached")
operation = request["operation"]
self._log_level = request.get("log_level", "info")
self._response["operation"] = operation
# Agent will never request log level critical
assert self._log_level in [
"error",
"warning",
"notice",
"info",
"verbose",
"debug",
]
if operation in ["validate_promise", "evaluate_promise"]:
promiser = request["promiser"]
attributes = request.get("attributes", {})
promiser, attributes = self._convert_types(promiser, attributes)
promiser, attributes = self.prepare_promiser_and_attributes(
promiser, attributes
)
self._response["promiser"] = promiser
self._response["attributes"] = attributes
if operation == "init":
self._handle_init()
elif operation == "validate_promise":
self._handle_validate(promiser, attributes, request)
elif operation == "evaluate_promise":
self._handle_evaluate(promiser, attributes, request)
elif operation == "terminate":
self._handle_terminate()
else:
self._log_level = None
raise ProtocolError(
"Unknown operation: '{operation}'".format(operation=operation)
)
self._log_level = None
def _add_result(self):
self._response["result"] = self._result
def _add_result_classes(self):
if self._result_classes:
self._response["result_classes"] = self._result_classes
def _add_traceback_to_response(self):
if self._log_level != "debug":
return
trace = traceback.format_exc()
logs = self._response.get("log", [])
logs.append({"level": "debug", "message": trace})
self._response["log"] = logs
def add_attribute(
self,
name,
typing,
default=None,
required=False,
default_to_promiser=False,
validator=None,
):
attribute = OrderedDict()
attribute["name"] = name
attribute["typing"] = typing
attribute["default"] = default
attribute["required"] = required
attribute["default_to_promiser"] = default_to_promiser
attribute["validator"] = validator
self._validator_attributes[name] = attribute
@property
def _has_validation_attributes(self):
return bool(self._validator_attributes)
def create_attribute_dict(self, promiser, attributes):
# Check for missing required attributes:
for name, attribute in self._validator_attributes.items():
if attribute["required"] and name not in attributes:
raise ValidationError(
"Missing required attribute '{name}'".format(name=name)
)
# Check for unknown attributes:
for name in attributes:
if name not in self._validator_attributes:
raise ValidationError("Unknown attribute '{name}'".format(name=name))
# Check typings and run custom validator callbacks:
for name, value in attributes.items():
expected = _cfengine_type(self._validator_attributes[name]["typing"])
found = _cfengine_type(type(value))
if found != expected:
raise ValidationError(
"Wrong type for attribute '{name}', requires '{expected}', not '{value}'({found})".format(
name=name, expected=expected, value=value, found=found
)
)
if self._validator_attributes[name]["validator"]:
# Can raise ValidationError:
self._validator_attributes[name]["validator"](value)
attribute_dict = OrderedDict()
# Copy attributes specified by user policy:
for key, value in attributes.items():
attribute_dict[key] = value
# Set defaults based on promise module validation hints:
for name, value in self._validator_attributes.items():
if value.get("default_to_promiser", False):
attribute_dict.setdefault(name, promiser)
elif value.get("default", None) is not None:
attribute_dict.setdefault(name, copy(value["default"]))
else:
attribute_dict.setdefault(name, None)
return attribute_dict
def create_attribute_object(self, promiser, attributes):
attribute_dict = self.create_attribute_dict(promiser, attributes)
return AttributeObject(attribute_dict)
def _validate_attributes(self, promiser, attributes):
if not self._has_validation_attributes:
# Can only validate attributes if module
# provided typings for attributes
return
self.create_attribute_object(promiser, attributes)
return # Only interested in exceptions, return None
def _handle_init(self):
self._result = self.protocol_init(None)
self._add_result()
_put_response(self._response, self._out, self._record_file)
def _handle_validate(self, promiser, attributes, request):
meta = {"promise_type": request.get("promise_type")}
try:
self.validate_attributes(promiser, attributes, meta)
returned = self.validate_promise(promiser, attributes, meta)
if returned is None:
# Good, expected
self._result = Result.VALID
else:
# Bad, validate method shouldn't return anything else
self.log_critical(
"Bug in promise module {name} - validate_promise() should not return anything".format(
name=self.name
)
)
self._result = Result.ERROR
except ValidationError as e:
message = str(e)
if "promise_type" in request:
message += " for {request_promise_type} promise with promiser '{promiser}'".format(
request_promise_type=request["promise_type"], promiser=promiser
)
else:
message += " for promise with promiser '{promiser}'".format(
promiser=promiser
)
if "filename" in request and "line_number" in request:
message += " ({request_filename}:{request_line_number})".format(
request_filename=request["filename"],
request_line_number=request["line_number"],
)
self.log_error(message)
self._result = Result.INVALID
except Exception as e:
self.log_critical(
"{error_type}: {error}".format(error_type=type(e).__name__, error=e)
)
self._add_traceback_to_response()
self._result = Result.ERROR
self._add_result()
_put_response(self._response, self._out, self._record_file)
def _handle_evaluate(self, promiser, attributes, request):
self._result_classes = None
meta = {"promise_type": request.get("promise_type")}
try:
results = self.evaluate_promise(promiser, attributes, meta)
# evaluate_promise should return either a result or a (result, result_classes) pair
if type(results) == str:
self._result = results
else:
assert len(results) == 2
self._result = results[0]
self._result_classes = results[1]
except Exception as e:
self.log_critical(
"{error_type}: {error}".format(error_type=type(e).__name__, error=e)
)
self._add_traceback_to_response()
self._result = Result.ERROR
self._add_result()
self._add_result_classes()
_put_response(self._response, self._out, self._record_file)
def _handle_terminate(self):
self._result = self.protocol_terminate()
self._add_result()
_put_response(self._response, self._out, self._record_file)
sys.exit(0)
def _log(self, level, message):
if self._log_level is not None and not _would_log(self._log_level, level):
return
# Message can be str or an object which implements __str__()
# for example an exception:
message = str(message).replace("\n", r"\n")
assert "\n" not in message
self._out.write("log_{level}={message}\n".format(level=level, message=message))
self._out.flush()
if self._record_file is not None:
self._record_file.write(
"log_{level}={message}\n".format(level=level, message=message)
)
def log_critical(self, message):
self._log("critical", message)
def log_error(self, message):
self._log("error", message)
def log_warning(self, message):
self._log("warning", message)
def log_notice(self, message):
self._log("notice", message)
def log_info(self, message):
self._log("info", message)
def log_verbose(self, message):
self._log("verbose", message)
def log_debug(self, message):
self._log("debug", message)
def _log_traceback(self):
trace = traceback.format_exc().split("\n")
for line in trace:
self.log_debug(line)
# Functions to override in subclass:
def protocol_init(self, version):
return Result.SUCCESS
def prepare_promiser_and_attributes(self, promiser, attributes):
"""Override if you want to modify promiser or attributes before validate or evaluate"""
return (promiser, attributes)
def validate_attributes(self, promiser, attributes, meta):
"""Override this if you want to prevent automatic validation"""
return self._validate_attributes(promiser, attributes)
def validate_promise(self, promiser, attributes, meta):
"""Must override this or use validation through self.add_attribute()"""
if not self._has_validation_attributes:
raise NotImplementedError("Promise module must implement validate_promise")
def evaluate_promise(self, promiser, attributes, meta):
raise NotImplementedError("Promise module must implement evaluate_promise")
def protocol_terminate(self):
return Result.SUCCESS
|