#!/usr/bin/python
# vim: ts=4 sts=4 et:
# pylint: disable=invalid-name,line-too-long
"""
OpenSSH AuthorizedKeysCommand: NSSCache input
Copyright 2016 Gentoo Foundation
Written by Robin H. Johnson <robbat2@gentoo.org>
Distributed under the BSD-3 license.

This script returns one or more authorized keys for use by SSH, by extracting
them from a local cache file /etc/sshkey.cache.

Two variants are supported, based on the existing nsscache code:
Format 1:
 username:key1
 username:key2
Format 2:
 username:['key1', 'key2']

Ensure this script is mentioned in the sshd_config like so:
AuthorizedKeysCommand /path/to/nsscache/authorized-keys-command.py

If you have sufficently new OpenSSH, you can also narrow down the search:
AuthorizedKeysCommand /path/to/nsscache/authorized-keys-command.py --username="%u" --key-type="%t" --key-fingerprint="%f" --key-blob="%k"

Future improvements:
- Validate SSH keys more strictly:
    - validate options string
    - validate X509 cert strings
- Implement command line options to:
    - filter keys based on options better (beyond regex)
    - filter keys based on comments better (beyond regex)
    - filter X509 keys based on DN/subject
    - support multiple inputs for conditions
    - add an advanced conditional filter language
"""
from __future__ import print_function
from ast import literal_eval
import sys
import errno
import argparse
import re
import base64
import hashlib
import copy
import textwrap

DEFAULT_SSHKEY_CACHE = '/etc/sshkey.cache'

REGEX_BASE64 = r'(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?'
# All of the SSH blobs starts with 3 null bytes , which encode to 'AAAA' in base64
REGEX_BASE64_START3NULL = r'AAAA' + REGEX_BASE64
# This regex needs a lot of work
KEYTYPE_REGEX_STRICT = r'\b(?:ssh-(?:rsa|dss|ed25519)|ecdsa-sha2-nistp(?:256|384|521))\b'
# Docs:
# http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xhtml#ssh-parameters-19
# RFC6187, etc
KEYTYPE_REGEX_LAZY_NOX509 = r'\b(?:(?:spki|pgp|x509|x509v3)-)?(?:(?:ssh|sign)-(?:rsa|dss|ed25519)|ecdsa-[0-9a-z-]+|rsa2048-sha256)(?:-cert-v01@openssh\.com|\@ssh\.com)?\b'
KEYTYPE_REGEX_LAZY_X509 = r'\bx509(?:v3)?-(?:(?:ssh|sign)-(?:rsa|dss|ed25519)|ecdsa-[0-9a-z-]+|rsa2048-sha256)(?:-cert-v01@openssh\.com|\@ssh\.com)?\b'
X509_WORDDN = r'(?:(?i)(?:Distinguished[ _-]?Name|DN|Subject)[=:]?)' # case insensitive!
KEY_REGEX = r'(.*)\s*(?:(' + KEYTYPE_REGEX_LAZY_NOX509 + r')\s+(' + REGEX_BASE64_START3NULL + r')\s*(.*)|(' + KEYTYPE_REGEX_LAZY_X509 + r')\s+('+ X509_WORDDN +'.*))'
# Group 1: options
# Branch 1:
#  Group 2: keytype (any, including x509)
#  Group 3: key blob (non-x509), always starts with AAAA (3 nulls in base64), no whitespace!
#  Group 4: comment (non-x509)
# Branch 2:
#  Group 5: keytype (x509)
#  Group 6: x509 WORDDN followed by x509-specific blob or DN, including whitespace
#
# If the keytype is x509v3-*, then the data block can actually be a certificate
# XOR a base64 block.
# The cert specifier is "DN:/OU=.../SN=.../C=.." etc. By implication, this
# EXCLUDEs the use of an comments, as you CANNOT detect when the DN ends.

def warning(*objs):
    """ Helper function for output to stderr. """
    print("WARNING: ", *objs, file=sys.stderr)

