1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
|
from copy import deepcopy
from enum import Enum
import numpy as np
from dataclasses import is_dataclass, fields, dataclass
import graphlib
import json
from importlib import import_module
from typing import Dict, List, Literal, Optional, TypedDict, Union
from types import GeneratorType
import logging
import warnings
import inspect
from base64 import b64encode, b64decode
from . import plugin
from .util import SCHEMA_ATTRIBUTE_NAME, get_libraries
DEBUG = False
class SCHEMA_VERSIONS(str, Enum):
BUMPS_DRAFT_O1 = "bumps-draft-01"
BUMPS_DRAFT_02 = "bumps-draft-02"
BUMPS_DRAFT_03 = "bumps-draft-03"
SCHEMA = SCHEMA_VERSIONS.BUMPS_DRAFT_03
REFERENCES_KEY = "references"
REFERENCE_IDENTIFIER = "$ref"
MISSING = object()
REFERENCE_TYPE_NAME = "Reference"
REFERENCE_TYPE = Literal["Reference"]
TYPE_KEY = "__class__"
@dataclass
class Reference:
id: str
type: REFERENCE_TYPE
JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None]
class SerializedObject(TypedDict, total=True):
schema: str
object: JSON
references: Dict[str, JSON]
def deserialize(serialized: SerializedObject, migration: bool = True):
"""rehydrate all items in serialzed['references'] then
- reydrate all objects in serialized['object']
- replacing `Reference` types with python objects from `references`
"""
if migration:
# first apply built-in migrations:
_, serialized = migrate(serialized)
# then apply plugin migrations:
serialized = plugin.migrate_serialized(serialized)
serialized_references = serialized[REFERENCES_KEY]
references = {}
dependency_graph = {}
for ref_id, ref_obj in serialized_references.items():
dependencies = set()
_find_ref_dependencies(ref_obj, dependencies)
dependency_graph[ref_id] = dependencies
sorter = graphlib.TopologicalSorter(dependency_graph)
for ref_id in sorter.static_order():
# deserialize and put all references into self.references
references[ref_id] = _rehydrate(serialized_references[ref_id], references)
# references is now full of deserialized objects,
# and we're ready to rehydrate the entire tree...
obj = _rehydrate(serialized["object"], references)
return obj
#### deserializer helpers:
def _rehydrate(obj, references: Dict[str, object]):
if isinstance(obj, dict):
obj = obj.copy()
t: str = obj.pop(TYPE_KEY, MISSING)
if t == REFERENCE_TYPE_NAME:
obj_id: str = obj.get("id", MISSING)
if obj_id is MISSING:
raise ValueError("object id is required for Reference type")
# requires that self.references is populated with rehydrated objects:
return references[obj_id]
elif t == "bumps.util.NumpyArray":
# skip processing values list in ndarray
return _to_ndarray(obj)
elif t == "Callable":
return deserialize_function(obj)
else:
for key in obj:
obj[key] = _rehydrate(obj[key], references)
if t is MISSING:
# no "type" provided, so no class to instantiate: return hydrated object
return obj
# obj values are now rehydrated: instantiate the class from 'type'
else:
try:
module_name, class_name = t.rsplit(".", 1)
# print(module_name, class_name)
klass = getattr(import_module(module_name), class_name)
except Exception as e:
# there is a type, but it is not found...
logging.exception(e)
raise RuntimeError(f"Error importing {t}: {e}")
hydrated = _instantiate(klass, t, obj)
# print("returning instantiated object", t, hydrated)
return hydrated
elif isinstance(obj, list):
# rehydrate all the items
return [_rehydrate(v, references) for v in obj]
else:
# it's a bare value - just return
return obj
def _instantiate(klass: type, typename: str, serialized: dict):
# TODO: why are we copying the top-level dict?
serialized = serialized.copy()
# if klass provides 'from_dict' method, use it -
# otherwise use klass.__init__ directly.
class_factory = getattr(klass, "from_dict", klass)
try:
hydrated = class_factory(**serialized)
except Exception as e:
logging.exception(e)
warnings.warn(f"Error restoring {typename}: {e}")
# Note: users of the failed deserialization may complain that it is None
hydrated = None
return hydrated
def _to_ndarray(obj: dict):
return np.asarray(obj["values"], dtype=np.dtype(obj.get("dtype", float)))
def _find_ref_dependencies(obj, dependencies: set):
if isinstance(obj, dict):
if obj.get(TYPE_KEY, None) == REFERENCE_TYPE:
dependencies.add(obj["id"])
else:
for v in obj.values():
_find_ref_dependencies(v, dependencies)
elif isinstance(obj, list):
for v in obj:
_find_ref_dependencies(v, dependencies)
#### end deserializer helpers
def serialize(obj, use_refs=True, add_libraries=True):
references = {}
def make_ref(obj_id: str):
return {"id": obj_id, TYPE_KEY: REFERENCE_TYPE_NAME}
def dataclass_to_dict(dclass, include=None, exclude=None):
all_fields = fields(dclass)
if include is not None:
all_fields = [f for f in all_fields if f.name in include]
elif exclude is not None:
all_fields = [f for f in all_fields if not f.name in exclude and not f.name.startswith("_")]
else:
all_fields = [f for f in all_fields if not f.name.startswith("_")]
cls = dclass.__class__
fqn = f"{cls.__module__}.{cls.__qualname__}"
output = dict([(f.name, obj_to_dict(getattr(dclass, f.name))) for f in all_fields])
output[TYPE_KEY] = fqn
return output
def obj_to_dict(obj):
if hasattr(obj, SCHEMA_ATTRIBUTE_NAME):
schema_opts = getattr(obj, SCHEMA_ATTRIBUTE_NAME)
include = schema_opts.get("include", None)
exclude = schema_opts.get("exclude", None)
use_ref = use_refs and hasattr(obj, "id")
if (not use_ref) or (obj.id not in references):
# only calculate dict if it's not already in refs, or if not using refs
obj_dict = dataclass_to_dict(obj, include=include, exclude=exclude)
if use_ref:
references.setdefault(obj.id, obj_dict)
return make_ref(obj.id) if use_ref else obj_dict
elif is_dataclass(obj):
return dataclass_to_dict(obj)
elif isinstance(obj, (list, tuple, GeneratorType)):
return list(obj_to_dict(v) for v in obj)
# elif isinstance(obj, GeneratorType) and issubclass(obj_type, Tuple):
# return tuple(to_dict(v) for v in obj)
elif isinstance(obj, dict):
return type(obj)((obj_to_dict(k), obj_to_dict(v)) for k, v in obj.items())
elif isinstance(obj, np.ndarray) and obj.dtype.kind in ["f", "i", "U"]:
return {TYPE_KEY: "bumps.util.NumpyArray", "dtype": str(obj.dtype), "values": obj.tolist()}
elif isinstance(obj, np.ndarray) and obj.dtype.kind == "O":
return obj_to_dict(obj.tolist())
elif isinstance(obj, Enum):
return obj_to_dict(obj.value)
elif isinstance(obj, float):
return str(obj) if np.isinf(obj) else obj
elif isinstance(obj, int) or isinstance(obj, str) or obj is None:
return obj
elif callable(obj):
return serialize_function(obj)
else:
raise ValueError("obj %s is not serializable" % str(obj))
serialized = {"$schema": SCHEMA, "object": obj_to_dict(obj), REFERENCES_KEY: references}
if add_libraries:
serialized["libraries"] = get_libraries(obj)
# print("serialized as", serialized)
return serialized
def deserialize_function(obj):
import dill
import cloudpickle
try:
# CRUFT: older versions did not specify the pickler
pickler = obj.get("pickler", "dill")
data = deserialize_bytes(obj["pickle"])
if pickler == "dill":
return dill.loads(data)
elif pickler == "cloudpickle":
return cloudpickle.loads(deserialize_bytes(obj["pickle"]))
else:
raise ValueError(f"unrecognized pickler {pickler}")
except Exception as e:
logging.exception(e)
warnings.warn(f"Error loading function: {e}")
return None
def serialize_bytes(b):
return b64encode(b).decode("ascii")
def deserialize_bytes(s):
return b64decode(s)
def serialize_function(fn):
from cloudpickle import dumps
name = getattr(fn, "__name__", "unknown")
# print("type fn", type(fn))
# Note: need dedent to handle decorator syntax. Dedent will fail when there are
# triple-quoted strings. Alternative: if first character is a space, then wrap
# the code in an "if True:" block
# source = dedent(inspect.getsource(fn)) #.strip()
try:
source = inspect.getsource(fn)
except Exception:
source = None
# print("source =>", source)
pickle = serialize_bytes(dumps(fn))
res = {TYPE_KEY: "Callable", "name": name, "source": source, "pickle": pickle, "pickler": "cloudpickle"}
# print(f"serializing {fn} to {res}")
return res
def save_file(filename, problem):
try:
p = serialize(problem)
with open(filename, "w") as fid:
json.dump(p, fid)
except Exception as e:
logging.exception(e)
warnings.warn(f"failed to create JSON file {filename} for fit problem")
def load_file(filename):
with open(filename, "r") as fid:
serialized: SerializedObject = json.loads(fid.read())
final_version, migrated = migrate(serialized)
# print("final version: ", final_version)
return deserialize(migrated)
#### MIGRATIONS
def validate_version(version: str, variable_name="from_version"):
if version not in list(SCHEMA_VERSIONS):
raise ValueError(f"must choose a valid {variable_name} from this list: {[s.value for s in SCHEMA_VERSIONS]}")
def migrate(
serialized: dict, from_version: Optional[SCHEMA_VERSIONS] = None, to_version: Optional[SCHEMA_VERSIONS] = SCHEMA
):
"""
Migrate a serialized object from one version to another
By default, the `from_version` is determined by inspection of the serialized object.
This is overriden by setting the `from_version` keyword argument to a member of `SCHEMA_VERSIONS`
Also by default, the target version is the current schema, which can be overriden with
the `to_version` keyword argument
"""
if from_version is None:
from_version = serialized.get(
"$schema", SCHEMA_VERSIONS.BUMPS_DRAFT_O1
) # fall back to first version if not specified
validate_version(from_version, "from_version")
validate_version(to_version, "to_version")
current_version = from_version
while current_version != to_version:
print(f"migrating {current_version}...")
current_version, serialized = MIGRATIONS[current_version](serialized)
return current_version, serialized
def _migrate_draft_01_to_draft_02(serialized: dict):
references = {}
def rename_type(obj):
if isinstance(obj, dict):
t: str = obj.pop("type", obj.pop(TYPE_KEY, MISSING))
# print(f"moving type to __class__ for id {obj.get('id', 'no id')}")
if t is not MISSING:
obj[TYPE_KEY] = t
for v in obj.values():
rename_type(v)
elif isinstance(obj, list):
for v in obj:
rename_type(v)
def build_references(obj):
if isinstance(obj, dict):
t: str = obj.get(TYPE_KEY, MISSING)
obj_id: str = obj.get("id", MISSING)
# if obj_id is not MISSING:
# print(f"building reference for id {obj_id}")
if obj_id is not MISSING and not t in [MISSING, REFERENCE_TYPE_NAME]:
if not obj_id in references:
references[obj_id] = deepcopy(obj)
obj[TYPE_KEY] = REFERENCE_TYPE_NAME
for k in list(obj.keys()):
if k not in [TYPE_KEY, "id"]:
del obj[k]
for v in obj.values():
build_references(v)
elif isinstance(obj, list):
for v in obj:
build_references(v)
rename_type(serialized)
build_references(serialized)
migrated = {
"$schema": SCHEMA_VERSIONS.BUMPS_DRAFT_02.value,
"object": serialized,
"references": references,
}
return SCHEMA_VERSIONS.BUMPS_DRAFT_02, migrated
def _migrate_draft_02_to_draft_03(serialized: dict):
# add migration code here
def div_to_truediv(obj):
# remove all 'div' operators and replace with 'truediv'
if isinstance(obj, dict) and obj.get(TYPE_KEY, MISSING) == "bumps.parameter.Expression":
if obj.get("op", MISSING) == "div":
obj["op"] = "truediv"
for v in obj.get("args", []):
div_to_truediv(v)
elif isinstance(obj, dict):
for k, v in obj.items():
div_to_truediv(v)
elif isinstance(obj, list):
for v in obj:
div_to_truediv(v)
migrated = deepcopy(serialized)
div_to_truediv(migrated)
return SCHEMA_VERSIONS.BUMPS_DRAFT_03, migrated
MIGRATIONS = {
SCHEMA_VERSIONS.BUMPS_DRAFT_O1: _migrate_draft_01_to_draft_02,
SCHEMA_VERSIONS.BUMPS_DRAFT_02: _migrate_draft_02_to_draft_03,
}
|