import inspect
import pickle
import platform
from types import GeneratorType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import catalogue
import pytest

try:
    from pydantic.v1 import BaseModel, PositiveInt, StrictFloat, constr
    from pydantic.v1.types import StrictBool
except ImportError:
    from pydantic import BaseModel, StrictFloat, PositiveInt, constr  # type: ignore
    from pydantic.types import StrictBool  # type: ignore

from confection import Config, ConfigValidationError
from confection.tests.util import Cat, make_tempdir, my_registry
from confection.util import Generator, partial

EXAMPLE_CONFIG = """
[optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
use_averages = true

[optimizer.learn_rate]
@schedules = "warmup_linear.v1"
initial_rate = 0.1
warmup_steps = 10000
total_steps = 100000

[pipeline]

[pipeline.classifier]
name = "classifier"
factory = "classifier"

[pipeline.classifier.model]
@layers = "ClassifierModel.v1"
hidden_depth = 1
hidden_width = 64
token_vector_width = 128

[pipeline.classifier.model.embedding]
@layers = "Embedding.v1"
width = ${pipeline.classifier.model:token_vector_width}

"""

OPTIMIZER_CFG = """
[optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
use_averages = true

[optimizer.learn_rate]
@schedules = "warmup_linear.v1"
initial_rate = 0.1
warmup_steps = 10000
total_steps = 100000
"""


class HelloIntsSchema(BaseModel):
    hello: int
    world: int

    class Config:
        extra = "forbid"


class DefaultsSchema(BaseModel):
    required: int
    optional: str = "default value"

    class Config:
        extra = "forbid"


class ComplexSchema(BaseModel):
    outer_req: int
    outer_opt: str = "default value"

    level2_req: HelloIntsSchema
    level2_opt: DefaultsSchema = DefaultsSchema(required=1)


good_catsie = {"@cats": "catsie.v1", "evil": False, "cute": True}
ok_catsie = {"@cats": "catsie.v1", "evil": False, "cute": False}
bad_catsie = {"@cats": "catsie.v1", "evil": True, "cute": True}
worst_catsie = {"@cats": "catsie.v1", "evil": True, "cute": False}


def test_validate_simple_config():
    simple_config = {"hello": 1, "world": 2}
    f, _, v = my_registry._fill(simple_config, HelloIntsSchema)
    assert f == simple_config
    assert v == simple_config


def test_invalidate_simple_config():
    invalid_config = {"hello": 1, "world": "hi!"}
    with pytest.raises(ConfigValidationError) as exc_info:
        my_registry._fill(invalid_config, HelloIntsSchema)
    error = exc_info.value
    assert len(error.errors) == 1
    assert "type_error.integer" in error.error_types


def test_invalidate_extra_args():
    invalid_config = {"hello": 1, "world": 2, "extra": 3}
    with pytest.raises(ConfigValidationError):
        my_registry._fill(invalid_config, HelloIntsSchema)


def test_fill_defaults_simple_config():
    valid_config = {"required": 1}
    filled, _, v = my_registry._fill(valid_config, DefaultsSchema)
    assert filled["required"] == 1
    assert filled["optional"] == "default value"
    invalid_config = {"optional": "some value"}
    with pytest.raises(ConfigValidationError):
        my_registry._fill(invalid_config, DefaultsSchema)


def test_fill_recursive_config():
    valid_config = {"outer_req": 1, "level2_req": {"hello": 4, "world": 7}}
    filled, _, validation = my_registry._fill(valid_config, ComplexSchema)
    assert filled["outer_req"] == 1
    assert filled["outer_opt"] == "default value"
    assert filled["level2_req"]["hello"] == 4
    assert filled["level2_req"]["world"] == 7
    assert filled["level2_opt"]["required"] == 1
    assert filled["level2_opt"]["optional"] == "default value"


def test_is_promise():
    assert my_registry.is_promise(good_catsie)
    assert not my_registry.is_promise({"hello": "world"})
    assert not my_registry.is_promise(1)
    invalid = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"}
    assert my_registry.is_promise(invalid)


def test_get_constructor():
    assert my_registry.get_constructor(good_catsie) == ("cats", "catsie.v1")


def test_parse_args():
    args, kwargs = my_registry.parse_args(bad_catsie)
    assert args == []
    assert kwargs == {"evil": True, "cute": True}


def test_make_promise_schema():
    schema = my_registry.make_promise_schema(good_catsie)
    assert "evil" in schema.__fields__
    assert "cute" in schema.__fields__


def test_validate_promise():
    config = {"required": 1, "optional": good_catsie}
    filled, _, validated = my_registry._fill(config, DefaultsSchema)
    assert filled == config
    assert validated == {"required": 1, "optional": "meow"}


def test_fill_validate_promise():
    config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}}
    filled, _, validated = my_registry._fill(config, DefaultsSchema)
    assert filled["optional"]["cute"] is True


def test_fill_invalidate_promise():
    config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}}
    with pytest.raises(ConfigValidationError):
        my_registry._fill(config, HelloIntsSchema)
    config["optional"]["whiskers"] = True
    with pytest.raises(ConfigValidationError):
        my_registry._fill(config, DefaultsSchema)


def test_create_registry():
    my_registry.dogs = catalogue.create(
        my_registry.namespace, "dogs", entry_points=False
    )
    assert hasattr(my_registry, "dogs")
    assert len(my_registry.dogs.get_all()) == 0
    my_registry.dogs.register("good_boy.v1", func=lambda x: x)
    assert len(my_registry.dogs.get_all()) == 1


def test_registry_methods():
    with pytest.raises(ValueError):
        my_registry.get("dfkoofkds", "catsie.v1")
    my_registry.cats.register("catsie.v123")(None)
    with pytest.raises(ValueError):
        my_registry.get("cats", "catsie.v123")


def test_resolve_no_schema():
    config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
    result = my_registry.resolve({"cfg": config})["cfg"]
    assert result["one"] == 1
    assert result["two"] == {"three": "scratch!"}
    with pytest.raises(ConfigValidationError):
        config = {"two": {"three": {"@cats": "catsie.v1", "evil": "true"}}}
        my_registry.resolve(config)


