1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
from __future__ import annotations
import re
from functools import cache
from typing import Optional, get_args
from maggma.utils import get_flat_models_from_model
from monty.json import MSONable
from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass
from pydantic.fields import FieldInfo
from mp_api.client.core.settings import MAPIClientSettings
def validate_ids(id_list: list[str]):
"""Function to validate material and task IDs.
Args:
id_list (List[str]): List of material or task IDs.
Raises:
ValueError: If at least one ID is not formatted correctly.
Returns:
id_list: Returns original ID list if everything is formatted correctly.
"""
if len(id_list) > MAPIClientSettings().MAX_LIST_LENGTH:
raise ValueError(
"List of material/molecule IDs provided is too long. Consider removing the ID filter to automatically pull"
" data for all IDs and filter locally."
)
pattern = "(mp|mvc|mol|mpcule)-.*"
for entry in id_list:
if re.match(pattern, entry) is None:
raise ValueError(f"{entry} is not formatted correctly!")
return id_list
@cache
def api_sanitize(
pydantic_model: BaseModel,
fields_to_leave: list[str] | None = None,
allow_dict_msonable=False,
):
"""Function to clean up pydantic models for the API by:
1.) Making fields optional
2.) Allowing dictionaries in-place of the objects for MSONable quantities.
WARNING: This works in place, so it mutates the model and all sub-models
Args:
pydantic_model (BaseModel): Pydantic model to alter
fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field".
Defaults to None.
allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities.
Defaults to False
"""
models = [
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: list[BaseModel]
fields_to_leave = fields_to_leave or []
fields_tuples = [f.split(".") for f in fields_to_leave]
assert all(len(f) == 2 for f in fields_tuples)
for model in models:
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
for name in model.model_fields:
field = model.model_fields[name]
field_json_extra = field.json_schema_extra
field_type = field.annotation
if field_type is not None and allow_dict_msonable:
if lenient_issubclass(field_type, MSONable):
field_type = allow_msonable_dict(field_type)
else:
for sub_type in get_args(field_type):
if lenient_issubclass(sub_type, MSONable):
allow_msonable_dict(sub_type)
if name not in model_fields_to_leave:
new_field = FieldInfo.from_annotated_attribute(
Optional[field_type], None
)
new_field.json_schema_extra = field_json_extra or {}
model.model_fields[name] = new_field
model.model_rebuild(force=True)
return pydantic_model
def allow_msonable_dict(monty_cls: type[MSONable]):
"""Patch Monty to allow for dict values for MSONable."""
def validate_monty(cls, v, _):
"""Stub validator for MSONable as a dictionary only."""
if isinstance(v, cls):
return v
elif isinstance(v, dict):
# Just validate the simple Monty Dict Model
errors = []
if v.get("@module", "") != monty_cls.__module__:
errors.append("@module")
if v.get("@class", "") != monty_cls.__name__:
errors.append("@class")
if len(errors) > 0:
raise ValueError(
"Missing Monty seriailzation fields in dictionary: {errors}"
)
return v
else:
raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary")
monty_cls.validate_monty_v2 = classmethod(validate_monty)
return monty_cls
|