def parse_key(full_key_line):
    """
    Explode an authorized_keys line including options into the various parts.
    """
    #print(KEY_REGEX)
    m = re.match(KEY_REGEX, full_key_line)
    if m is None:
        warning("Failed to match", full_key_line)
        return (None, None, None, None)
    options = m.group(1)
    key_type = m.group(2)
    blob = m.group(3)
    comment = m.group(4)
    if m.group(5) is not None:
        key_type = m.group(5)
        blob = m.group(6)
        comment = None
    return (options, key_type, blob, comment)

def fingerprint_key(keyblob, fingerprint_format='SHA256'):
    """
    Generate SSH key fingerprints, using the requested format.
    """
    # Don't try to fingerprint x509 blobs
    if keyblob is None or not keyblob.startswith('AAAA'):
        return None
    try:
        binary_blob = base64.b64decode(keyblob)
    except TypeError as e:
        warning(e, keyblob)
        return None
    if fingerprint_format == 'MD5':
        raw = hashlib.md5(binary_blob).digest()
        return 'MD5:'+":".join("{:02x}".format(ord(c)) for c in raw)
    elif fingerprint_format in ['SHA256', 'SHA512', 'SHA1']:
        h = hashlib.new(fingerprint_format)
        h.update(binary_blob)
        raw = h.digest()
        return fingerprint_format+":"+base64.b64encode(raw).rstrip('=')
    return None

def detect_fingerprint_format(fpr):
    """
    Given a fingerprint, try to detect what fingerprint format is used.
    """
    if fpr is None:
        return None
    for prefix in ['SHA256', 'SHA512', 'SHA1', 'MD5']:
        if fpr.startswith(prefix+':'):
            return prefix
    if re.match(r'^(MD5:)?([0-9a-f]{2}:)+[0-9a-f]{2}$', fpr) is not None:
        return 'MD5'
    # Cannot detect the format
    return None

def validate_key(candidate_key, conditions, strict=False):
    # pylint: disable=invalid-name,line-too-long,too-many-locals
    """
    Validate a potential authorized_key line against multiple conditions
    """
    # Explode the key
    (candidate_key_options, \
            candidate_key_type, \
            candidate_key_blob, \
            candidate_key_comment) = parse_key(candidate_key)

    # Set up our conditions with their defaults
    key_type = conditions.get('key_type', None)
    key_blob = conditions.get('key_blob', None)
    key_fingerprint = conditions.get('key_fingerprint', None)
    key_options_re = conditions.get('key_options_re', None)
    key_comment_re = conditions.get('key_comment_re', None)

    # Try to detect the fingerprint format
    fingerprint_format = detect_fingerprint_format(key_fingerprint)
    # Force MD5 prefix on old fingerprints
    if fingerprint_format is 'MD5':
        if not key_fingerprint.startswith('MD5:'):
            key_fingerprint = 'MD5:' + key_fingerprint
    # The OpenSSH base64 fingerprints drops the trailing padding, ensure we do
    # the same on provided input
    if fingerprint_format is not 'MD5' \
            and key_fingerprint is not None:
        key_fingerprint = key_fingerprint.rstrip('=')
    # Build the fingerprint for the candidate key
    # (the func does the padding strip as well)
    candidate_key_fingerprint = \
            fingerprint_key(candidate_key_blob, 
                            fingerprint_format)

    match = True
    strict_pass = False
    if key_type is not None and \
            candidate_key_type is not None:
        strict_pass = True
        match = match and \
                (candidate_key_type == key_type)
    if key_fingerprint is not None and \
            candidate_key_fingerprint is not None:
        strict_pass = True
        match = match and \
                (candidate_key_fingerprint == key_fingerprint)
    if key_blob is not None and \
            candidate_key_blob is not None:
        strict_pass = True
        match = match and \
                (candidate_key_blob == key_blob)
    if key_comment_re is not None and \
            candidate_key_comment is not None:
        strict_pass = True
        match = match and \
                key_comment_re.search(candidate_key_comment) is not None
    if key_options_re is not None:
        strict_pass = True
        match = match and \
                key_options_re.search(candidate_key_options) is not None
    if strict:
        return match and strict_pass
    return match