def test_resolve_schema():
    class TestBaseSubSchema(BaseModel):
        three: str

    class TestBaseSchema(BaseModel):
        one: PositiveInt
        two: TestBaseSubSchema

        class Config:
            extra = "forbid"

    class TestSchema(BaseModel):
        cfg: TestBaseSchema

    config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
    my_registry.resolve({"cfg": config}, schema=TestSchema)
    config = {"one": -1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
    with pytest.raises(ConfigValidationError):
        # "one" is not a positive int
        my_registry.resolve({"cfg": config}, schema=TestSchema)
    config = {"one": 1, "two": {"four": {"@cats": "catsie.v1", "evil": True}}}
    with pytest.raises(ConfigValidationError):
        # "three" is required in subschema
        my_registry.resolve({"cfg": config}, schema=TestSchema)


def test_resolve_schema_coerced():
    class TestBaseSchema(BaseModel):
        test1: str
        test2: bool
        test3: float

    class TestSchema(BaseModel):
        cfg: TestBaseSchema

    config = {"test1": 123, "test2": 1, "test3": 5}
    filled = my_registry.fill({"cfg": config}, schema=TestSchema)
    result = my_registry.resolve({"cfg": config}, schema=TestSchema)
    assert result["cfg"] == {"test1": "123", "test2": True, "test3": 5.0}
    # This only affects the resolved config, not the filled config
    assert filled["cfg"] == config


def test_read_config():
    byte_string = EXAMPLE_CONFIG.encode("utf8")
    cfg = Config().from_bytes(byte_string)

    assert cfg["optimizer"]["beta1"] == 0.9
    assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1
    assert cfg["pipeline"]["classifier"]["factory"] == "classifier"
    assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128


def test_optimizer_config():
    cfg = Config().from_str(OPTIMIZER_CFG)
    optimizer = my_registry.resolve(cfg, validate=True)["optimizer"]
    assert optimizer.beta1 == 0.9


def test_config_to_str():
    cfg = Config().from_str(OPTIMIZER_CFG)
    assert cfg.to_str().strip() == OPTIMIZER_CFG.strip()
    cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG)
    assert cfg.to_str().strip() == OPTIMIZER_CFG.strip()


def test_config_to_str_creates_intermediate_blocks():
    cfg = Config({"optimizer": {"foo": {"bar": 1}}})
    assert (
        cfg.to_str().strip()
        == """
[optimizer]

[optimizer.foo]
bar = 1
    """.strip()
    )


def test_config_to_str_escapes():
    section_str = """
        [section]
        node1 = "^a$$"
        node2 = "$$b$$c"
        """
    section_dict = {"section": {"node1": "^a$", "node2": "$b$c"}}

    # parse from escaped string
    cfg = Config().from_str(section_str)
    assert cfg == section_dict

    # parse from non-escaped dict
    cfg = Config(section_dict)
    assert cfg == section_dict

    # roundtrip through str
    cfg_str = cfg.to_str()
    assert "^a$$" in cfg_str
    new_cfg = Config().from_str(cfg_str)
    assert cfg == section_dict


def test_config_roundtrip_bytes():
    cfg = Config().from_str(OPTIMIZER_CFG)
    cfg_bytes = cfg.to_bytes()
    new_cfg = Config().from_bytes(cfg_bytes)
    assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()


def test_config_roundtrip_disk():
    cfg = Config().from_str(OPTIMIZER_CFG)
    with make_tempdir() as path:
        cfg_path = path / "config.cfg"
        cfg.to_disk(cfg_path)
        new_cfg = Config().from_disk(cfg_path)
    assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()


def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture):
    cfg = Config().from_str(OPTIMIZER_CFG)
    cfg_path = pathy_fixture / "config.cfg"
    cfg.to_disk(cfg_path)
    new_cfg = Config().from_disk(cfg_path)
    assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()


def test_config_to_str_invalid_defaults():
    """Test that an error is raised if a config contains top-level keys without
    a section that would otherwise be interpreted as [DEFAULT] (which causes
    the values to be included in *all* other sections).
    """
    cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}
    with pytest.raises(ConfigValidationError):
        Config(cfg).to_str()
    config_str = "[DEFAULT]\none = 1"
    with pytest.raises(ConfigValidationError):
        Config().from_str(config_str)


def test_validation_custom_types():
    def complex_args(
        rate: StrictFloat,
        steps: PositiveInt = 10,  # type: ignore
        log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR",  # noqa: F821
    ):
        return None

    my_registry.complex = catalogue.create(
        my_registry.namespace, "complex", entry_points=False
    )
    my_registry.complex("complex.v1")(complex_args)
    cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"}
    my_registry.resolve({"config": cfg})
    cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": -1, "log_level": "INFO"}
    with pytest.raises(ConfigValidationError):
        # steps is not a positive int
        my_registry.resolve({"config": cfg})
    cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "none"}
    with pytest.raises(ConfigValidationError):
        # log_level is not a string matching the regex
        my_registry.resolve({"config": cfg})
    cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"}
    with pytest.raises(ConfigValidationError):
        # top-level object is promise
        my_registry.resolve(cfg)
    with pytest.raises(ConfigValidationError):
        # top-level object is promise
        my_registry.fill(cfg)
    cfg = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"}
    with pytest.raises(ConfigValidationError):
        # two constructors
        my_registry.resolve({"config": cfg})


def test_validation_no_validate():
    config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": "false"}}}
    result = my_registry.resolve({"cfg": config}, validate=False)
    filled = my_registry.fill({"cfg": config}, validate=False)
    assert result["cfg"]["one"] == 1
    assert result["cfg"]["two"] == {"three": "scratch!"}
    assert filled["cfg"]["two"]["three"]["evil"] == "false"
    assert filled["cfg"]["two"]["three"]["cute"] is True


def test_validation_fill_defaults():
    config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}}
    result = my_registry.fill(config, validate=False)
    assert len(result["cfg"]["two"]) == 3
    with pytest.raises(ConfigValidationError):
        # Required arg "evil" is not defined
        my_registry.fill(config)
    config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v2", "evil": False}}}
    # Fill in with new defaults
    result = my_registry.fill(config)
    assert len(result["cfg"]["two"]) == 4
    assert result["cfg"]["two"]["evil"] is False
    assert result["cfg"]["two"]["cute"] is True
    assert result["cfg"]["two"]["cute_level"] == 1


def test_make_config_positional_args():
    @my_registry.cats("catsie.v567")
    def catsie_567(*args: Optional[str], foo: str = "bar"):
        assert args[0] == "^_^"
        assert args[1] == "^(*.*)^"
        assert foo == "baz"
        return args[0]

    args = ["^_^", "^(*.*)^"]
    cfg = {"config": {"@cats": "catsie.v567", "foo": "baz", "*": args}}
    assert my_registry.resolve(cfg)["config"] == "^_^"


def test_fill_config_positional_args_w_promise():
    @my_registry.cats("catsie.v568")
    def catsie_568(*args: str, foo: str = "bar"):
        assert args[0] == "^(*.*)^"
        assert foo == "baz"
        return args[0]

    @my_registry.cats("cat_promise.v568")
    def cat_promise() -> str:
        return "^(*.*)^"

    cfg = {
        "config": {
            "@cats": "catsie.v568",
            "*": {"promise": {"@cats": "cat_promise.v568"}},
        }
    }
    filled = my_registry.fill(cfg, validate=True)
    assert filled["config"]["foo"] == "bar"
    assert filled["config"]["*"] == {"promise": {"@cats": "cat_promise.v568"}}


