﻿#------------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. 
# All rights reserved.
# 
# This code is licensed under the MIT License.
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files(the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions :
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#------------------------------------------------------------------------------

import base64
import copy
import hashlib
from datetime import datetime, timedelta
from dateutil import parser

from .adal_error import AdalError
from .constants import TokenResponseFields, Misc
from . import log

#surppress warnings: like accces to a protected member of "_AUTHORITY", etc
# pylint: disable=W0212

def _create_token_hash(token):
    hash_object = hashlib.sha256()
    hash_object.update(token.encode('utf8'))
    return base64.b64encode(hash_object.digest())

def _create_token_id_message(entry):
    access_token_hash = _create_token_hash(entry[TokenResponseFields.ACCESS_TOKEN])
    message = 'AccessTokenId: ' + str(access_token_hash)
    if entry.get(TokenResponseFields.REFRESH_TOKEN):
        refresh_token_hash = _create_token_hash(entry[TokenResponseFields.REFRESH_TOKEN])
        message += ', RefreshTokenId: ' + str(refresh_token_hash)
    return message

def _is_mrrt(entry):
    return bool(entry.get(TokenResponseFields.RESOURCE, None))

def _entry_has_metadata(entry):
    return (TokenResponseFields._CLIENT_ID in entry and 
            TokenResponseFields._AUTHORITY in entry)


class CacheDriver(object):
    def __init__(self, call_context, authority, resource, client_id, cache,
                 refresh_function):
        self._call_context = call_context
        self._log = log.Logger("CacheDriver", call_context['log_context'])
        self._authority = authority
        self._resource = resource
        self._client_id = client_id
        self._cache = cache
        self._refresh_function = refresh_function

    def _get_potential_entries(self, query):
        potential_entries_query = {}

        if query.get(TokenResponseFields._CLIENT_ID):
            potential_entries_query[TokenResponseFields._CLIENT_ID] = query[TokenResponseFields._CLIENT_ID]
      
        if query.get(TokenResponseFields.USER_ID):
            potential_entries_query[TokenResponseFields.USER_ID] = query[TokenResponseFields.USER_ID]

        self._log.debug(
            'Looking for potential cache entries: %(query)s',
            {"query": log.scrub_pii(potential_entries_query)})
        entries = self._cache.find(potential_entries_query)
        self._log.debug(
            'Found %(quantity)s potential entries.', {"quantity": len(entries)})
        return entries
    
    def _find_mrrt_tokens_for_user(self, user):
        return self._cache.find({
            TokenResponseFields.IS_MRRT: True,
            TokenResponseFields.USER_ID: user,
            TokenResponseFields._CLIENT_ID : self._client_id            
            })

    def _load_single_entry_from_cache(self, query):
        return_val = []
        is_resource_tenant_specific = False

        potential_entries = self._get_potential_entries(query)
        if potential_entries:
            resource_tenant_specific_entries = [
                x for x in potential_entries 
                if x[TokenResponseFields.RESOURCE] == self._resource and 
                x[TokenResponseFields._AUTHORITY] == self._authority]

            if not resource_tenant_specific_entries:
                self._log.debug('No resource specific cache entries found.')

                #There are no resource specific entries. Find an MRRT token.
                mrrt_tokens = (x for x in potential_entries if x[TokenResponseFields.IS_MRRT])
                token = next(mrrt_tokens, None)
                if token:
                    self._log.debug('Found an MRRT token.')
                    return_val = token
                else:
                    self._log.debug('No MRRT tokens found.')
            elif len(resource_tenant_specific_entries) == 1:
                self._log.debug('Resource specific token found.')
                return_val = resource_tenant_specific_entries[0]
                is_resource_tenant_specific = True
            else:
                raise AdalError('More than one token matches the criteria. The result is ambiguous.')

        if return_val:
            self._log.debug('Returning token from cache lookup, %(token_hash)s',
                            {"token_hash": _create_token_id_message(return_val)})

        return return_val, is_resource_tenant_specific

    def _create_entry_from_refresh(self, entry, refresh_response):
        new_entry = copy.deepcopy(entry)
        new_entry.update(refresh_response)

        # It is possible the response payload has no 'resource' field, like in ADFS, so we manually 
        # fill it here. Note, 'resource' is part of the token cache key, so we have to set it to avoid
        # corrupting the cache.
        if 'resource' not in refresh_response:
            new_entry['resource'] = self._resource

        if entry[TokenResponseFields.IS_MRRT] and self._authority != entry[TokenResponseFields._AUTHORITY]:
            new_entry[TokenResponseFields._AUTHORITY] = self._authority

        self._log.debug('Created new cache entry from refresh response.')
        return new_entry

    def _replace_entry(self, entry_to_replace, new_entry):
        self.remove(entry_to_replace)
        self.add(new_entry)

    def _refresh_expired_entry(self, entry):
        token_response = self._refresh_function(entry, None)
        new_entry = self._create_entry_from_refresh(entry, token_response)
        self._replace_entry(entry, new_entry)
        self._log.info('Returning token refreshed after expiry.')
        return new_entry

    def _acquire_new_token_from_mrrt(self, entry):
        token_response = self._refresh_function(entry, self._resource)
        new_entry = self._create_entry_from_refresh(entry, token_response)
        self.add(new_entry)
        self._log.info('Returning token derived from mrrt refresh.')
        return new_entry

    def _refresh_entry_if_necessary(self, entry, is_resource_specific):
        expiry_date = parser.parse(entry[TokenResponseFields.EXPIRES_ON])
        now = datetime.now(expiry_date.tzinfo)
            
        # Add some buffer in to the time comparison to account for clock skew or latency.
        now_plus_buffer = now + timedelta(minutes=Misc.CLOCK_BUFFER)

        if is_resource_specific and now_plus_buffer > expiry_date:
            if TokenResponseFields.REFRESH_TOKEN in entry:
                self._log.info('Cached token is expired at %(date)s.  Refreshing',
                               {"date": expiry_date})
                return self._refresh_expired_entry(entry)
            else:
                self.remove(entry)
                return None
        elif not is_resource_specific and entry.get(TokenResponseFields.IS_MRRT):
            if TokenResponseFields.REFRESH_TOKEN in entry:
                self._log.info('Acquiring new access token from MRRT token.')
                return self._acquire_new_token_from_mrrt(entry)
            else:
                self.remove(entry)
                return None
        else:
            return entry

    def find(self, query):
        if query is None:
            query = {}
        self._log.debug('finding with query keys: %(query)s',
                        {"query": log.scrub_pii(query)})
        entry, is_resource_tenant_specific = self._load_single_entry_from_cache(query)
        if entry:
            return self._refresh_entry_if_necessary(entry, 
                                                    is_resource_tenant_specific)
        else:
            return None

    def remove(self, entry):
        self._log.debug('Removing entry.')
        self._cache.remove([entry])

    def _remove_many(self, entries):
        self._log.debug('Remove many: %(number)s', {"number": len(entries)})
        self._cache.remove(entries)

    def _add_many(self, entries):
        self._log.debug('Add many: %(number)s', {"number": len(entries)})
        self._cache.add(entries)

    def _update_refresh_tokens(self, entry):
        if _is_mrrt(entry) and entry.get(TokenResponseFields.REFRESH_TOKEN):
            mrrt_tokens = self._find_mrrt_tokens_for_user(entry.get(TokenResponseFields.USER_ID))
            if mrrt_tokens:
                self._log.debug('Updating %(number)s cached refresh tokens',
                                {"number": len(mrrt_tokens)})
                self._remove_many(mrrt_tokens)
               
                for t in mrrt_tokens:
                    t[TokenResponseFields.REFRESH_TOKEN] = entry[TokenResponseFields.REFRESH_TOKEN]

                self._add_many(mrrt_tokens)

    def _argument_entry_with_cached_metadata(self, entry):
        if _entry_has_metadata(entry):
            return

        if _is_mrrt(entry):
            self._log.debug('Added entry is MRRT')
            entry[TokenResponseFields.IS_MRRT] = True
        else:
            entry[TokenResponseFields.RESOURCE] = self._resource

        entry[TokenResponseFields._CLIENT_ID] = self._client_id
        entry[TokenResponseFields._AUTHORITY] = self._authority

    def add(self, entry):
        self._log.debug('Adding entry %(token_hash)s',
                        {"token_hash": _create_token_id_message(entry)})
        self._argument_entry_with_cached_metadata(entry)
        self._update_refresh_tokens(entry)
        self._cache.add([entry])