PROG_EPILOG = textwrap.dedent('''\
Strict match will require that at least one condition matched.
Conditions marked with X may not work correctly with X509 authorized_keys lines.
''')
PROG_DESC = 'OpenSSH AuthorizedKeysCommand to read from cached keys file'

if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='AUTHKEYCMD',
                                     description=PROG_DESC,
                                     epilog=PROG_EPILOG,
                                     formatter_class=argparse.RawDescriptionHelpFormatter,
                                     add_help=False)
    # Arguments
    group = parser.add_argument_group('Mandatory arguments')
    group.add_argument('username', metavar='USERNAME',
                       nargs='?',
                       type=str, help='Username')
    group.add_argument('--username', metavar='USERNAME',
                       dest='username_opt',
                       type=str, help='Username (alternative form)')
    # Conditions
    group = parser.add_argument_group('Match Conditions (optional)')
    group.add_argument('--key-type', metavar='KEY-TYPE',
                       type=str, help='Key type')
    group.add_argument('--key-fingerprint', '--key-fp', metavar='KEY-FP',
                       type=str, help='Key fingerprint X')
    group.add_argument('--key-blob', metavar='KEY-BLOB',
                       type=str, help='Key blob (Base64 section) X')
    group.add_argument('--key-comment-re', metavar='REGEX',
                       type=str, help='Regex to match on comments X')
    group.add_argument('--key-options-re', metavar='REGEX',
                       type=str, help='Regex to match on options')
    # Setup parameters:
    group = parser.add_argument_group('Misc settings')
    group.add_argument('--cache-file', metavar='FILENAME',
                       default=DEFAULT_SSHKEY_CACHE,
                       type=argparse.FileType('r'),
                       help='Cache file [%s]' % (DEFAULT_SSHKEY_CACHE, ), )
    group.add_argument('--strict', action="store_true",
                       default=False, help='Strict match required')
    group.add_argument('--help', action="help",
                       default=False, help='This help')
    # Fire it all
    args = parser.parse_args()

    # Handle that we support both variants
    lst = [args.username, args.username_opt]
    cnt = lst.count(None)
    if cnt == 2:
        parser.error('Username was not specified')
    elif cnt == 0:
        parser.error('Username must be specified either as an option XOR argument.')
    else:
        args.username = [x for x in lst if x is not None][0]

    # Strict makes no sense without at least one condition being specified
    if args.strict:
        d = copy.copy(vars(args))
        for k in ['cache_file', 'strict', 'username']:
            d.pop(k, None)
        if not any(v is not None for v in d.values()):
            parser.error('At least one condition must be specified with --strict')


    if args.key_comment_re is not None:
        args.key_comment_re = re.compile(args.key_comment_re)
    if args.key_options_re is not None:
        args.key_options_re = re.compile(args.key_options_re)


    try:
        key_conditions = {'key_options_re': args.key_options_re,
                          'key_type': args.key_type,
                          'key_blob': args.key_blob,
                          'key_fingerprint': args.key_fingerprint,
                          'key_comment_re': args.key_comment_re,
                         }
        with args.cache_file as f:
            for line in f:
                (username, key) = line.split(':', 1)
                if username != args.username:
                    continue
                key = key.strip()
                if key.startswith("[") and key.endswith("]"):
                    # Python array, but handle it safely!
                    keys = [i.strip() for i in literal_eval(key)]
                else:
                    # Raw key
                    keys = [key.strip()]
                for k in keys:
                    if validate_key(candidate_key=k,
                                    conditions=key_conditions,
                                    strict=args.strict):
                        print(k)
    except IOError as err:
        if err.errno in [errno.EPERM, errno.ENOENT]:
            pass
        else:
            raise err
