1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
import pytest
# ml requires matcalc, not packaged for debian
pytest.importorskip("matcalc")
# from matcalc.utils import get_universal_calculator
from pymatgen.core import Structure
from pymatgen.util.testing import PymatgenTest
from emmet.core.elasticity import BulkModulus, ElasticTensorDoc, ShearModulus
from emmet.core.ml import MLDoc
# if TYPE_CHECKING:
# from ase.calculators.calculator import Calculator
struct = PymatgenTest.get_structure("Si")
expected_keys = {
# -- metadata --
"material_id": str,
"structure": Structure,
"deprecated": bool,
"matcalc_version": type(None), # str,
"model_name": type(None), # str,
"model_version": type(None), # str,
# -- relaxation --
"final_structure": Structure,
"energy": float,
"volume": float,
"a": float,
"b": float,
"c": float,
"alpha": float,
"beta": float,
"gamma": float,
# -- eos --
"eos": dict,
"bulk_modulus_bm": float,
# -- phonon --
"temperatures": list,
"free_energy": list,
"entropy": list,
"heat_capacity": list,
# -- elasticity --
"elastic_tensor": ElasticTensorDoc,
"shear_modulus": ShearModulus,
"bulk_modulus": BulkModulus,
"young_modulus": float,
}
# @pytest.mark.parametrize(
# ("calculator", "prop_kwargs"),
# [
# (get_universal_calculator("chgnet"), None),
# ("M3GNet-MP-2021.2.8-PES", {"ElasticityCalc": {"relax_structure": False}}),
# ],
# )
@pytest.mark.skip(reason="Temporary skip. Needs attention.")
def test_ml_doc(calculator, prop_kwargs: dict) -> None:
doc = MLDoc(
structure=struct,
calculator=calculator,
material_id="mp-33",
deprecated=False,
prop_kwargs=prop_kwargs,
)
# check that all expected keys are present
missing = sorted({*expected_keys} - {*doc.__dict__})
assert not missing, f"keys {missing=}"
# check that all keys have expected type
for key, typ in expected_keys.items():
actual = getattr(doc, key)
assert isinstance(
actual, typ
), f"{key=} expected type={typ.__name__}, got {type(actual).__name__}"
|