#! @PYTHON@
# ------------------------------------------------------------------
#
#   Copyright (C) 2013 Canonical Ltd.
#   Author: Steve Beattie <steve@nxnw.org>
#
#   This program is free software; you can redistribute it and/or
#   modify it under the terms of version 2 of the GNU General Public
#   License published by the Free Software Foundation.
#
# ------------------------------------------------------------------

import ctypes
import os
import unittest

import LibAppArmor as libapparmor

TESTDIR = "../../../testsuite/test_multi"

# map of testsuite .out entries that aren't a simple to_lower() and
# s/ /_/ of the field name

OUTPUT_MAP = {
    'Event type':       'event',
    'Mask':             'requested_mask',
    'Command':          'comm',
    'Token':            'magic_token',
    'ErrorCode':        'error_code',
    'Network family':   'net_family',
    'Socket type':      'net_sock_type',
    'Protocol':         'net_protocol',
    'Local addr':       'net_local_addr',
    'Foreign addr':     'net_foreign_addr',
    'Local port':       'net_local_port',
    'Foreign port':     'net_foreign_port',
    'Audit subid':      'audit_sub_id',
    'Class':            '_class',
}

# FIXME: pull this automatically out of LibAppArmor, but swig
# classes aren't real enough. (Perhaps we're not using swig correctly)
EVENT_MAP = {
    libapparmor.AA_RECORD_ALLOWED:  'AA_RECORD_ALLOWED',
    libapparmor.AA_RECORD_AUDIT:    'AA_RECORD_AUDIT',
    libapparmor.AA_RECORD_DENIED:   'AA_RECORD_DENIED',
    libapparmor.AA_RECORD_ERROR:    'AA_RECORD_ERROR',
    libapparmor.AA_RECORD_HINT:     'AA_RECORD_HINT',
    libapparmor.AA_RECORD_INVALID:  'AA_RECORD_INVALID',
    libapparmor.AA_RECORD_STATUS:   'AA_RECORD_STATUS',
}

# default is None if not in this table
NO_VALUE_MAP = {
    'fsuid': int(ctypes.c_ulong(-1).value),
    'ouid':  int(ctypes.c_ulong(-1).value),
}
class AAPythonBindingsTests(unittest.TestCase):

    def setUp(self):
        # REPORT ALL THE OUTPUT
        self.maxDiff = None

    def test_aa_splitcon(self):
        AA_SPLITCON_EXPECT = [
            ("unconfined", "unconfined", None),
            ("unconfined\n", "unconfined", None),
            ("/bin/ping (enforce)", "/bin/ping", "enforce"),
            ("/bin/ping (enforce)\n", "/bin/ping", "enforce"),
            ("/usr/sbin/rsyslog (complain)", "/usr/sbin/rsyslog", "complain"),
        ]
        for context, expected_label, expected_mode in AA_SPLITCON_EXPECT:
            actual_label, actual_mode = libapparmor.aa_splitcon(context)
            if expected_label is None:
                self.assertIsNone(actual_label)
            else:
                self.assertIsInstance(actual_label, str)
                self.assertEqual(expected_label, actual_label)

            if expected_mode is None:
                self.assertIsNone(actual_mode)
            else:
                self.assertIsInstance(actual_mode, str)
                self.assertEqual(expected_mode, actual_mode)

        with self.assertRaises(ValueError):
            libapparmor.aa_splitcon("")

    def test_aa_is_enabled(self):
        aa_enabled = libapparmor.aa_is_enabled()
        self.assertIsInstance(aa_enabled, bool)

    @unittest.skipUnless(libapparmor.aa_is_enabled(), "AppArmor is not enabled")
    def test_aa_find_mountpoint(self):
        mount_point = libapparmor.aa_find_mountpoint()
        self.assertIsInstance(mount_point, str)
        self.assertGreater(len(mount_point), 0, "mount point should not be empty")
        self.assertTrue(os.path.isdir(mount_point))

    # TODO: test commented out functions (or at least their prototypes)
    # extern int aa_change_profile(const char *profile);
    # extern int aa_change_onexec(const char *profile);
    @unittest.skipUnless(libapparmor.aa_is_enabled(), "AppArmor is not enabled")
    def test_change_hats(self):
        # Changing hats will fail because we have no valid hats to change to
        # However, we still verify that we get an OSError instead of a TypeError
        with self.assertRaises(OSError):
            libapparmor.aa_change_hat("nonexistent_profile", 12345678)

        with self.assertRaises(OSError):
            libapparmor.aa_change_hatv(["nonexistent_1", "nonexistent_2"], 0xabcdef)
            libapparmor.aa_change_hatv(("nonexistent_1", "nonexistent_2"), 0xabcdef)

    # extern int aa_stack_profile(const char *profile);
    # extern int aa_stack_onexec(const char *profile);
    # extern int aa_getprocattr(pid_t tid, const char *attr, char **label, char **mode);
    # extern int aa_gettaskcon(pid_t target, char **label, char **mode);

    @unittest.skipUnless(libapparmor.aa_is_enabled(), "AppArmor is not enabled")
    def test_aa_gettaskcon(self):
        # Our test harness should be running us as unconfined
        # Get our own pid and this should be equivalent to aa_getcon
        pid = os.getpid()

        label, mode = libapparmor.aa_gettaskcon(pid)
        self.assertEqual(label, "unconfined", "aa_gettaskcon label should be unconfined")
        self.assertIsNone(mode, "aa_gettaskcon mode should be unconfined")

    @unittest.skipUnless(libapparmor.aa_is_enabled(), "AppArmor is not enabled")
    def test_aa_getcon(self):
        # Our test harness should be running us as unconfined
        label, mode = libapparmor.aa_getcon()
        self.assertEqual(label, "unconfined", "aa_getcon label should be unconfined")
        self.assertIsNone(mode, "aa_getcon mode should be unconfined")

    # extern int aa_getpeercon(int fd, char **label, char **mode);

    # extern int aa_query_file_path(uint32_t mask, const char *label,
    #                 const char *path, int *allowed, int *audited);
    @unittest.skipUnless(libapparmor.aa_is_enabled(), "AppArmor is not enabled")
    def test_aa_query_file_path(self):
        aa_query_mask = libapparmor.AA_MAY_EXEC | libapparmor.AA_MAY_READ | libapparmor.AA_MAY_WRITE
        allowed, audited = libapparmor.aa_query_file_path(aa_query_mask, "unconfined", "/tmp/hello")
        self.assertTrue(allowed)
        self.assertFalse(audited)
    # extern int aa_query_link_path(const char *label, const char *target,
    #                 const char *link, int *allowed, int *audited);