def test_make_config_positional_args_complex():
    @my_registry.cats("catsie.v890")
    def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]):
        assert args[0] == 123
        return args[0]

    cfg = {"config": {"@cats": "catsie.v890", "*": [123, True, 1, False]}}
    assert my_registry.resolve(cfg)["config"] == 123
    cfg = {"config": {"@cats": "catsie.v890", "*": [123, "True"]}}
    with pytest.raises(ConfigValidationError):
        # "True" is not a valid boolean or positive int
        my_registry.resolve(cfg)


def test_positional_args_to_from_string():
    cfg = """[a]\nb = 1\n* = ["foo","bar"]"""
    assert Config().from_str(cfg).to_str() == cfg
    cfg = """[a]\nb = 1\n\n[a.*.bar]\ntest = 2\n\n[a.*.foo]\ntest = 1"""
    assert Config().from_str(cfg).to_str() == cfg

    @my_registry.cats("catsie.v666")
    def catsie_666(*args, meow=False):
        return args

    cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]"""
    filled = my_registry.fill(Config().from_str(cfg)).to_str()
    assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false"""
    resolved = my_registry.resolve(Config().from_str(cfg))
    assert resolved == {"a": ("foo", "bar")}
    cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1"""
    filled = my_registry.fill(Config().from_str(cfg)).to_str()
    assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1"""
    resolved = my_registry.resolve(Config().from_str(cfg))
    assert resolved == {"a": ({"x": 1},)}

    @my_registry.cats("catsie.v777")
    def catsie_777(y: int = 1):
        return "meow" * y

    cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\""""
    filled = my_registry.fill(Config().from_str(cfg)).to_str()
    expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1"""
    assert filled == expected
    cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3"""
    result = my_registry.resolve(Config().from_str(cfg))
    assert result == {"a": ("meowmeowmeow",)}


def test_validation_generators_iterable():
    @my_registry.optimizers("test_optimizer.v1")
    def test_optimizer_v1(rate: float) -> None:
        return None

    @my_registry.schedules("test_schedule.v1")
    def test_schedule_v1(some_value: float = 1.0) -> Iterable[float]:
        while True:
            yield some_value

    config = {"optimizer": {"@optimizers": "test_optimizer.v1", "rate": 0.1}}
    my_registry.resolve(config)


def test_validation_unset_type_hints():
    """Test that unset type hints are handled correctly (and treated as Any)."""

    @my_registry.optimizers("test_optimizer.v2")
    def test_optimizer_v2(rate, steps: int = 10) -> None:
        return None

    config = {"test": {"@optimizers": "test_optimizer.v2", "rate": 0.1, "steps": 20}}
    my_registry.resolve(config)


def test_validation_bad_function():
    @my_registry.optimizers("bad.v1")
    def bad() -> None:
        raise ValueError("This is an error in the function")
        return None

    @my_registry.optimizers("good.v1")
    def good() -> None:
        return None

    # Bad function
    config = {"test": {"@optimizers": "bad.v1"}}
    with pytest.raises(ValueError):
        my_registry.resolve(config)
    # Bad function call
    config = {"test": {"@optimizers": "good.v1", "invalid_arg": 1}}
    with pytest.raises(ConfigValidationError):
        my_registry.resolve(config)


def test_objects_from_config():
    config = {
        "optimizer": {
            "@optimizers": "my_cool_optimizer.v1",
            "beta1": 0.2,
            "learn_rate": {
                "@schedules": "my_cool_repetitive_schedule.v1",
                "base_rate": 0.001,
                "repeat": 4,
            },
        }
    }

    optimizer = my_registry.resolve(config)["optimizer"]
    assert optimizer.beta1 == 0.2
    assert optimizer.learn_rate == [0.001] * 4


def test_partials_from_config():
    """Test that functions registered with partial applications are handled
    correctly (e.g. initializers)."""
    numpy = pytest.importorskip("numpy")

    def uniform_init(
        shape: Tuple[int, ...], *, lo: float = -0.1, hi: float = 0.1
    ) -> List[float]:
        return numpy.random.uniform(lo, hi, shape).tolist()

    @my_registry.initializers("uniform_init.v1")
    def configure_uniform_init(
        *, lo: float = -0.1, hi: float = 0.1
    ) -> Callable[[List[float]], List[float]]:
        return partial(uniform_init, lo=lo, hi=hi)

    name = "uniform_init.v1"
    cfg = {"test": {"@initializers": name, "lo": -0.2}}
    func = my_registry.resolve(cfg)["test"]
    assert hasattr(func, "__call__")
    # The partial will still have lo as an arg, just with default
    assert len(inspect.signature(func).parameters) == 3
    # Make sure returned partial function has correct value set
    assert inspect.signature(func).parameters["lo"].default == -0.2
    # Actually call the function and verify
    assert numpy.asarray(func((2, 3))).shape == (2, 3)
    # Make sure validation still works
    bad_cfg = {"test": {"@initializers": name, "lo": [0.5]}}
    with pytest.raises(ConfigValidationError):
        my_registry.resolve(bad_cfg)
    bad_cfg = {"test": {"@initializers": name, "lo": -0.2, "other": 10}}
    with pytest.raises(ConfigValidationError):
        my_registry.resolve(bad_cfg)


def test_partials_from_config_nested():
    """Test that partial functions are passed correctly to other registered
    functions that consume them (e.g. initializers -> layers)."""

    def test_initializer(a: int, b: int = 1) -> int:
        return a * b

    @my_registry.initializers("test_initializer.v1")
    def configure_test_initializer(b: int = 1) -> Callable[[int], int]:
        return partial(test_initializer, b=b)

    @my_registry.layers("test_layer.v1")
    def test_layer(init: Callable[[int], int], c: int = 1) -> Callable[[int], int]:
        return lambda x: x + init(c)

    cfg = {
        "@layers": "test_layer.v1",
        "c": 5,
        "init": {"@initializers": "test_initializer.v1", "b": 10},
    }
    func = my_registry.resolve({"test": cfg})["test"]
    assert func(1) == 51
    assert func(100) == 150


def test_validate_generator():
    """Test that generator replacement for validation in config doesn't
    actually replace the returned value."""

    @my_registry.schedules("test_schedule.v2")
    def test_schedule():
        while True:
            yield 10

    cfg = {"@schedules": "test_schedule.v2"}
    result = my_registry.resolve({"test": cfg})["test"]
    assert isinstance(result, GeneratorType)

    @my_registry.optimizers("test_optimizer.v2")
    def test_optimizer2(rate: Generator) -> Generator:
        return rate

    cfg = {
        "@optimizers": "test_optimizer.v2",
        "rate": {"@schedules": "test_schedule.v2"},
    }
    result = my_registry.resolve({"test": cfg})["test"]
    assert isinstance(result, GeneratorType)

    @my_registry.optimizers("test_optimizer.v3")
    def test_optimizer3(schedules: Dict[str, Generator]) -> Generator:
        return schedules["rate"]

    cfg = {
        "@optimizers": "test_optimizer.v3",
        "schedules": {"rate": {"@schedules": "test_schedule.v2"}},
    }
    result = my_registry.resolve({"test": cfg})["test"]
    assert isinstance(result, GeneratorType)

    @my_registry.optimizers("test_optimizer.v4")
    def test_optimizer4(*schedules: Generator) -> Generator:
        return schedules[0]


def test_handle_generic_type():
    """Test that validation can handle checks against arbitrary generic
    types in function argument annotations."""

    cfg = {"@cats": "generic_cat.v1", "cat": {"@cats": "int_cat.v1", "value_in": 3}}
    cat = my_registry.resolve({"test": cfg})["test"]
    assert isinstance(cat, Cat)
    assert cat.value_in == 3
    assert cat.value_out is None
    assert cat.name == "generic_cat"


@pytest.mark.parametrize(
    "cfg",
    [
        "[a]\nb = 1\nc = 2\n\n[a.c]\nd = 3",
        "[a]\nb = 1\n\n[a.c]\nd = 2\n\n[a.c.d]\ne = 3",
    ],
)
def test_handle_error_duplicate_keys(cfg):
    """This would cause very cryptic error when interpreting config.
    (TypeError: 'X' object does not support item assignment)
    """
    with pytest.raises(ConfigValidationError):
        Config().from_str(cfg)


@pytest.mark.parametrize(
    "cfg,is_valid",
    [("[a]\nb = 1\n\n[a.c]\nd = 3", True), ("[a]\nb = 1\n\n[A.c]\nd = 2", False)],
)
def test_cant_expand_undefined_block(cfg, is_valid):
    """Test that you can't expand a block that hasn't been created yet. This
    comes up when you typo a name, and if we allow expansion of undefined blocks,
    it's very hard to create good errors for those typos.
    """
    if is_valid:
        Config().from_str(cfg)
    else:
        with pytest.raises(ConfigValidationError):
            Config().from_str(cfg)


def test_fill_config_overrides():
    config = {
        "cfg": {
            "one": 1,
            "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}},
        }
    }
    overrides = {"cfg.two.three.evil": False}
    result = my_registry.fill(config, overrides=overrides, validate=True)
    assert result["cfg"]["two"]["three"]["evil"] is False
    # Test that promises can be overwritten as well
    overrides = {"cfg.two.three": 3}
    result = my_registry.fill(config, overrides=overrides, validate=True)
    assert result["cfg"]["two"]["three"] == 3
    # Test that value can be overwritten with promises and that the result is
    # interpreted and filled correctly
    overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}}
    result = my_registry.fill(config, overrides=overrides)
    assert result["cfg"]["two"] is None
    assert result["cfg"]["one"]["@cats"] == "catsie.v1"
    assert result["cfg"]["one"]["evil"] is False
    assert result["cfg"]["one"]["cute"] is True
    # Overwriting with wrong types should cause validation error
    with pytest.raises(ConfigValidationError):
        overrides = {"cfg.two.three.evil": 20}
        my_registry.fill(config, overrides=overrides, validate=True)
    # Overwriting with incomplete promises should cause validation error
    with pytest.raises(ConfigValidationError):
        overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}}
        my_registry.fill(config, overrides=overrides)
    # Overrides that don't match config should raise error
    with pytest.raises(ConfigValidationError):
        overrides = {"cfg.two.three.evil": False, "two.four": True}
        my_registry.fill(config, overrides=overrides, validate=True)
    with pytest.raises(ConfigValidationError):
        overrides = {"cfg.five": False}
        my_registry.fill(config, overrides=overrides, validate=True)


def test_resolve_overrides():
    config = {
        "cfg": {
            "one": 1,
            "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}},
        }
    }
    overrides = {"cfg.two.three.evil": False}
    result = my_registry.resolve(config, overrides=overrides, validate=True)
    assert result["cfg"]["two"]["three"] == "meow"
    # Test that promises can be overwritten as well
    overrides = {"cfg.two.three": 3}
    result = my_registry.resolve(config, overrides=overrides, validate=True)
    assert result["cfg"]["two"]["three"] == 3
    # Test that value can be overwritten with promises
    overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}}
    result = my_registry.resolve(config, overrides=overrides)
    assert result["cfg"]["one"] == "meow"
    assert result["cfg"]["two"] is None
    # Overwriting with wrong types should cause validation error
    with pytest.raises(ConfigValidationError):
        overrides = {"cfg.two.three.evil": 20}
        my_registry.resolve(config, overrides=overrides, validate=True)
    # Overwriting with incomplete promises should cause validation error
    with pytest.raises(ConfigValidationError):
        overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}}
        my_registry.resolve(config, overrides=overrides)
    # Overrides that don't match config should raise error
    with pytest.raises(ConfigValidationError):
        overrides = {"cfg.two.three.evil": False, "cfg.two.four": True}
        my_registry.resolve(config, overrides=overrides, validate=True)
    with pytest.raises(ConfigValidationError):
        overrides = {"cfg.five": False}
        my_registry.resolve(config, overrides=overrides, validate=True)


@pytest.mark.parametrize(
    "prop,expected",
    [("a.b.c", True), ("a.b", True), ("a", True), ("a.e", True), ("a.b.c.d", False)],
)
def test_is_in_config(prop, expected):
    config = {"a": {"b": {"c": 5, "d": 6}, "e": [1, 2]}}
    assert my_registry._is_in_config(prop, config) is expected


def test_resolve_prefilled_values():
    class Language(object):
        def __init__(self):
            ...

    @my_registry.optimizers("prefilled.v1")
    def prefilled(nlp: Language, value: int = 10):
        return (nlp, value)

    # Passing an instance of Language here via the config is bad, since it
    # won't serialize to a string, but we still test for it
    config = {"test": {"@optimizers": "prefilled.v1", "nlp": Language(), "value": 50}}
    resolved = my_registry.resolve(config, validate=True)
    result = resolved["test"]
    assert isinstance(result[0], Language)
    assert result[1] == 50


def test_fill_config_dict_return_type():
    """Test that a registered function returning a dict is handled correctly."""

    @my_registry.cats.register("catsie_with_dict.v1")
    def catsie_with_dict(evil: StrictBool) -> Dict[str, bool]:
        return {"not_evil": not evil}

    config = {"test": {"@cats": "catsie_with_dict.v1", "evil": False}, "foo": 10}
    result = my_registry.fill({"cfg": config}, validate=True)["cfg"]["test"]
    assert result["evil"] is False
    assert "not_evil" not in result
    result = my_registry.resolve({"cfg": config}, validate=True)["cfg"]["test"]
    assert result["not_evil"] is True


def test_deepcopy_config():
    config = Config({"a": 1, "b": {"c": 2, "d": 3}})
    copied = config.copy()
    # Same values but not same object
    assert config == copied
    assert config is not copied


@pytest.mark.skipif(
    platform.python_implementation() == "PyPy", reason="copy does not fail for pypy"
)
def test_deepcopy_config_pickle():
    numpy = pytest.importorskip("numpy")
    # Check for error if value can't be pickled/deepcopied
    config = Config({"a": 1, "b": numpy})
    with pytest.raises(ValueError):
        config.copy()


def test_config_to_str_simple_promises():
    """Test that references to function registries without arguments are
    serialized inline as dict."""
    config_str = """[section]\nsubsection = {"@registry":"value"}"""
    config = Config().from_str(config_str)
    assert config["section"]["subsection"]["@registry"] == "value"
    assert config.to_str() == config_str


def test_config_from_str_invalid_section():
    config_str = """[a]\nb = null\n\n[a.b]\nc = 1"""
    with pytest.raises(ConfigValidationError):
        Config().from_str(config_str)

    config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1"""
    with pytest.raises(ConfigValidationError):
        Config().from_str(config_str)


