File: utils.py

package info (click to toggle)
python-mp-api 0.45.3-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,988 kB
  • sloc: python: 6,712; makefile: 14
file content (126 lines) | stat: -rw-r--r-- 4,264 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations

import re
from functools import cache
from typing import Optional, get_args

from maggma.utils import get_flat_models_from_model
from monty.json import MSONable
from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass
from pydantic.fields import FieldInfo

from mp_api.client.core.settings import MAPIClientSettings


def validate_ids(id_list: list[str]):
    """Function to validate material and task IDs.

    Args:
        id_list (List[str]): List of material or task IDs.

    Raises:
        ValueError: If at least one ID is not formatted correctly.

    Returns:
        id_list: Returns original ID list if everything is formatted correctly.
    """
    if len(id_list) > MAPIClientSettings().MAX_LIST_LENGTH:
        raise ValueError(
            "List of material/molecule IDs provided is too long. Consider removing the ID filter to automatically pull"
            " data for all IDs and filter locally."
        )

    pattern = "(mp|mvc|mol|mpcule)-.*"

    for entry in id_list:
        if re.match(pattern, entry) is None:
            raise ValueError(f"{entry} is not formatted correctly!")

    return id_list


@cache
def api_sanitize(
    pydantic_model: BaseModel,
    fields_to_leave: list[str] | None = None,
    allow_dict_msonable=False,
):
    """Function to clean up pydantic models for the API by:
        1.) Making fields optional
        2.) Allowing dictionaries in-place of the objects for MSONable quantities.

    WARNING: This works in place, so it mutates the model and all sub-models

    Args:
        pydantic_model (BaseModel): Pydantic model to alter
        fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field".
            Defaults to None.
        allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities.
            Defaults to False
    """
    models = [
        model
        for model in get_flat_models_from_model(pydantic_model)
        if issubclass(model, BaseModel)
    ]  # type: list[BaseModel]

    fields_to_leave = fields_to_leave or []
    fields_tuples = [f.split(".") for f in fields_to_leave]
    assert all(len(f) == 2 for f in fields_tuples)

    for model in models:
        model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
        for name in model.model_fields:
            field = model.model_fields[name]
            field_json_extra = field.json_schema_extra
            field_type = field.annotation

            if field_type is not None and allow_dict_msonable:
                if lenient_issubclass(field_type, MSONable):
                    field_type = allow_msonable_dict(field_type)
                else:
                    for sub_type in get_args(field_type):
                        if lenient_issubclass(sub_type, MSONable):
                            allow_msonable_dict(sub_type)

            if name not in model_fields_to_leave:
                new_field = FieldInfo.from_annotated_attribute(
                    Optional[field_type], None
                )
                new_field.json_schema_extra = field_json_extra or {}
                model.model_fields[name] = new_field

        model.model_rebuild(force=True)

    return pydantic_model


def allow_msonable_dict(monty_cls: type[MSONable]):
    """Patch Monty to allow for dict values for MSONable."""

    def validate_monty(cls, v, _):
        """Stub validator for MSONable as a dictionary only."""
        if isinstance(v, cls):
            return v
        elif isinstance(v, dict):
            # Just validate the simple Monty Dict Model
            errors = []
            if v.get("@module", "") != monty_cls.__module__:
                errors.append("@module")

            if v.get("@class", "") != monty_cls.__name__:
                errors.append("@class")

            if len(errors) > 0:
                raise ValueError(
                    "Missing Monty seriailzation fields in dictionary: {errors}"
                )

            return v
        else:
            raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary")

    monty_cls.validate_monty_v2 = classmethod(validate_monty)

    return monty_cls