
|
import os
from core_function import client_search_testing
import pytest
from pymatgen.analysis.magnetism import Ordering
from mp_api.client.routes.materials.electronic_structure import (
BandStructureRester,
DosRester,
ElectronicStructureRester,
)
@pytest.fixture
def es_rester():
rester = ElectronicStructureRester()
yield rester
rester.session.close()
es_excluded_params = [
"sort_fields",
"chunk_size",
"num_chunks",
"all_fields",
"fields",
]
sub_doc_fields = [] # type: list
es_alt_name_dict = {
"material_ids": "material_id",
"exclude_elements": "material_id",
"formula": "material_id",
"num_elements": "nelements",
"num_sites": "nsites",
} # type: dict
es_custom_field_tests = {
"material_ids": ["mp-149"],
"magnetic_ordering": Ordering.FM,
"formula": "CoO2",
"chemsys": "Co-O",
"elements": ["Co", "O"],
"exclude_elements": ["Co"],
} # type: dict
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skip(reason="magnetic ordering fields not build correctly")
def test_es_client(es_rester):
search_method = es_rester.search
client_search_testing(
search_method=search_method,
excluded_params=es_excluded_params,
alt_name_dict=es_alt_name_dict,
custom_field_tests=es_custom_field_tests,
sub_doc_fields=sub_doc_fields,
)
bs_custom_field_tests = {
"magnetic_ordering": Ordering.FM,
"is_metal": True,
"is_gap_direct": True,
"efermi": (0, 100),
"band_gap": (0, 5),
}
bs_sub_doc_fields = ["bandstructure"]
bs_alt_name_dict = {} # type: dict
@pytest.fixture
def bs_rester():
rester = BandStructureRester()
yield rester
rester.session.close()
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skip(reason="magnetic ordering fields not build correctly")
def test_bs_client(bs_rester):
# Get specific search method
search_method = bs_rester.search
# Query fields
for param in bs_custom_field_tests:
project_field = bs_alt_name_dict.get(param, None)
q = {
param: bs_custom_field_tests[param],
"chunk_size": 1,
"num_chunks": 1,
}
doc = search_method(**q)[0].model_dump()
for sub_field in bs_sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]
if param != "path_type":
doc = doc["setyawan_curtarolo"]
assert doc[project_field if project_field is not None else param] is not None
dos_custom_field_tests = {
"magnetic_ordering": Ordering.FM,
"efermi": (0, 100),
"band_gap": (0, 5),
}
dos_excluded_params = ["orbital", "element"]
dos_sub_doc_fields = ["dos"]
dos_alt_name_dict = {} # type: dict
@pytest.fixture
def dos_rester():
rester = DosRester()
yield rester
rester.session.close()
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skip(reason="magnetic ordering fields not build correctly")
def test_dos_client(dos_rester):
search_method = dos_rester.search
# Query fields
for param in dos_custom_field_tests:
if param not in dos_excluded_params:
project_field = dos_alt_name_dict.get(param, None)
q = {
param: dos_custom_field_tests[param],
"chunk_size": 1,
"num_chunks": 1,
}
doc = search_method(**q)[0].model_dump()
for sub_field in dos_sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]
if param != "projection_type" and param != "magnetic_ordering":
doc = doc["total"]["1"]
assert (
doc[project_field if project_field is not None else param] is not None
)
|