"""Stores for connecting to AWS data."""

import threading
import warnings
import zlib
from collections.abc import Iterator
from concurrent.futures import wait
from concurrent.futures.thread import ThreadPoolExecutor
from hashlib import sha1
from io import BytesIO
from json import dumps
from typing import Any, Callable, Optional, Union

import msgpack  # type: ignore
from monty.msgpack import default as monty_default

from maggma.core import Sort, Store
from maggma.stores.ssh_tunnel import SSHTunnel
from maggma.utils import grouper, to_isoformat_ceil_ms

try:
    import boto3
    import botocore
    from boto3.session import Session
    from botocore.exceptions import ClientError
except (ImportError, ModuleNotFoundError):
    boto3 = None  # type: ignore


class S3Store(Store):
    """
    GridFS like storage using Amazon S3 and a regular store for indexing.

    Assumes Amazon AWS key and secret key are set in environment or default config file.
    """

    def __init__(
        self,
        index: Store,
        bucket: str,
        s3_profile: Optional[Union[str, dict]] = None,
        compress: bool = False,
        endpoint_url: Optional[str] = None,
        sub_dir: Optional[str] = None,
        s3_workers: int = 1,
        s3_resource_kwargs: Optional[dict] = None,
        ssh_tunnel: Optional[SSHTunnel] = None,
        key: str = "fs_id",
        store_hash: bool = True,
        unpack_data: bool = True,
        searchable_fields: Optional[list[str]] = None,
        index_store_kwargs: Optional[dict] = None,
        **kwargs,
    ):
        """
        Initializes an S3 Store.

        Args:
            index: a store to use to index the S3 bucket.
            bucket: name of the bucket.
            s3_profile: name of AWS profile containing the credentials. Alternatively
                you can pass in a dictionary with the full credentials:
                    aws_access_key_id (string) -- AWS access key ID
                    aws_secret_access_key (string) -- AWS secret access key
                    aws_session_token (string) -- AWS temporary session token
                    region_name (string) -- Default region when creating new connections
            compress: compress files inserted into the store.
            endpoint_url: this allows the interface with minio service; ignored if
                `ssh_tunnel` is provided, in which case it is inferred.
            sub_dir: subdirectory of the S3 bucket to store the data.
            s3_workers: number of concurrent S3 puts to run.
            s3_resource_kwargs: additional kwargs to pass to the boto3 session resource.
            ssh_tunnel: optional SSH tunnel to use for the S3 connection.
            key: main key to index on.
            store_hash: store the SHA1 hash right before insertion to the database.
            unpack_data: whether to decompress and unpack byte data when querying from
                the bucket.
            searchable_fields: fields to keep in the index store.
            index_store_kwargs: kwargs to pass to the index store. Allows the user to
                use kwargs here to update the index store.
        """
        if boto3 is None:
            raise RuntimeError("boto3 and botocore are required for S3Store")
        self.index_store_kwargs = index_store_kwargs or {}
        if index_store_kwargs:
            d_ = index.as_dict()
            d_.update(index_store_kwargs)
            self.index = index.__class__.from_dict(d_)
        else:
            self.index = index
        self.bucket = bucket
        self.s3_profile = s3_profile
        self.compress = compress
        self.endpoint_url = endpoint_url
        self.sub_dir = sub_dir.strip("/") + "/" if sub_dir else ""
        self.s3: Any = None
        self.s3_bucket: Any = None
        self.s3_workers = s3_workers
        self.s3_resource_kwargs = s3_resource_kwargs if s3_resource_kwargs is not None else {}
        self.ssh_tunnel = ssh_tunnel
        self.unpack_data = unpack_data
        self.searchable_fields = searchable_fields if searchable_fields is not None else []
        self.store_hash = store_hash

        # Force the key to be the same as the index
        assert isinstance(index.key, str), "Since we are using the key as a file name in S3, they key must be a string"
        if key != index.key:
            warnings.warn(
                f'The desired S3Store key "{key}" does not match the index key "{index.key},"'
                "the index key will be used",
                UserWarning,
            )
        kwargs["key"] = str(index.key)

        self._thread_local = threading.local()
        super().__init__(**kwargs)

    @property
    def name(self) -> str:
        """String representing this data source."""
        return f"s3://{self.bucket}"

    def connect(self, force_reset: bool = False):  # lgtm[py/conflicting-attributes]
        """Connect to the source data.

        Args:
            force_reset: whether to force a reset of the connection
        """
        if self.s3 is None or force_reset:
            self.s3, self.s3_bucket = self._get_resource_and_bucket()
        self.index.connect(force_reset=force_reset)

    def close(self):
        """Closes any connections."""
        self.index.close()

        self.s3.meta.client.close()
        self.s3 = None
        self.s3_bucket = None

        if self.ssh_tunnel is not None:
            self.ssh_tunnel.stop()

    @property
    def _collection(self):
        """
        A handle to the pymongo collection object.

        Important:
            Not guaranteed to exist in the future.
        """
        # For now returns the index collection since that is what we would "search" on
        return self.index._collection

    def count(self, criteria: Optional[dict] = None) -> int:
        """
        Counts the number of documents matching the query criteria.

        Args:
            criteria: PyMongo filter for documents to count in.
        """
        return self.index.count(criteria)

    def query(
        self,
        criteria: Optional[dict] = None,
        properties: Union[dict, list, None] = None,
        sort: Optional[dict[str, Union[Sort, int]]] = None,
        skip: int = 0,
        limit: int = 0,
    ) -> Iterator[dict]:
        """
        Queries the Store for a set of documents.

        Args:
            criteria: PyMongo filter for documents to search in.
            properties: properties to return in grouped documents.
            sort: Dictionary of sort order for fields. Keys are field names and values
                are 1 for ascending or -1 for descending.
            skip: number documents to skip.
            limit: limit on total number of documents returned.

        """
        prop_keys = set()
        if isinstance(properties, dict):
            prop_keys = set(properties.keys())
        elif isinstance(properties, list):
            prop_keys = set(properties)

        for doc in self.index.query(criteria=criteria, sort=sort, limit=limit, skip=skip):
            if properties is not None and prop_keys.issubset(set(doc.keys())):
                yield {p: doc[p] for p in properties if p in doc}
            else:
                try:
                    # TODO: This is ugly and unsafe, do some real checking before pulling data
                    data = self.s3_bucket.Object(self._get_full_key_path(doc[self.key])).get()["Body"].read()
                except botocore.exceptions.ClientError as e:
                    # If a client error is thrown, then check that it was a NoSuchKey or NoSuchBucket error.
                    # If it was a NoSuchKey error, then the object does not exist.
                    error_code = e.response["Error"]["Code"]
                    if error_code in ["NoSuchKey", "NoSuchBucket"]:
                        error_message = e.response["Error"]["Message"]
                        self.logger.error(
                            f"S3 returned '{error_message}' while querying '{self.bucket}' for '{doc[self.key]}'"
                        )
                        continue
                    else:
                        raise e

                if self.unpack_data:
                    data = self._read_data(data=data, compress_header=doc.get("compression", ""))

                    if self.last_updated_field in doc:
                        data[self.last_updated_field] = doc[self.last_updated_field]

                yield data

    def _read_data(self, data: bytes, compress_header: str) -> dict:
        """Reads the data and transforms it into a dictionary.
        Allows for subclasses to apply custom schemes for transforming
        the data retrieved from S3.

        Args:
            data (bytes): The raw byte representation of the data.
            compress_header (str): String representing the type of compression used on the data.

        Returns:
            Dict: Dictionary representation of the data.
        """
        return self._unpack(data=data, compressed=compress_header == "zlib")

    @staticmethod
    def _unpack(data: bytes, compressed: bool):
        if compressed:
            data = zlib.decompress(data)
        # requires msgpack-python to be installed to fix string encoding problem
        # https://github.com/msgpack/msgpack/issues/121
        # During recursion
        # msgpack.unpackb goes as deep as possible during reconstruction
        # MontyDecoder().process_decode only goes until it finds a from_dict
        # as such, we cannot just use msgpack.unpackb(data, object_hook=monty_object_hook, raw=False)
        # Should just return the unpacked object then let the user run process_decoded
        return msgpack.unpackb(data, raw=False)

    def distinct(self, field: str, criteria: Optional[dict] = None, all_exist: bool = False) -> list:
        """
        Get all distinct values for a field.

        Args:
            field: the field(s) to get distinct values for.
            criteria: PyMongo filter for documents to search in.
        """
        # Index is a store so it should have its own distinct function
        return self.index.distinct(field, criteria=criteria)

    def groupby(
        self,
        keys: Union[list[str], str],
        criteria: Optional[dict] = None,
        properties: Union[dict, list, None] = None,
        sort: Optional[dict[str, Union[Sort, int]]] = None,
        skip: int = 0,
        limit: int = 0,
    ) -> Iterator[tuple[dict, list[dict]]]:
        """
        Simple grouping function that will group documents by keys.

        Args:
            keys: fields to group documents.
            criteria: PyMongo filter for documents to search in.
            properties: properties to return in grouped documents.
            sort: Dictionary of sort order for fields. Keys are field names and values
            are 1 for ascending or -1 for descending.
            skip: number documents to skip.
            limit: limit on total number of documents returned.

        Returns:
            generator returning tuples of (dict, list of docs)
        """
        return self.index.groupby(
            keys=keys,
            criteria=criteria,
            properties=properties,
            sort=sort,
            skip=skip,
            limit=limit,
        )

    def ensure_index(self, key: str, unique: bool = False) -> bool:
        """
        Tries to create an index and return true if it succeeded.

        Args:
            key: single key to index.
            unique: whether this index contains only unique keys.

        Returns:
            bool indicating if the index exists/was created.
        """
        return self.index.ensure_index(key, unique=unique)

    def update(
        self,
        docs: Union[list[dict], dict],
        key: Union[list, str, None] = None,
        additional_metadata: Union[str, list[str], None] = None,
    ):
        """
        Update documents into the Store.

        Args:
            docs: the document or list of documents to update.
            key: field name(s) to determine uniqueness for a document, can be a list of
                multiple fields, a single field, or None if the Store's key field is to
                be used.
            additional_metadata: field(s) to include in the S3 store's metadata.
        """
        if not isinstance(docs, list):
            docs = [docs]

        if isinstance(key, str):
            key = [key]
        elif not key:
            key = [self.key]

        if additional_metadata is None:
            additional_metadata = []
        elif isinstance(additional_metadata, str):
            additional_metadata = [additional_metadata]
        else:
            additional_metadata = list(additional_metadata)

        self._write_to_s3_and_index(docs, key + additional_metadata + self.searchable_fields)

    def _write_to_s3_and_index(self, docs: list[dict], search_keys: list[str]):
        """Implements updating of the provided documents in S3 and the index.
        Allows for subclasses to apply custom approaches to parellizing the writing.

        Args:
            docs (List[Dict]): The documents to update
            search_keys (List[str]): The keys of the information to be updated in the index
        """
        with ThreadPoolExecutor(max_workers=self.s3_workers) as pool:
            fs = {
                pool.submit(
                    self.write_doc_to_s3,
                    doc=itr_doc,
                    search_keys=search_keys,
                )
                for itr_doc in docs
            }
            fs, _ = wait(fs)

            search_docs = [sdoc.result() for sdoc in fs]

        # Use store's update to remove key clashes
        self.index.update(search_docs, key=self.key)

    def _get_session(self):
        if self.ssh_tunnel is not None:
            self.ssh_tunnel.start()

        if not hasattr(self._thread_local, "s3_bucket"):
            if isinstance(self.s3_profile, dict):
                return Session(**self.s3_profile)
            return Session(profile_name=self.s3_profile)

        return None

    def _get_endpoint_url(self):
        if self.ssh_tunnel is None:
            return self.endpoint_url
        host, port = self.ssh_tunnel.local_address
        return f"http://{host}:{port}"

    def _get_bucket(self):
        """If on the main thread return the bucket created above, else create a new
        bucket on each thread.
        """
        if threading.current_thread().name == "MainThread":
            return self.s3_bucket

        if not hasattr(self._thread_local, "s3_bucket"):
            _, bucket = self._get_resource_and_bucket()
            self._thread_local.s3_bucket = bucket

        return self._thread_local.s3_bucket

    def _get_resource_and_bucket(self):
        """Helper function to create the resource and bucket objects."""
        session = self._get_session()
        endpoint_url = self._get_endpoint_url()
        resource = session.resource("s3", endpoint_url=endpoint_url, **self.s3_resource_kwargs)
        try:
            resource.meta.client.head_bucket(Bucket=self.bucket)
        except ClientError:
            raise RuntimeError("Bucket not present on AWS")
        bucket = resource.Bucket(self.bucket)

        return resource, bucket

    def _get_full_key_path(self, id: str) -> str:
        """Produces the full key path for S3 items.

        Args:
            id (str): The value of the key identifier.

        Returns:
            str: The full key path
        """
        return self.sub_dir + str(id)

    def _get_compression_function(self) -> Callable:
        """Returns the function to use for compressing data."""
        return zlib.compress

    def _get_decompression_function(self) -> Callable:
        """Returns the function to use for decompressing data."""
        return zlib.decompress

    def write_doc_to_s3(self, doc: dict, search_keys: list[str]) -> dict:
        """
        Write the data to s3 and return the metadata to be inserted into the index db.

        Args:
            doc: the document.
            search_keys: list of keys to pull from the docs and be inserted into the
                index db.

        Returns:
            Dict: The metadata to be inserted into the index db
        """
        s3_bucket = self._get_bucket()

        search_doc = {k: doc[k] for k in search_keys}
        search_doc[self.key] = doc[self.key]  # Ensure key is in metadata
        if self.sub_dir != "":
            search_doc["sub_dir"] = self.sub_dir

        # Remove MongoDB _id from search
        if "_id" in search_doc:
            del search_doc["_id"]

        # to make hashing more meaningful, make sure last updated field is removed
        lu_info = doc.pop(self.last_updated_field, None)
        data = msgpack.packb(doc, default=monty_default)

        if self.compress:
            # Compress with zlib if chosen
            search_doc["compression"] = "zlib"
            data = self._get_compression_function()(data)

        # keep a record of original keys, in case these are important for the individual researcher
        # it is not expected that this information will be used except in disaster recovery
        s3_to_mongo_keys = {k: self._sanitize_key(k) for k in search_doc}
        s3_to_mongo_keys["s3-to-mongo-keys"] = "s3-to-mongo-keys"  # inception
        # encode dictionary since values have to be strings
        search_doc["s3-to-mongo-keys"] = dumps(s3_to_mongo_keys)
        s3_bucket.upload_fileobj(
            Fileobj=BytesIO(data),
            Key=self._get_full_key_path(str(doc[self.key])),
            ExtraArgs={"Metadata": {s3_to_mongo_keys[k]: str(v) for k, v in search_doc.items()}},
        )

        if lu_info is not None:
            search_doc[self.last_updated_field] = lu_info

        if self.store_hash:
            hasher = sha1()
            hasher.update(data)
            obj_hash = hasher.hexdigest()
            search_doc["obj_hash"] = obj_hash
        return search_doc

    @staticmethod
    def _sanitize_key(key):
        """Sanitize keys to store in S3/MinIO metadata."""
        # Any underscores are encoded as double dashes in metadata, since keys with
        # underscores may be result in the corresponding HTTP header being stripped
        # by certain server configurations (e.g. default nginx), leading to:
        # `botocore.exceptions.ClientError: An error occurred (AccessDenied) when
        # calling the PutObject operation: There were headers present in the request
        # which were not signed`
        # Metadata stored in the MongoDB index (self.index) is stored unchanged.

        # Additionally, MinIO requires lowercase keys
        return str(key).replace("_", "-").lower()

    def remove_docs(self, criteria: dict, remove_s3_object: bool = False):
        """
        Remove docs matching the query dictionary.

        Args:
            criteria: query dictionary to match.
            remove_s3_object: whether to remove the actual S3 object or not.
        """
        if not remove_s3_object:
            self.index.remove_docs(criteria=criteria)
        else:
            to_remove = self.index.distinct(self.key, criteria=criteria)
            self.index.remove_docs(criteria=criteria)

            # Can remove up to 1000 items at a time via boto
            to_remove_chunks = list(grouper(to_remove, n=1000))
            for chunk_to_remove in to_remove_chunks:
                objlist = [{"Key": self._get_full_key_path(obj)} for obj in chunk_to_remove]
                self.s3_bucket.delete_objects(Delete={"Objects": objlist})

    @property
    def last_updated(self):
        return self.index.last_updated

    def newer_in(self, target: Store, criteria: Optional[dict] = None, exhaustive: bool = False) -> list[str]:
        """
        Returns the keys of documents that are newer in the target Store than this Store.

        Args:
            target: target Store.
            criteria: PyMongo filter for documents to search in.
            exhaustive: triggers an item-by-item check vs. checking the last_updated of
                the target Store and using that to filter out new items in.
        """
        if hasattr(target, "index"):
            return self.index.newer_in(target=target.index, criteria=criteria, exhaustive=exhaustive)
        return self.index.newer_in(target=target, criteria=criteria, exhaustive=exhaustive)

    def __hash__(self):
        return hash((self.index.__hash__, self.bucket))

    def rebuild_index_from_s3_data(self, **kwargs):
        """
        Rebuilds the index Store from the data in S3.

        Relies on the index document being stores as the metadata for the file. This can
        help recover lost databases.
        """
        bucket = self.s3_bucket
        objects = bucket.objects.filter(Prefix=self.sub_dir)
        for obj in objects:
            key_ = self._get_full_key_path(obj.key)
            data = self.s3_bucket.Object(key_).get()["Body"].read()

            if self.compress:
                data = self._get_decompression_function()(data)
            unpacked_data = msgpack.unpackb(data, raw=False)
            self.update(unpacked_data, **kwargs)

    def rebuild_metadata_from_index(self, index_query: Optional[dict] = None):
        """
        Read data from the index store and populate the metadata of the S3 bucket.
        Force all the keys to be lower case to be Minio compatible.

        Args:
            index_query: query on the index store.
        """
        qq = {} if index_query is None else index_query
        for index_doc in self.index.query(qq):
            key_ = self._get_full_key_path(index_doc[self.key])
            s3_object = self.s3_bucket.Object(key_)
            new_meta = {self._sanitize_key(k): v for k, v in s3_object.metadata.items()}
            for k, v in index_doc.items():
                new_meta[str(k).lower()] = v
            new_meta.pop("_id")
            if self.last_updated_field in new_meta:
                new_meta[self.last_updated_field] = str(to_isoformat_ceil_ms(new_meta[self.last_updated_field]))
            # s3_object.metadata.update(new_meta)
            s3_object.copy_from(
                CopySource={"Bucket": self.s3_bucket.name, "Key": key_},
                Metadata=new_meta,
                MetadataDirective="REPLACE",
            )

    def __eq__(self, other: object) -> bool:
        """
        Check equality for S3Store.

        other: other S3Store to compare with.
        """
        if not isinstance(other, S3Store):
            return False

        fields = ["index", "bucket", "last_updated_field"]
        return all(getattr(self, f) == getattr(other, f) for f in fields)