class AALogParsePythonBindingsTests(unittest.TestCase):

    def setUp(self):
        # REPORT ALL THE OUTPUT
        self.maxDiff = None

    def _runtest(self, testname):
        infile = testname + ".in"
        outfile = testname + ".out"
        # infile *should* only contain one line
        with open(os.path.join(TESTDIR, infile), 'r') as f:
            line = f.read()
            swig_record = libapparmor.parse_record(line)

        record = self.create_record_dict(swig_record)
        record['file'] = infile
        libapparmor.free_record(swig_record)

        expected = self.parse_output_file(outfile)
        self.assertEqual(expected, record,
                         "expected records did not match\n"
                         "expected = {}\nactual = {}".format(expected, record))

    def parse_output_file(self, outfile):
        """parse testcase .out file and return dict"""

        output = dict()
        with open(os.path.join(TESTDIR, outfile), 'r') as f:
            lines = f.readlines()

        count = 0
        for l in lines:
            line = l.rstrip('\n')
            count += 1
            if line == "START":
                self.assertEqual(count, 1,
                                 "Unexpected output format in " + outfile)
                continue
            else:
                key, value = line.split(": ", 1)
                if key in OUTPUT_MAP:
                    newkey = OUTPUT_MAP[key]
                else:
                    newkey = key.lower().replace(' ', '_')
                # check if entry already exists?
                output[newkey] = value

        return output

    def create_record_dict(self, record):
        """parse the swig created record and construct a dict from it"""

        new_record = dict()
        for key in [x for x in dir(record) if not (x.startswith('__') or x == 'this')]:
            value = getattr(record, key)
            if key == "event" and value in EVENT_MAP:
                new_record[key] = EVENT_MAP[value]
            elif key == "version":
                # FIXME: out files should report log version?
                # FIXME: or can we just deprecate v1 logs?
                continue
            elif key == "thisown":
                # SWIG generates this key to track memory allocation
                continue
            elif key in NO_VALUE_MAP:
                if NO_VALUE_MAP[key] == value:
                    continue
                else:
                    new_record[key] = str(value)
            elif value or value == '':
                new_record[key] = str(value)

        return new_record


def find_testcases(testdir):
    """dig testcases out of passed directory"""

    for f in os.listdir(testdir):
        if f.endswith(".in"):
            yield f[:-3]


def main():
    for f in find_testcases(TESTDIR):
        def stub_test(self, testname=f):
            self._runtest(testname)
        stub_test.__doc__ = "test " + f
        setattr(AALogParsePythonBindingsTests, 'test_' + f, stub_test)
    return unittest.main(verbosity=2)


if __name__ == "__main__":
    main()