def test_config_to_str_order():
    """Test that Config.to_str orders the sections."""
    config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}}
    expected = (
        "[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5"
    )
    config = Config(config)
    assert config.to_str() == expected


@pytest.mark.parametrize("d", [".", ":"])
def test_config_interpolation(d):
    """Test that config values are interpolated correctly. The parametrized
    value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
    valid. The double {{ }} in the config strings are required to prevent the
    references from being interpreted as an actual f-string variable.
    """
    c_str = """[a]\nfoo = "hello"\n\n[b]\nbar = ${foo}"""
    with pytest.raises(ConfigValidationError):
        Config().from_str(c_str)
    c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}"""
    assert Config().from_str(c_str)["b"]["bar"] == "hello"
    c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}!"""
    assert Config().from_str(c_str)["b"]["bar"] == "hello!"
    c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = "${{a{d}foo}}!\""""
    assert Config().from_str(c_str)["b"]["bar"] == "hello!"
    c_str = f"""[a]\nfoo = 15\n\n[b]\nbar = ${{a{d}foo}}!"""
    assert Config().from_str(c_str)["b"]["bar"] == "15!"
    c_str = f"""[a]\nfoo = ["x", "y"]\n\n[b]\nbar = ${{a{d}foo}}"""
    assert Config().from_str(c_str)["b"]["bar"] == ["x", "y"]
    # Interpolation within the same section
    c_str = f"""[a]\nfoo = "x"\nbar = ${{a{d}foo}}\nbaz = "${{a{d}foo}}y\""""
    assert Config().from_str(c_str)["a"]["bar"] == "x"
    assert Config().from_str(c_str)["a"]["baz"] == "xy"


