1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
|
import json
import numpy as np
from psycopg2 import connect
from psycopg2.extras import execute_values
from ase.db.sqlite import (init_statements, index_statements, VERSION,
SQLite3Database)
from ase.io.jsonio import (encode as ase_encode,
create_ase_object, create_ndarray)
jsonb_indices = [
'CREATE INDEX idxkeys ON systems USING GIN (key_value_pairs);',
'CREATE INDEX idxcalc ON systems USING GIN (calculator_parameters);']
def remove_nan_and_inf(obj):
if isinstance(obj, float) and not np.isfinite(obj):
return {'__special_number__': str(obj)}
if isinstance(obj, list):
return [remove_nan_and_inf(x) for x in obj]
if isinstance(obj, dict):
return {key: remove_nan_and_inf(value) for key, value in obj.items()}
if isinstance(obj, np.ndarray) and not np.isfinite(obj).all():
return remove_nan_and_inf(obj.tolist())
return obj
def insert_nan_and_inf(obj):
if isinstance(obj, dict) and '__special_number__' in obj:
return float(obj['__special_number__'])
if isinstance(obj, list):
return [insert_nan_and_inf(x) for x in obj]
if isinstance(obj, dict):
return {key: insert_nan_and_inf(value) for key, value in obj.items()}
return obj
class Connection:
def __init__(self, con):
self.con = con
def cursor(self):
return Cursor(self.con.cursor())
def commit(self):
self.con.commit()
def close(self):
self.con.close()
class Cursor:
def __init__(self, cur):
self.cur = cur
def fetchone(self):
return self.cur.fetchone()
def fetchall(self):
return self.cur.fetchall()
def execute(self, statement, *args):
self.cur.execute(statement.replace('?', '%s'), *args)
def executemany(self, statement, *args):
if len(args[0]) > 0:
N = len(args[0][0])
else:
return
if 'INSERT INTO systems' in statement:
q = 'DEFAULT' + ', ' + ', '.join('?' * N) # DEFAULT for id
else:
q = ', '.join('?' * N)
statement = statement.replace('({})'.format(q), '%s')
q = '({})'.format(q.replace('?', '%s'))
execute_values(self.cur, statement.replace('?', '%s'),
argslist=args[0], template=q, page_size=len(args[0]))
def insert_ase_and_ndarray_objects(obj):
if isinstance(obj, dict):
objtype = obj.pop('__ase_objtype__', None)
if objtype is not None:
return create_ase_object(objtype,
insert_ase_and_ndarray_objects(obj))
data = obj.get('__ndarray__')
if data is not None:
return create_ndarray(*data)
return {key: insert_ase_and_ndarray_objects(value)
for key, value in obj.items()}
if isinstance(obj, list):
return [insert_ase_and_ndarray_objects(value) for value in obj]
return obj
class PostgreSQLDatabase(SQLite3Database):
type = 'postgresql'
default = 'DEFAULT'
def encode(self, obj, binary=False):
return ase_encode(remove_nan_and_inf(obj))
def decode(self, obj, lazy=False):
return insert_ase_and_ndarray_objects(insert_nan_and_inf(obj))
def blob(self, array):
"""Convert array to blob/buffer object."""
if array is None:
return None
if len(array) == 0:
array = np.zeros(0)
if array.dtype == np.int64:
array = array.astype(np.int32)
return array.tolist()
def deblob(self, buf, dtype=float, shape=None):
"""Convert blob/buffer object to ndarray of correct dtype and shape.
(without creating an extra view)."""
if buf is None:
return None
return np.array(buf, dtype=dtype)
def _connect(self):
return Connection(connect(self.filename))
def _initialize(self, con):
if self.initialized:
return
self._metadata = {}
cur = con.cursor()
cur.execute("show search_path;")
schema = cur.fetchone()[0].split(', ')
if schema[0] == '"$user"':
schema = schema[1]
else:
schema = schema[0]
cur.execute("""
SELECT EXISTS(select * from information_schema.tables where
table_name='information' and table_schema='{}');
""".format(schema))
if not cur.fetchone()[0]: # information schema doesn't exist.
# Initialize database:
sql = ';\n'.join(init_statements)
sql = schema_update(sql)
cur.execute(sql)
if self.create_indices:
cur.execute(';\n'.join(index_statements))
cur.execute(';\n'.join(jsonb_indices))
con.commit()
self.version = VERSION
else:
cur.execute('select * from information;')
for name, value in cur.fetchall():
if name == 'version':
self.version = int(value)
elif name == 'metadata':
self._metadata = json.loads(value)
assert 5 < self.version <= VERSION
self.initialized = True
def get_offset_string(self, offset, limit=None):
# postgresql allows you to set offset without setting limit;
# very practical
return '\nOFFSET {0}'.format(offset)
def get_last_id(self, cur):
cur.execute('SELECT last_value FROM systems_id_seq')
id = cur.fetchone()[0]
return int(id)
def schema_update(sql):
for a, b in [('REAL', 'DOUBLE PRECISION'),
('INTEGER PRIMARY KEY AUTOINCREMENT',
'SERIAL PRIMARY KEY')]:
sql = sql.replace(a, b)
arrays_1D = ['numbers', 'initial_magmoms', 'initial_charges', 'masses',
'tags', 'momenta', 'stress', 'dipole', 'magmoms', 'charges']
arrays_2D = ['positions', 'cell', 'forces']
txt2jsonb = ['calculator_parameters', 'key_value_pairs']
for column in arrays_1D:
if column in ['numbers', 'tags']:
dtype = 'INTEGER'
else:
dtype = 'DOUBLE PRECISION'
sql = sql.replace('{} BLOB,'.format(column),
'{} {}[],'.format(column, dtype))
for column in arrays_2D:
sql = sql.replace('{} BLOB,'.format(column),
'{} DOUBLE PRECISION[][],'.format(column))
for column in txt2jsonb:
sql = sql.replace('{} TEXT,'.format(column),
'{} JSONB,'.format(column))
sql = sql.replace('data BLOB,', 'data JSONB,')
return sql
|