File: chemenv.py

package info (click to toggle)
python-mp-api 0.45.3-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,988 kB
  • sloc: python: 6,712; makefile: 14
file content (149 lines) | stat: -rw-r--r-- 6,363 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from __future__ import annotations

from collections import defaultdict

from emmet.core.chemenv import (
    COORDINATION_GEOMETRIES,
    COORDINATION_GEOMETRIES_IUCR,
    COORDINATION_GEOMETRIES_IUPAC,
    COORDINATION_GEOMETRIES_NAMES,
    ChemEnvDoc,
)

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import validate_ids


class ChemenvRester(BaseRester[ChemEnvDoc]):
    suffix = "materials/chemenv"
    document_model = ChemEnvDoc  # type: ignore
    primary_key = "material_id"

    def search(
        self,
        material_ids: str | list[str] | None = None,
        chemenv_iucr: COORDINATION_GEOMETRIES_IUCR
        | list[COORDINATION_GEOMETRIES_IUCR]
        | None = None,
        chemenv_iupac: COORDINATION_GEOMETRIES_IUPAC
        | list[COORDINATION_GEOMETRIES_IUPAC]
        | None = None,
        chemenv_name: COORDINATION_GEOMETRIES_NAMES
        | list[COORDINATION_GEOMETRIES_NAMES]
        | None = None,
        chemenv_symbol: COORDINATION_GEOMETRIES
        | list[COORDINATION_GEOMETRIES]
        | None = None,
        species: str | list[str] | None = None,
        elements: str | list[str] | None = None,
        exclude_elements: list[str] | None = None,
        csm: tuple[float, float] | None = None,
        density: tuple[float, float] | None = None,
        num_elements: tuple[int, int] | None = None,
        num_sites: tuple[int, int] | None = None,
        volume: tuple[float, float] | None = None,
        num_chunks: int | None = None,
        chunk_size: int = 1000,
        all_fields: bool = True,
        fields: list[str] | None = None,
    ) -> list[ChemEnvDoc] | list[dict]:
        """Query for chemical environment data.

        Arguments:
            material_ids (str, List[str]): Search forchemical environment associated with the specified Material IDs.
            chemenv_iucr (COORDINATION_GEOMETRIES_IUCR, List[COORDINATION_GEOMETRIES_IUCR]): Unique cationic species in
                IUCR format, e.g. "[3n]".
            chemenv_iupac (COORDINATION_GEOMETRIES_IUPAC, List[COORDINATION_GEOMETRIES_IUPAC]): Unique cationic species
                in IUPAC format, e.g., "T-4".
            chemenv_name (COORDINATION_GEOMETRIES_NAMES, List[COORDINATION_GEOMETRIES_NAMES]): Coordination environment
                descriptions in text form for unique cationic species, e.g. "Tetrahedron".
            chemenv_symbol (COORDINATION_GEOMETRIES, List[COORDINATION_GEOMETRIES]): Coordination environment
                descriptions as used in ChemEnv package for unique cationic species, e.g. "T:4".
            species (str, List[str]): Cationic species in the crystal structure, e.g. "Ti4+".
            elements (str, List[str]): Element names in the crystal structure, e.g., "Ti".
            exclude_elements (List[str]): A list of elements to exclude.
            csm (Tuple[float,float]): Minimum and maximum value of continuous symmetry measure to consider.
            density (Tuple[float,float]): Minimum and maximum density to consider.
            num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider.
            num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider.
            volume (Tuple[float,float]): Minimum and maximum volume to consider.
            num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
            chunk_size (int): Number of data entries per chunk.
            all_fields (bool): Whether to return all fields in the document. Defaults to True.
            fields (List[str]): List of fields in ChemEnvDoc to return data for.

        Returns:
            ([ChemEnvDoc], [dict]) List of chemenv documents or dictionaries.
        """
        query_params = defaultdict(dict)  # type: dict

        if csm:
            query_params.update({"csm_min": csm[0], "csm_max": csm[1]})

        if volume:
            query_params.update({"volume_min": volume[0], "volume_max": volume[1]})

        if density:
            query_params.update({"density_min": density[0], "density_max": density[1]})

        if num_sites:
            query_params.update(
                {"nsites_min": num_sites[0], "nsites_max": num_sites[1]}
            )

        if elements:
            query_params.update({"elements": ",".join(elements)})

        if exclude_elements:
            query_params.update({"exclude_elements": ",".join(exclude_elements)})

        if num_elements:
            if isinstance(num_elements, int):
                num_elements = (num_elements, num_elements)
            query_params.update(
                {"nelements_min": num_elements[0], "nelements_max": num_elements[1]}
            )

        if material_ids:
            if isinstance(material_ids, str):
                material_ids = [material_ids]

            query_params.update({"material_ids": ",".join(validate_ids(material_ids))})

        chemenv_literals = {
            "chemenv_iucr": (chemenv_iucr, COORDINATION_GEOMETRIES_IUCR),
            "chemenv_iupac": (chemenv_iupac, COORDINATION_GEOMETRIES_IUPAC),
            "chemenv_name": (chemenv_name, COORDINATION_GEOMETRIES_NAMES),
            "chemenv_symbol": (chemenv_symbol, COORDINATION_GEOMETRIES),
        }

        for chemenv_var_name, (chemenv_var, literals) in chemenv_literals.items():
            if chemenv_var:
                t_types = {t if isinstance(t, str) else t.value for t in chemenv_var}
                valid_types = {*map(str, literals.__args__)}
                if invalid_types := t_types - valid_types:
                    raise ValueError(
                        f"Invalid type(s) passed for {chemenv_var_name}: {invalid_types}, valid types are: {valid_types}"
                    )

                query_params.update({chemenv_var_name: ",".join(t_types)})

        if species:
            if isinstance(species, str):
                species = [species]

            query_params.update({"species": ",".join(species)})

        query_params = {
            entry: query_params[entry]
            for entry in query_params
            if query_params[entry] is not None
        }

        return super()._search(
            num_chunks=num_chunks,
            chunk_size=chunk_size,
            all_fields=all_fields,
            fields=fields,
            **query_params,
        )