def test_config_interpolation_lists():
    # Test that lists are preserved correctly
    c_str = """[a]\nb = 1\n\n[c]\nd = ["hello ${a.b}", "world"]"""
    config = Config().from_str(c_str, interpolate=False)
    assert config["c"]["d"] == ["hello ${a.b}", "world"]
    config = config.interpolate()
    assert config["c"]["d"] == ["hello 1", "world"]
    c_str = """[a]\nb = 1\n\n[c]\nd = [${a.b}, "hello ${a.b}", "world"]"""
    config = Config().from_str(c_str)
    assert config["c"]["d"] == [1, "hello 1", "world"]
    config = Config().from_str(c_str, interpolate=False)
    # NOTE: This currently doesn't work, because we can't know how to JSON-load
    # the uninterpolated list [${a.b}].
    # assert config["c"]["d"] == ["${a.b}", "hello ${a.b}", "world"]
    # config = config.interpolate()
    # assert config["c"]["d"] == [1, "hello 1", "world"]
    c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", ${a}]"""
    config = Config().from_str(c_str)
    assert config["c"]["d"] == ["hello", {"b": 1}]
    c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", "hello ${a}"]"""
    with pytest.raises(ConfigValidationError):
        Config().from_str(c_str)
    config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": ["hello ${a.b}"], "y": 2}]"""
    config = Config().from_str(config_str)
    assert config["c"]["d"] == ["hello", {"x": ["hello 1"], "y": 2}]
    config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": [${a.b}], "y": 2}]"""
    with pytest.raises(ConfigValidationError):
        Config().from_str(c_str)


@pytest.mark.parametrize("d", [".", ":"])
def test_config_interpolation_sections(d):
    """Test that config sections are interpolated correctly. The parametrized
    value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
    valid. The double {{ }} in the config strings are required to prevent the
    references from being interpreted as an actual f-string variable.
    """
    # Simple block references
    c_str = """[a]\nfoo = "hello"\nbar = "world"\n\n[b]\nc = ${a}"""
    config = Config().from_str(c_str)
    assert config["b"]["c"] == config["a"]
    # References with non-string values
    c_str = f"""[a]\nfoo = "hello"\n\n[a.x]\ny = ${{a{d}b}}\n\n[a.b]\nc = 1\nd = [10]"""
    config = Config().from_str(c_str)
    assert config["a"]["x"]["y"] == config["a"]["b"]
    # Multiple references in the same string
    c_str = f"""[a]\nx = "string"\ny = 10\n\n[b]\nz = "${{a{d}x}}/${{a{d}y}}\""""
    config = Config().from_str(c_str)
    assert config["b"]["z"] == "string/10"
    # Non-string references in string (converted to string)
    c_str = f"""[a]\nx = ["hello", "world"]\n\n[b]\ny = "result: ${{a{d}x}}\""""
    config = Config().from_str(c_str)
    assert config["b"]["y"] == 'result: ["hello", "world"]'
    # References to sections referencing sections
    c_str = """[a]\nfoo = "x"\n\n[b]\nbar = ${a}\n\n[c]\nbaz = ${b}"""
    config = Config().from_str(c_str)
    assert config["b"]["bar"] == config["a"]
    assert config["c"]["baz"] == config["b"]
    # References to section values referencing other sections
    c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b{d}bar}}"""
    config = Config().from_str(c_str)
    assert config["c"]["baz"] == config["b"]["bar"]
    # References to sections with subsections
    c_str = """[a]\nfoo = "x"\n\n[a.b]\nbar = 100\n\n[c]\nbaz = ${a}"""
    config = Config().from_str(c_str)
    assert config["c"]["baz"] == config["a"]
    # Infinite recursion
    c_str = """[a]\nfoo ="x"\n\n[a.b]\nbar = ${a}"""
    config = Config().from_str(c_str)
    assert config["a"]["b"]["bar"] == config["a"]
    c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b.bar{d}foo}}"""
    # We can't reference not-yet interpolated subsections
    with pytest.raises(ConfigValidationError):
        Config().from_str(c_str)
    # Generally invalid references
    c_str = f"""[a]\nfoo = ${{b{d}bar}}"""
    with pytest.raises(ConfigValidationError):
        Config().from_str(c_str)
    # We can't reference sections or promises within strings
    c_str = """[a]\n\n[a.b]\nfoo = "x: ${c}"\n\n[c]\nbar = 1\nbaz = 2"""
    with pytest.raises(ConfigValidationError):
        Config().from_str(c_str)


