1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
# -*- coding: utf-8 -*-
from ..invalidkeyidexception import InvalidKeyIdException
from ..invalidkeyexception import InvalidKeyException
from ..invalidmessageexception import InvalidMessageException
from ..duplicatemessagexception import DuplicateMessageException
from ..nosessionexception import NoSessionException
from ..protocol.senderkeymessage import SenderKeyMessage
from ..sessioncipher import AESCipher
from ..groups.state.senderkeystore import SenderKeyStore
class GroupCipher:
def __init__(self, senderKeyStore, senderKeyName):
"""
:type senderKeyStore: SenderKeyStore
:type senderKeyName: SenderKeyName
"""
self.senderKeyStore = senderKeyStore
self.senderKeyName = senderKeyName
def encrypt(self, paddedPlaintext):
"""
:type paddedPlaintext: bytes
"""
try:
record = self.senderKeyStore.loadSenderKey(self.senderKeyName)
senderKeyState = record.getSenderKeyState()
senderKey = senderKeyState.getSenderChainKey().getSenderMessageKey()
ciphertext = self.getCipherText(senderKey.getIv(), senderKey.getCipherKey(), paddedPlaintext)
senderKeyMessage = SenderKeyMessage(senderKeyState.getKeyId(),
senderKey.getIteration(),
ciphertext,
senderKeyState.getSigningKeyPrivate())
senderKeyState.setSenderChainKey(senderKeyState.getSenderChainKey().getNext())
self.senderKeyStore.storeSenderKey(self.senderKeyName, record)
return senderKeyMessage.serialize()
except InvalidKeyIdException as e:
raise NoSessionException(e)
def decrypt(self, senderKeyMessageBytes):
"""
:type senderKeyMessageBytes: bytearray
"""
try:
record = self.senderKeyStore.loadSenderKey(self.senderKeyName)
if record.isEmpty():
raise NoSessionException("No sender key for: %s" % self.senderKeyName)
senderKeyMessage = SenderKeyMessage(serialized = bytes(senderKeyMessageBytes))
senderKeyState = record.getSenderKeyState(senderKeyMessage.getKeyId())
senderKeyMessage.verifySignature(senderKeyState.getSigningKeyPublic())
senderKey = self.getSenderKey(senderKeyState, senderKeyMessage.getIteration())
plaintext = self.getPlainText(senderKey.getIv(), senderKey.getCipherKey(), senderKeyMessage.getCipherText())
self.senderKeyStore.storeSenderKey(self.senderKeyName, record)
return plaintext
except (InvalidKeyException, InvalidKeyIdException) as e:
raise InvalidMessageException(e)
def getSenderKey(self, senderKeyState, iteration):
senderChainKey = senderKeyState.getSenderChainKey()
if senderChainKey.getIteration() > iteration:
if senderKeyState.hasSenderMessageKey(iteration):
return senderKeyState.removeSenderMessageKey(iteration)
else:
raise DuplicateMessageException("Received message with old counter: %s, %s" %
(senderChainKey.getIteration(), iteration))
if senderChainKey.getIteration() - iteration > 2000:
raise InvalidMessageException("Over 2000 messages into the future!")
while senderChainKey.getIteration() < iteration:
senderKeyState.addSenderMessageKey(senderChainKey.getSenderMessageKey())
senderChainKey = senderChainKey.getNext()
senderKeyState.setSenderChainKey(senderChainKey.getNext())
return senderChainKey.getSenderMessageKey()
def getPlainText(self, iv, key, ciphertext):
"""
:type iv: bytearray
:type key: bytearray
:type ciphertext: bytearray
"""
try:
cipher = AESCipher(key, iv)
plaintext = cipher.decrypt(ciphertext)
return plaintext
except Exception as e:
raise InvalidMessageException(e)
def getCipherText(self, iv, key, plaintext):
"""
:type iv: bytearray
:type key: bytearray
:type plaintext: bytearray
"""
cipher = AESCipher(key, iv)
return cipher.encrypt(plaintext)
|