#!/usr/bin/env python
# -*- coding: utf-8 -*-
# test.py
# This file is part of python-otr
#
# Copyright (C) 2008 - Kjell Braden <fnord@pentabarf.de>
#
# python-otr is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# python-otr is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with python-otr; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, 
# Boston, MA  02110-1301  USA
#
 
 
import unittest
import otr
import os

fprint_a = "\x12\x34\x56\x78\x90\xab\xcd\xef\x07\x99\x12\x34\x56\x78\x90\xab\xcd\xef\x07\x99"
fprint_b = "\x99\x07\xef\xcd\xab\x90\x78\x56\x34\x12\x99\x07\xef\xcd\xab\x90\x78\x56\x34\x12"

class No1_OtrTest(unittest.TestCase):
    def __init__(self, *args):
        unittest.TestCase.__init__(self, *args)
        self.path = os.getcwd()

    def test_otrl_context_find(self): 
        ustate = otr.otrl_userstate_create();

        self.failUnlessEqual(otr.otrl_context_find(ustate, "user", "account", "proto", 0), (None, 0));

        (new_ctx, added) = otr.otrl_context_find(ustate, "user", "account", "proto", 1)
        self.failUnlessEqual(added, 1);
        self.failUnlessEqual(type(new_ctx), otr.ConnContext)

        (ctx, added) = otr.otrl_context_find(ustate, "user", "account", "proto", 0)
        self.failUnlessEqual(added, 0);
        self.failUnlessEqual(ctx.this, new_ctx.this)

    def test_otrl_context_find_fingerprint__and__read_write(self):
        ustate = otr.otrl_userstate_create();
        ctx = otr.otrl_context_find(ustate, "user", "account", "proto", 1)[0]

        self.failUnlessEqual(otr.otrl_context_find_fingerprint(ctx, fprint_a, 0), (None, 0))

        (new_fprint, added) = otr.otrl_context_find_fingerprint(ctx, fprint_b, 1)
        self.failUnlessEqual(added, 1)
        self.failUnlessEqual(type(new_fprint), otr.Fingerprint)

        (fprint, added) = otr.otrl_context_find_fingerprint(ctx, fprint_b, 0)
        self.failUnlessEqual(fprint.this, new_fprint.this)

        ctx.fingerprint_root.next.fingerprint = fprint_a
        (fprint, added) = otr.otrl_context_find_fingerprint(ctx, fprint_a, 0)
        self.failUnlessEqual(fprint.this, ctx.fingerprint_root.next.this)

        (new_fprint, added) = otr.otrl_context_find_fingerprint(ctx, fprint_b, 1)

        try:
            otr.otrl_privkey_read_fingerprints(ustate, "some.nonexistant.file")
        except Exception, e:
            self.failUnless(hasattr(e,"os_errno") and e.os_errno == 2, "otrl_privkey_read_fingerprints raised the wrong exception")
        else:
            self.fail("otrl_privkey_read_fingerprints should've raised an exception")

        otr.otrl_privkey_write_fingerprints(ustate, os.path.join(self.path, "test.fpr"))

        new_ustate = otr.otrl_userstate_create();
        otr.otrl_privkey_read_fingerprints(new_ustate, os.path.join(self.path, "test.fpr"))


        list1 = []
        list2 = []
        for (u, l) in [(ustate, list1), (new_ustate, list2)]:
            c = u.context_root
            while c:
                f = c.fingerprint_root
                while f:
                    human_readable = None
                    if f.fingerprint:
                        human_readable = otr.otrl_privkey_hash_to_human(f.fingerprint)
                    l.append(( f.context.accountname, f.context.protocol, human_readable))
                    f = f.next
                c = c.next
   
        self.failIfEqual(list1, [], "writing / reading failed")
        self.failUnlessEqual(sorted(list1), sorted(list2), "writing / reading failed")

    def test_otrl_context_set_trust(self):
        ustate = otr.otrl_userstate_create();
        ctx = otr.otrl_context_find(ustate, "user", "account", "proto", 1)[0]
        fprint = otr.otrl_context_find_fingerprint(ctx, fprint_a, 1)[0]

        for i in ["smp", "verified", ""]:
            otr.otrl_context_set_trust(fprint, i);
            self.failUnlessEqual(fprint.trust, i)

    def test_otrl_privkey_hash_to_human(self):

        list = [("\x01\x23\x45\x67\x12\x34\x56\x78\x23\x45\x67\x89\x34\x56\x78\x90\x45\x67\x89\x01", "01234567 12345678 23456789 34567890 45678901"),
            ("\xfa\xe9\xa3\xc0\x56\xe2\x19\x55\x74\xfd\x05\xf6\xd1\x9d\x64\xaa\x89\xb8\x2a\xf4", "FAE9A3C0 56E21955 74FD05F6 D19D64AA 89B82AF4"),
            ('\x00'*20, "00000000 00000000 00000000 00000000 00000000"),
            ('\xff'*20, "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF"),
            ('\x30\x67\x8b\xf0\xa6\xcd\xe7\x73\x2d\x8f\x78\x31\x68\x5c\xf5\xcc\x0f\x88\xa6\xc1', '30678BF0 A6CDE773 2D8F7831 685CF5CC 0F88A6C1')]

        for hash,human in list:
            self.failUnlessEqual(otr.otrl_privkey_hash_to_human(hash), human)

    def test_otrl_privkey_read__and__generate(self):
        ustate = otr.otrl_userstate_create();

        try:
            otr.otrl_privkey_read(ustate, "some.nonexistant.file")
        except Exception, e:
            self.failUnless(hasattr(e,"os_errno") and e.os_errno == 2, "otrl_privkey_read raised the wrong exception")
        else:
            self.fail("otrl_privkey_read should've raised an exception")

        otr.otrl_privkey_generate(ustate, os.path.join(self.path,"test.key"), "accountname", "protocol")
        pk = ustate.privkey_root
        n = 0
        while pk:
            self.failUnlessEqual(otr.otrl_privkey_hash_to_human(otr.otrl_privkey_fingerprint_raw(ustate, pk.accountname, pk.protocol)), otr.otrl_privkey_fingerprint(ustate, pk.accountname, pk.protocol))
            n+=1
            pk = pk.next

        self.failUnlessEqual(n, 1, "Wrong number of private keys")

        new_ustate = otr.otrl_userstate_create();
        otr.otrl_privkey_read(new_ustate, os.path.join(self.path,"test.key"))

        pk = ustate.privkey_root
        n_pk = new_ustate.privkey_root
        while pk and n_pk:
            self.failUnlessEqual(pk.pubkey_data, n_pk.pubkey_data)
            pk = pk.next
            n_pk = n_pk.next

        self.failIf(pk or n_pk, "Number of private keys differs! (pk = %r, n_pk = %r)"%(pk, n_pk))

    # TODO I will have to use predefined keys and check the return values against static ones
    # TODO     since I can't predict easily what a public key's hash will be
    # TODO otrl_privkey_fingerprint 
    # TODO otrl_privkey_fingerprint_raw 