def test_config_from_str_overrides():
    config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = 3\n\n[f]\ng = {"x": "y"}"""
    # Basic value substitution
    overrides = {"a.b": 10, "a.c.d": 20}
    config = Config().from_str(config_str, overrides=overrides)
    assert config["a"]["b"] == 10
    assert config["a"]["c"]["d"] == 20
    assert config["a"]["c"]["e"] == 3
    # Valid values that previously weren't in config
    config = Config().from_str(config_str, overrides={"a.c.f": 100})
    assert config["a"]["c"]["d"] == 2
    assert config["a"]["c"]["e"] == 3
    assert config["a"]["c"]["f"] == 100
    # Invalid keys and sections
    with pytest.raises(ConfigValidationError):
        Config().from_str(config_str, overrides={"f": 10})
    # This currently isn't expected to work, because the dict in f.g is not
    # interpreted as a section while the config is still just the configparser
    with pytest.raises(ConfigValidationError):
        Config().from_str(config_str, overrides={"f.g.x": "z"})
    # With variables (values)
    config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = ${a:b}"""
    config = Config().from_str(config_str, overrides={"a.b": 10})
    assert config["a"]["b"] == 10
    assert config["a"]["c"]["e"] == 10
    # With variables (sections)
    config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\n[e]\nf = ${a.c}"""
    config = Config().from_str(config_str, overrides={"a.c.d": 20})
    assert config["a"]["c"]["d"] == 20
    assert config["e"]["f"] == {"d": 20}


def test_config_reserved_aliases():
    """Test that the auto-generated pydantic schemas auto-alias reserved
    attributes like "validate" that would otherwise cause NameError."""

    @my_registry.cats("catsie.with_alias")
    def catsie_with_alias(validate: StrictBool = False):
        return validate

    cfg = {"@cats": "catsie.with_alias", "validate": True}
    resolved = my_registry.resolve({"test": cfg})
    filled = my_registry.fill({"test": cfg})
    assert resolved["test"] is True
    assert filled["test"] == cfg
    cfg = {"@cats": "catsie.with_alias", "validate": 20}
    with pytest.raises(ConfigValidationError):
        my_registry.resolve({"test": cfg})


@pytest.mark.parametrize("d", [".", ":"])
def test_config_no_interpolation(d):
    """Test that interpolation is correctly preserved. The parametrized
    value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
    valid. The double {{ }} in the config strings are required to prevent the
    references from being interpreted as an actual f-string variable.
    """
    numpy = pytest.importorskip("numpy")
    c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}"""
    config = Config().from_str(c_str, interpolate=False)
    assert not config.is_interpolated
    assert config["c"]["d"] == f"${{a{d}b}}"
    assert config["c"]["e"] == f'"hello${{a{d}b}}"'
    assert config["c"]["f"] == "${a}"
    config2 = Config().from_str(config.to_str(), interpolate=True)
    assert config2.is_interpolated
    assert config2["c"]["d"] == 1
    assert config2["c"]["e"] == "hello1"
    assert config2["c"]["f"] == {"b": 1}
    config3 = config.interpolate()
    assert config3.is_interpolated
    assert config3["c"]["d"] == 1
    assert config3["c"]["e"] == "hello1"
    assert config3["c"]["f"] == {"b": 1}
    # Bad non-serializable value
    cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}}
    with pytest.raises(ConfigValidationError):
        Config(cfg).interpolate()


def test_config_no_interpolation_registry():
    config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}"""
    config = Config().from_str(config_str, interpolate=False)
    assert not config.is_interpolated
    assert config["b"]["evil"] == "${a:bad}"
    assert config["c"]["d"] == "${b}"
    filled = my_registry.fill(config)
    resolved = my_registry.resolve(config)
    assert resolved["b"] == "scratch!"
    assert resolved["c"]["d"] == "scratch!"
    assert filled["b"]["evil"] == "${a:bad}"
    assert filled["b"]["cute"] is True
    assert filled["c"]["d"] == "${b}"
    interpolated = filled.interpolate()
    assert interpolated.is_interpolated
    assert interpolated["b"]["evil"] is True
    assert interpolated["c"]["d"] == interpolated["b"]
    config = Config().from_str(config_str, interpolate=True)
    assert config.is_interpolated
    filled = my_registry.fill(config)
    resolved = my_registry.resolve(config)
    assert resolved["b"] == "scratch!"
    assert resolved["c"]["d"] == "scratch!"
    assert filled["b"]["evil"] is True
    assert filled["c"]["d"] == filled["b"]
    # Resolving a non-interpolated filled config
    config = Config().from_str(config_str, interpolate=False)
    assert not config.is_interpolated
    filled = my_registry.fill(config)
    assert not filled.is_interpolated
    assert filled["c"]["d"] == "${b}"
    resolved = my_registry.resolve(filled)
    assert resolved["c"]["d"] == "scratch!"


def test_config_deep_merge():
    config = {"a": "hello", "b": {"c": "d"}}
    defaults = {"a": "world", "b": {"c": "e", "f": "g"}}
    merged = Config(defaults).merge(config)
    assert len(merged) == 2
    assert merged["a"] == "hello"
    assert merged["b"] == {"c": "d", "f": "g"}
    config = {"a": "hello", "b": {"@test": "x", "foo": 1}}
    defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100}
    merged = Config(defaults).merge(config)
    assert len(merged) == 3
    assert merged["a"] == "hello"
    assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2}
    assert merged["c"] == 100
    config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100}
    defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
    merged = Config(defaults).merge(config)
    assert len(merged) == 3
    assert merged["a"] == "hello"
    assert merged["b"] == {"@test": "x", "foo": 1}
    assert merged["c"] == 100
    # Test that leaving out the factory just adds to existing
    config = {"a": "hello", "b": {"foo": 1}, "c": 100}
    defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
    merged = Config(defaults).merge(config)
    assert len(merged) == 3
    assert merged["a"] == "hello"
    assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2}
    assert merged["c"] == 100
    # Test that switching to a different factory prevents the default from being added
    config = {"a": "hello", "b": {"@foo": 1}, "c": 100}
    defaults = {"a": "world", "b": {"@bar": "y"}}
    merged = Config(defaults).merge(config)
    assert len(merged) == 3
    assert merged["a"] == "hello"
    assert merged["b"] == {"@foo": 1}
    assert merged["c"] == 100
    config = {"a": "hello", "b": {"@foo": 1}, "c": 100}
    defaults = {"a": "world", "b": "y"}
    merged = Config(defaults).merge(config)
    assert len(merged) == 3
    assert merged["a"] == "hello"
    assert merged["b"] == {"@foo": 1}
    assert merged["c"] == 100


def test_config_deep_merge_variables():
    config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}"""
    defaults_str = """[a]\nx = 100\n\n[d]\ny = 500"""
    config = Config().from_str(config_str, interpolate=False)
    defaults = Config().from_str(defaults_str)
    merged = defaults.merge(config)
    assert merged["a"] == {"b": 1, "c": 2, "x": 100}
    assert merged["d"] == {"e": "${a:b}", "y": 500}
    assert merged.interpolate()["d"] == {"e": 1, "y": 500}
    # With variable in defaults: overwritten by new value
    config = Config().from_str("""[a]\nb= 1\nc = 2""")
    defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False)
    merged = defaults.merge(config)
    assert merged["a"]["c"] == 2


