1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
import datetime
import json
import numpy as np
from ase.utils import reader, writer
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, 'todict'):
d = obj.todict()
if not isinstance(d, dict):
raise RuntimeError('todict() of {} returned object of type {} '
'but should have returned dict'
.format(obj, type(d)))
if hasattr(obj, 'ase_objtype'):
d['__ase_objtype__'] = obj.ase_objtype
return d
if isinstance(obj, np.ndarray):
flatobj = obj.ravel()
if np.iscomplexobj(obj):
flatobj.dtype = obj.real.dtype
return {'__ndarray__': (obj.shape,
obj.dtype.name,
flatobj.tolist())}
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.bool_):
return bool(obj)
if isinstance(obj, datetime.datetime):
return {'__datetime__': obj.isoformat()}
if isinstance(obj, complex):
return {'__complex__': (obj.real, obj.imag)}
return json.JSONEncoder.default(self, obj)
encode = MyEncoder().encode
def object_hook(dct):
if '__datetime__' in dct:
return datetime.datetime.strptime(dct['__datetime__'],
'%Y-%m-%dT%H:%M:%S.%f')
if '__complex__' in dct:
return complex(*dct['__complex__'])
if '__ndarray__' in dct:
return create_ndarray(*dct['__ndarray__'])
# No longer used (only here for backwards compatibility):
if '__complex_ndarray__' in dct:
r, i = (np.array(x) for x in dct['__complex_ndarray__'])
return r + i * 1j
if '__ase_objtype__' in dct:
objtype = dct.pop('__ase_objtype__')
dct = numpyfy(dct)
return create_ase_object(objtype, dct)
return dct
def create_ndarray(shape, dtype, data):
"""Create ndarray from shape, dtype and flattened data."""
array = np.empty(shape, dtype=dtype)
flatbuf = array.ravel()
if np.iscomplexobj(array):
flatbuf.dtype = array.real.dtype
flatbuf[:] = data
return array
def create_ase_object(objtype, dct):
# We just try each object type one after another and instantiate
# them manually, depending on which kind it is.
# We can formalize this later if it ever becomes necessary.
if objtype == 'cell':
from ase.cell import Cell
dct.pop('pbc', None) # compatibility; we once had pbc
obj = Cell(**dct)
elif objtype == 'bandstructure':
from ase.spectrum.band_structure import BandStructure
obj = BandStructure(**dct)
elif objtype == 'bandpath':
from ase.dft.kpoints import BandPath
obj = BandPath(path=dct.pop('labelseq'), **dct)
elif objtype == 'atoms':
from ase import Atoms
obj = Atoms.fromdict(dct)
else:
raise ValueError('Do not know how to decode object type {} '
'into an actual object'.format(objtype))
assert obj.ase_objtype == objtype
return obj
mydecode = json.JSONDecoder(object_hook=object_hook).decode
def intkey(key):
"""Convert str to int if possible."""
try:
return int(key)
except ValueError:
return key
def fix_int_keys_in_dicts(obj):
"""Convert "int" keys: "1" -> 1.
The json.dump() function will convert int keys in dicts to str keys.
This function goes the other way.
"""
if isinstance(obj, dict):
return {intkey(key): fix_int_keys_in_dicts(value)
for key, value in obj.items()}
return obj
def numpyfy(obj):
if isinstance(obj, dict):
if '__complex_ndarray__' in obj:
r, i = (np.array(x) for x in obj['__complex_ndarray__'])
return r + i * 1j
if isinstance(obj, list) and len(obj) > 0:
try:
a = np.array(obj)
except ValueError:
pass
else:
if a.dtype in [bool, int, float]:
return a
obj = [numpyfy(value) for value in obj]
return obj
def decode(txt, always_array=True):
obj = mydecode(txt)
obj = fix_int_keys_in_dicts(obj)
if always_array:
obj = numpyfy(obj)
return obj
@reader
def read_json(fd, always_array=True):
dct = decode(fd.read(), always_array=always_array)
return dct
@writer
def write_json(fd, obj):
fd.write(encode(obj))
|