class No2_OtrCommTest(unittest.TestCase):
    def __init__(self, *args):
        unittest.TestCase.__init__(self, *args)
        self.path = os.getcwd()

    def policy(self, opdata=None, context=None):
        return getattr(self, opdata["me"]+"_policy")
    def create_privkey(self, opdata="", accountname="", protocol=""):
        otr.otrl_privkey_generate(opdata['ustate'], os.path.join(self.path,str(opdata['me'])+".key"), accountname, protocol)
    def is_logged_in(self, opdata="", accountname="", protocol="", recipient=""):
        return 1
    def inject_message(self, opdata=None, accountname="", protocol="", recipient="", message=""):
        opdata['recv'](message)
    def notify(sef, opdata=None, level=None, accountname="", protocol="", username="", title="", primary="", secondary=""):
        pass
    def display_otr_message(self, **kwargs):
        return 0
    def update_context_list(self, **kwargs):
        pass
    def protocol_name(self, opdata=None, protocol=""):
        return "human-"+str(protocol)
    def new_fingerprint(self, **kwargs):
        pass
    def write_fingerprints(self, opdata=""):
        otr.otrl_privkey_write_fingerprints(opdata['ustate'], os.path.join(self.path,str(opdata['me'])+".fps"))
    def gone_secure(self, opdata="", context=None):
        pass
    def gone_insecure(self, opdata="", context=None):
        pass
    def still_secure(self, opdata=None, context=None, is_reply=0):
        pass
    def log_message(self, opdata=None, message=""):
        pass
    def max_message_size(self, **kwargs):
        return 0
    def account_name(self, opdata=None, accountname="",protocol=""):
        return "human-"+str(accountname)+"@"+str(protocol)

    def test_01otrl_message(self):
        """ implements some basic messaging """

        # create the userstates
        self.ustate_a = otr.otrl_userstate_create()
        self.ustate_b = otr.otrl_userstate_create()

        # this is the callback data passed to the callbacks above, used to "deliver" the messages
        self.data_a = {"ustate":self.ustate_a, "partner":self.recipient_a,"me":self.account_a, "protocol":self.protocol, "self":self, "recv":self.b_recv}
        self.data_b = {"ustate":self.ustate_b, "partner":self.recipient_b,"me":self.account_b, "protocol":self.protocol, "self":self, "recv":self.a_recv}

        # generate the keys
        otr.otrl_privkey_generate(self.ustate_a, os.path.join(self.path,"Alice.key"), self.account_a, self.protocol)
        otr.otrl_privkey_generate(self.ustate_b, os.path.join(self.path,"Bob.key"), self.account_b, self.protocol)

        # create the contexts, fail if they are not PLAINTEXT
        ctx_a = otr.otrl_context_find(self.ustate_a, self.recipient_a, self.account_a, self.protocol, 1)[0]
        ctx_b = otr.otrl_context_find(self.ustate_b, self.recipient_b, self.account_b, self.protocol, 1)[0]
        self.failUnlessEqual(ctx_a.msgstate, otr.OTRL_MSGSTATE_PLAINTEXT)
        self.failUnlessEqual(ctx_b.msgstate, otr.OTRL_MSGSTATE_PLAINTEXT)

        # initialise the queues
        self.queue = []
        self.a_received = []
        self.b_received = []
        # set the policies to NEVER ENCRYPT
        self.Alice_policy = otr.OTRL_POLICY_NEVER
        self.Bob_policy = otr.OTRL_POLICY_NEVER
        # send / receive a plaintext message from A to B
        otr.otrl_message_fragment_and_send((self, self.data_a), ctx_a, self.a_send("Hello."), otr.OTRL_FRAGMENT_SEND_ALL)
        self.do_queue()
        #  check the message in Bs queue
        self.failUnlessEqual(self.b_received, [(0, "Hello.", None, "plaintext")], "plaintext message exchange failed")

        # clear the queues
        self.a_received = []
        self.b_received = []
        # set the policy to "check whether the other one understands OTR"
        self.Alice_policy = otr.OTRL_POLICY_OPPORTUNISTIC
        self.Bob_policy = otr.OTRL_POLICY_OPPORTUNISTIC
        # send / receive a plaintext message from B to A
        otr.otrl_message_fragment_and_send((self, self.data_b), ctx_b, self.b_send("Hello!"), otr.OTRL_FRAGMENT_SEND_ALL)
        self.do_queue()
        # check the message in As queue, it has to be plaintext but the state should be ENCRYPTED afterwards
        self.failUnlessEqual(self.a_received, [(0, "Hello!", None, "plaintext")], "opportunistic message exchange failed")
        self.failUnlessEqual(ctx_a.msgstate, otr.OTRL_MSGSTATE_ENCRYPTED, "opportunistic message exchange failed")
        self.failUnlessEqual(ctx_b.msgstate, otr.OTRL_MSGSTATE_ENCRYPTED, "opportunistic message exchange failed")

        # clear the queues
        self.a_received = []
        self.b_received = []
        # send / receive an encrypted message from A to B
        otr.otrl_message_fragment_and_send((self, self.data_a), ctx_a, self.a_send("are we encrypted?"), otr.OTRL_FRAGMENT_SEND_ALL)
        self.do_queue()
        # check for integrity at B
        self.failUnlessEqual(self.b_received, [(0, "are we encrypted?", None, "not verified")], "encrypted message exchange failed")

        # clear the queues
        self.a_received = []
        self.b_received = []
        # test disconnecting
        otr.otrl_message_disconnect(self.ustate_b, (self, self.data_b), self.account_b, self.protocol, self.recipient_b)
        self.do_queue()
        # the "disconnector" is always PLAINTEXT, his partner is FINISHED
        self.failUnlessEqual(ctx_a.msgstate, otr.OTRL_MSGSTATE_FINISHED, "disconnecting the otr session failed")
        self.failUnlessEqual(ctx_b.msgstate, otr.OTRL_MSGSTATE_PLAINTEXT, "disconnecting the otr session failed")
        
        # clear the queues
        self.a_received = []
        self.b_received = []
        self.Alice_policy = otr.OTRL_POLICY_ALWAYS
        self.Bob_policy = otr.OTRL_POLICY_ALWAYS
        # check for sending messages in FINISHED state (no need to do_queue, as it has to be empty)
        otr.otrl_message_fragment_and_send((self, self.data_a), ctx_a, self.a_send("huh?"), otr.OTRL_FRAGMENT_SEND_ALL)
        self.failUnlessEqual(self.queue, [], "must not send messages in the FINISHED state")

        # restart the encrypted messaging 
        self.b_recv("?OTR?")
        self.do_queue()
        self.failUnlessEqual(ctx_a.msgstate, otr.OTRL_MSGSTATE_ENCRYPTED, "failed to start OTR session")
        self.failUnlessEqual(ctx_b.msgstate, otr.OTRL_MSGSTATE_ENCRYPTED, "failed to start OTR session")

    def do_queue(self):
        while len(self.queue) > 0:
            if self.queue[0][1] == 0:
                self.a_queue()
            else:
                self.b_queue()


    def a_send(self, msg):
        return otr.otrl_message_sending(self.ustate_a, (self, self.data_a), self.account_a, self.protocol, self.recipient_a, msg, None)

    def a_recv(self, msg):
        if msg:
            self.queue.append((msg,0))
    
    def a_queue(self):
        msg = self.queue.pop(0)[0]
        t = otr.otrl_message_receiving(self.ustate_a, (self, self.data_a), self.account_a, self.protocol, self.recipient_a, msg)
        if t[0] == 0:
            ctx = otr.otrl_context_find(self.data_a['ustate'], self.data_a["partner"], self.data_a["me"], self.data_a["protocol"], 0)[0]
            if ctx.msgstate == otr.OTRL_MSGSTATE_ENCRYPTED and ctx.active_fingerprint.trust:
                t = t+(ctx.active_fingerprint.trust,)
            if ctx.msgstate == otr.OTRL_MSGSTATE_ENCRYPTED:
                t = t+("not verified",)
            else:
                t = t+("plaintext",)
            self.a_received.append(t)

    def b_send(self, msg):
        return otr.otrl_message_sending(self.ustate_b, (self, self.data_b), self.account_b, self.protocol, self.recipient_b, msg, None)

    def b_recv(self, msg):
        if msg:
            self.queue.append((msg,1))
    
    def b_queue(self):
        msg = self.queue.pop(0)[0]
        t = otr.otrl_message_receiving(self.ustate_b, (self, self.data_b), self.account_b, self.protocol, self.recipient_b, msg)
        if t[0] == 0:
            ctx = otr.otrl_context_find(self.data_b['ustate'], self.data_b["partner"], self.data_b["me"], self.data_b["protocol"], 0)[0]
            if ctx.msgstate == otr.OTRL_MSGSTATE_ENCRYPTED and ctx.active_fingerprint.trust:
                t = t+(ctx.active_fingerprint.trust,)
            if ctx.msgstate == otr.OTRL_MSGSTATE_ENCRYPTED:
                t = t+("not verified",)
            else:
                t = t+("plaintext",)
            self.b_received.append(t)

    account_a = "Alice"
    account_b = "Bob"
    protocol = "P"
    recipient_a = account_b
    recipient_b = account_a

if __name__ == "__main__": 
    unittest.main()