def test_config_to_str_roundtrip():
    numpy = pytest.importorskip("numpy")
    cfg = {"cfg": {"foo": False}}
    config_str = Config(cfg).to_str()
    assert config_str == "[cfg]\nfoo = false"
    config = Config().from_str(config_str)
    assert dict(config) == cfg
    cfg = {"cfg": {"foo": "false"}}
    config_str = Config(cfg).to_str()
    assert config_str == '[cfg]\nfoo = "false"'
    config = Config().from_str(config_str)
    assert dict(config) == cfg
    # Bad non-serializable value
    cfg = {"cfg": {"x": numpy.asarray([[1, 2, 3, 4], [4, 5, 3, 4]], dtype="f")}}
    config = Config(cfg)
    with pytest.raises(ConfigValidationError):
        config.to_str()
    # Roundtrip with variables: preserve variables correctly (quoted/unquoted)
    config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = "${a:b}\""""
    config = Config().from_str(config_str, interpolate=False)
    assert config.to_str() == config_str


def test_config_is_interpolated():
    """Test that a config object correctly reports whether it's interpolated."""
    config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}"""
    config = Config().from_str(config_str, interpolate=False)
    assert not config.is_interpolated
    config = config.merge(Config({"x": {"y": "z"}}))
    assert not config.is_interpolated
    config = Config(config)
    assert not config.is_interpolated
    config = config.interpolate()
    assert config.is_interpolated
    config = config.merge(Config().from_str(config_str, interpolate=False))
    assert not config.is_interpolated


@pytest.mark.parametrize(
    "section_order,expected_str,expected_keys",
    [
        # fmt: off
        ([], "[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[h]\ni = 5\n\n[j]\nk = 6", ["a", "h", "j"]),
        (["j", "h", "a"], "[j]\nk = 6\n\n[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4", ["j", "h", "a"]),
        (["h"], "[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[j]\nk = 6", ["h", "a", "j"])
        # fmt: on
    ],
)
def test_config_serialize_custom_sort(section_order, expected_str, expected_keys):
    cfg = {
        "j": {"k": 6},
        "a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}},
        "h": {"i": 5},
    }
    cfg_str = Config(cfg).to_str()
    assert Config(cfg, section_order=section_order).to_str() == expected_str
    keys = list(Config(section_order=section_order).from_str(cfg_str).keys())
    assert keys == expected_keys
    keys = list(Config(cfg, section_order=section_order).keys())
    assert keys == expected_keys


def test_config_custom_sort_preserve():
    """Test that sort order is preserved when merging and copying configs,
    or when configs are filled and resolved."""
    cfg = {"x": {}, "y": {}, "z": {}}
    section_order = ["y", "z", "x"]
    expected = "[y]\n\n[z]\n\n[x]"
    config = Config(cfg, section_order=section_order)
    assert config.to_str() == expected
    config2 = config.copy()
    assert config2.to_str() == expected
    config3 = config.merge({"a": {}})
    assert config3.to_str() == f"{expected}\n\n[a]"
    config4 = Config(config)
    assert config4.to_str() == expected
    config_str = """[a]\nb = 1\n[c]\n@cats = "catsie.v1"\nevil = true\n\n[t]\n x = 2"""
    section_order = ["c", "a", "t"]
    config5 = Config(section_order=section_order).from_str(config_str)
    assert list(config5.keys()) == section_order
    filled = my_registry.fill(config5)
    assert filled.section_order == section_order


def test_config_pickle():
    config = Config({"foo": "bar"}, section_order=["foo", "bar", "baz"])
    data = pickle.dumps(config)
    config_new = pickle.loads(data)
    assert config_new == {"foo": "bar"}
    assert config_new.section_order == ["foo", "bar", "baz"]


def test_config_fill_extra_fields():
    """Test that filling a config from a schema removes extra fields."""

    class TestSchemaContent(BaseModel):
        a: str
        b: int

        class Config:
            extra = "forbid"

    class TestSchema(BaseModel):
        cfg: TestSchemaContent

    config = Config({"cfg": {"a": "1", "b": 2, "c": True}})
    with pytest.raises(ConfigValidationError):
        my_registry.fill(config, schema=TestSchema)
    filled = my_registry.fill(config, schema=TestSchema, validate=False)["cfg"]
    assert filled == {"a": "1", "b": 2}
    config2 = config.interpolate()
    filled = my_registry.fill(config2, schema=TestSchema, validate=False)["cfg"]
    assert filled == {"a": "1", "b": 2}
    config3 = Config({"cfg": {"a": "1", "b": 2, "c": True}}, is_interpolated=False)
    filled = my_registry.fill(config3, schema=TestSchema, validate=False)["cfg"]
    assert filled == {"a": "1", "b": 2}

    class TestSchemaContent2(BaseModel):
        a: str
        b: int

        class Config:
            extra = "allow"

    class TestSchema2(BaseModel):
        cfg: TestSchemaContent2

    filled = my_registry.fill(config, schema=TestSchema2, validate=False)["cfg"]
    assert filled == {"a": "1", "b": 2, "c": True}


def test_config_validation_error_custom():
    class Schema(BaseModel):
        hello: int
        world: int

    config = {"hello": 1, "world": "hi!"}
    with pytest.raises(ConfigValidationError) as exc_info:
        my_registry._fill(config, Schema)
    e1 = exc_info.value
    assert e1.title == "Config validation error"
    assert e1.desc is None
    assert not e1.parent
    assert e1.show_config is True
    assert len(e1.errors) == 1
    assert e1.errors[0]["loc"] == ("world",)
    assert e1.errors[0]["msg"] == "value is not a valid integer"
    assert e1.errors[0]["type"] == "type_error.integer"
    assert e1.error_types == set(["type_error.integer"])
    # Create a new error with overrides
    title = "Custom error"
    desc = "Some error description here"
    e2 = ConfigValidationError.from_error(e1, title=title, desc=desc, show_config=False)
    assert e2.errors == e1.errors
    assert e2.error_types == e1.error_types
    assert e2.title == title
    assert e2.desc == desc
    assert e2.show_config is False
    assert e1.text != e2.text


def test_config_parsing_error():
    config_str = "[a]\nb c"
    with pytest.raises(ConfigValidationError):
        Config().from_str(config_str)


def test_config_fill_without_resolve():
    class BaseSchema(BaseModel):
        catsie: int

    config = {"catsie": {"@cats": "catsie.v1", "evil": False}}
    filled = my_registry.fill(config)
    resolved = my_registry.resolve(config)
    assert resolved["catsie"] == "meow"
    assert filled["catsie"]["cute"] is True
    with pytest.raises(ConfigValidationError):
        my_registry.resolve(config, schema=BaseSchema)
    filled2 = my_registry.fill(config, schema=BaseSchema)
    assert filled2["catsie"]["cute"] is True
    resolved = my_registry.resolve(filled2)
    assert resolved["catsie"] == "meow"

    # With unavailable function
    class BaseSchema2(BaseModel):
        catsie: Any
        other: int = 12

    config = {"catsie": {"@cats": "dog", "evil": False}}
    filled3 = my_registry.fill(config, schema=BaseSchema2)
    assert filled3["catsie"] == config["catsie"]
    assert filled3["other"] == 12


def test_config_dataclasses():
    cat = Cat("testcat", value_in=1, value_out=2)
    config = {"cfg": {"@cats": "catsie.v3", "arg": cat}}
    result = my_registry.resolve(config)["cfg"]
    assert isinstance(result, Cat)
    assert result.name == cat.name
    assert result.value_in == cat.value_in
    assert result.value_out == cat.value_out


@pytest.mark.parametrize(
    "greeting,value,expected",
    [
        # simple substitution should go fine
        [342, "${vars.a}", int],
        ["342", "${vars.a}", str],
        ["everyone", "${vars.a}", str],
    ],
)
def test_config_interpolates(greeting, value, expected):
    str_cfg = f"""
    [project]
    my_par = {value}

    [vars]
    a = "something"
    """
    overrides = {"vars.a": greeting}
    cfg = Config().from_str(str_cfg, overrides=overrides)
    assert type(cfg["project"]["my_par"]) == expected


@pytest.mark.parametrize(
    "greeting,value,expected",
    [
        # fmt: off
        # simple substitution should go fine
        ["hello 342", "${vars.a}", "hello 342"],
        ["hello everyone", "${vars.a}", "hello everyone"],
        ["hello tout le monde", "${vars.a}", "hello tout le monde"],
        ["hello 42", "${vars.a}", "hello 42"],
        # substituting an element in a list
        ["hello 342", "[1, ${vars.a}, 3]", "hello 342"],
        ["hello everyone", "[1, ${vars.a}, 3]", "hello everyone"],
        ["hello tout le monde", "[1, ${vars.a}, 3]", "hello tout le monde"],
        ["hello 42", "[1, ${vars.a}, 3]", "hello 42"],
        # substituting part of a string
        [342, "hello ${vars.a}", "hello 342"],
        ["everyone", "hello ${vars.a}", "hello everyone"],
        ["tout le monde", "hello ${vars.a}", "hello tout le monde"],
        pytest.param("42", "hello ${vars.a}", "hello 42", marks=pytest.mark.xfail),
        # substituting part of a implicit string inside a list
        [342, "[1, hello ${vars.a}, 3]", "hello 342"],
        ["everyone", "[1, hello ${vars.a}, 3]", "hello everyone"],
        ["tout le monde", "[1, hello ${vars.a}, 3]", "hello tout le monde"],
        pytest.param("42", "[1, hello ${vars.a}, 3]", "hello 42", marks=pytest.mark.xfail),
        # substituting part of a explicit string inside a list
        [342, "[1, 'hello ${vars.a}', '3']", "hello 342"],
        ["everyone", "[1, 'hello ${vars.a}', '3']", "hello everyone"],
        ["tout le monde", "[1, 'hello ${vars.a}', '3']", "hello tout le monde"],
        pytest.param("42", "[1, 'hello ${vars.a}', '3']", "hello 42", marks=pytest.mark.xfail),
        # more complicated example
        [342, "[{'name':'x','script':['hello ${vars.a}']}]", "hello 342"],
        ["everyone", "[{'name':'x','script':['hello ${vars.a}']}]", "hello everyone"],
        ["tout le monde", "[{'name':'x','script':['hello ${vars.a}']}]", "hello tout le monde"],
        pytest.param("42", "[{'name':'x','script':['hello ${vars.a}']}]", "hello 42", marks=pytest.mark.xfail),
        # fmt: on
    ],
)
def test_config_overrides(greeting, value, expected):
    str_cfg = f"""
    [project]
    commands = {value}

    [vars]
    a = "world"
    """
    overrides = {"vars.a": greeting}
    assert "${vars.a}" in str_cfg
    cfg = Config().from_str(str_cfg, overrides=overrides)
    assert expected in str(cfg)


def test_warn_single_quotes():
    str_cfg = """
    [project]
    commands = 'do stuff'
    """

    with pytest.warns(UserWarning, match="single-quoted"):
        Config().from_str(str_cfg)

    # should not warn if single quotes are in the middle
    str_cfg = """
    [project]
    commands = some'thing
    """
    Config().from_str(str_cfg)


def test_parse_strings_interpretable_as_ints():
    """Test whether strings interpretable as integers are parsed correctly (i. e. as strings)."""
    cfg = Config().from_str(
        f"""[a]\nfoo = [${{b.bar}}, "00${{b.bar}}", "y"]\n\n[b]\nbar = 3"""  # noqa: F541
    )
    assert cfg["a"]["foo"] == [3, "003", "y"]
    assert cfg["b"]["bar"] == 3
