"""Mixin that adds model instance loading behavior.

.. warning::

    This module is treated as private API.
    Users should not need to use this module directly.
"""
import marshmallow as ma

from .fields import get_primary_keys


class LoadInstanceMixin:
    class Opts:
        def __init__(self, meta, *args, **kwargs):
            super().__init__(meta, *args, **kwargs)
            self.sqla_session = getattr(meta, "sqla_session", None)
            self.load_instance = getattr(meta, "load_instance", False)
            self.transient = getattr(meta, "transient", False)

    class Schema:
        @property
        def session(self):
            return self._session or self.opts.sqla_session

        @session.setter
        def session(self, session):
            self._session = session

        @property
        def transient(self):
            if self._transient is not None:
                return self._transient
            return self.opts.transient

        @transient.setter
        def transient(self, transient):
            self._transient = transient

        def __init__(self, *args, **kwargs):
            self._session = kwargs.pop("session", None)
            self.instance = kwargs.pop("instance", None)
            self._transient = kwargs.pop("transient", None)
            self._load_instance = kwargs.pop("load_instance", self.opts.load_instance)
            super().__init__(*args, **kwargs)

        def get_instance(self, data):
            """Retrieve an existing record by primary key(s). If the schema instance
            is transient, return None.

            :param data: Serialized data to inform lookup.
            """
            if self.transient:
                return None
            props = get_primary_keys(self.opts.model)
            filters = {prop.key: data.get(prop.key) for prop in props}
            if None not in filters.values():
                return self.session.query(self.opts.model).filter_by(**filters).first()
            return None

        @ma.post_load
        def make_instance(self, data, **kwargs):
            """Deserialize data to an instance of the model if self.load_instance is True.

            Update an existing row if specified in `self.instance` or loaded by primary
            key(s) in the data; else create a new row.

            :param data: Data to deserialize.
            """
            if not self._load_instance:
                return data
            instance = self.instance or self.get_instance(data)
            if instance is not None:
                for key, value in data.items():
                    setattr(instance, key, value)
                return instance
            kwargs, association_attrs = self._split_model_kwargs_association(data)
            instance = self.opts.model(**kwargs)
            for attr, value in association_attrs.items():
                setattr(instance, attr, value)
            return instance

        def load(self, data, *, session=None, instance=None, transient=False, **kwargs):
            """Deserialize data to internal representation.

            :param session: Optional SQLAlchemy session.
            :param instance: Optional existing instance to modify.
            :param transient: Optional switch to allow transient instantiation.
            """
            self._session = session or self._session
            self._transient = transient or self._transient
            if self._load_instance and not (self.transient or self.session):
                raise ValueError("Deserialization requires a session")
            self.instance = instance or self.instance
            try:
                return super().load(data, **kwargs)
            finally:
                self.instance = None

        def validate(self, data, *, session=None, **kwargs):
            self._session = session or self._session
            if not (self.transient or self.session):
                raise ValueError("Validation requires a session")
            return super().validate(data, **kwargs)

        def _split_model_kwargs_association(self, data):
            """Split serialized attrs to ensure association proxies are passed separately.

            This is necessary for Python < 3.6.0, as the order in which kwargs are passed
            is non-deterministic, and associations must be parsed by sqlalchemy after their
            intermediate relationship, unless their `creator` has been set.

            Ignore invalid keys at this point - behaviour for unknowns should be
            handled elsewhere.

            :param data: serialized dictionary of attrs to split on association_proxy.
            """
            association_attrs = {
                key: value
                for key, value in data.items()
                # association proxy
                if hasattr(getattr(self.opts.model, key, None), "remote_attr")
            }
            kwargs = {
                key: value
                for key, value in data.items()
                if (hasattr(self.opts.model, key) and key not in association_attrs)
            }
            return kwargs, association_attrs
