1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
|
{-# LANGUAGE FlexibleContexts #-}
module Network.TLS.Record.Decrypt (
decryptRecord,
) where
import Control.Monad.State.Strict
import Crypto.Cipher.Types (AuthTag (..))
import qualified Data.ByteArray as B (convert, xor)
import qualified Data.ByteString as B
import Network.TLS.Cipher
import Network.TLS.Crypto
import Network.TLS.ErrT
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Record.State
import Network.TLS.Record.Types
import Network.TLS.Struct
import Network.TLS.Util
import Network.TLS.Wire
decryptRecord :: Record Ciphertext -> Int -> RecordM (Record Plaintext)
decryptRecord record@(Record ct ver fragment) lim = do
st <- get
case stCipher st of
Nothing -> noDecryption
_ -> do
recOpts <- getRecordOptions
let mver = recordVersion recOpts
if recordTLS13 recOpts
then decryptData13 mver (fragmentGetBytes fragment) st
else onRecordFragment record $ fragmentUncipher $ \e ->
decryptData mver record e st lim
where
noDecryption = onRecordFragment record $ fragmentUncipher $ checkPlainLimit lim
decryptData13 mver e st = case ct of
ProtocolType_AppData -> do
inner <- decryptData mver record e st (lim + 1)
case unInnerPlaintext inner of
Left message -> throwError $ Error_Protocol message UnexpectedMessage
Right (ct', d) -> return $ Record ct' ver $ fragmentPlaintext d
ProtocolType_ChangeCipherSpec -> noDecryption
ProtocolType_Alert -> noDecryption
_ ->
throwError $ Error_Protocol "illegal plain text" UnexpectedMessage
unInnerPlaintext :: ByteString -> Either String (ProtocolType, ByteString)
unInnerPlaintext inner =
case B.unsnoc dc of
Nothing -> Left $ unknownContentType13 (0 :: Word8)
Just (bytes, c)
| B.null bytes && ProtocolType c `elem` nonEmptyContentTypes ->
Left ("empty " ++ show (ProtocolType c) ++ " record disallowed")
| otherwise -> Right (ProtocolType c, bytes)
where
(dc, _pad) = B.spanEnd (== 0) inner
nonEmptyContentTypes = [ProtocolType_Handshake, ProtocolType_Alert]
unknownContentType13 c = "unknown TLS 1.3 content type: " ++ show c
getCipherData :: Record a -> CipherData -> RecordM ByteString
getCipherData (Record pt ver _) cdata = do
-- check if the MAC is valid.
macValid <- case cipherDataMAC cdata of
Nothing -> return True
Just digest -> do
let new_hdr = Header pt ver (fromIntegral $ B.length $ cipherDataContent cdata)
expected_digest <- makeDigest new_hdr $ cipherDataContent cdata
return (expected_digest == digest)
-- check if the padding is filled with the correct pattern if it exists
-- (before TLS10 this checks instead that the padding length is minimal)
paddingValid <- case cipherDataPadding cdata of
Nothing -> return True
Just (pad, _blksz) -> do
let b = B.length pad - 1
return $ B.replicate (B.length pad) (fromIntegral b) == pad
unless (macValid &&! paddingValid) $
throwError $
Error_Protocol "bad record mac Stream/Block" BadRecordMac
return $ cipherDataContent cdata
checkPlainLimit :: Int -> ByteString -> RecordM ByteString
checkPlainLimit lim plain
| len > lim =
throwError $
Error_Protocol
( "plaintext exceeding record size limit: "
++ show len
++ " > "
++ show lim
)
RecordOverflow
| otherwise = return plain
where
len = B.length plain
decryptData
:: Version
-> Record Ciphertext
-> ByteString
-> RecordState
-> Int
-> RecordM ByteString
decryptData ver record econtent tst lim =
decryptOf (cstKey cst) >>= checkPlainLimit lim
where
cipher = fromJust $ stCipher tst
bulk = cipherBulk cipher
cst = stCryptState tst
macSize = hashDigestSize $ cipherHash cipher
blockSize = bulkBlockSize bulk
econtentLen = B.length econtent
sanityCheckError =
throwError
(Error_Packet "encrypted content too small for encryption parameters")
decryptOf :: BulkState -> RecordM ByteString
decryptOf (BulkStateBlock decryptF) = do
let minContent = bulkIVSize bulk + max (macSize + 1) blockSize
-- check if we have enough bytes to cover the minimum for this cipher
when
((econtentLen `mod` blockSize) /= 0 || econtentLen < minContent)
sanityCheckError
{- update IV -}
(iv, econtent') <-
get2o econtent (bulkIVSize bulk, econtentLen - bulkIVSize bulk)
let (content', iv') = decryptF iv econtent'
modify $ \txs -> txs{stCryptState = cst{cstIV = iv'}}
let paddinglength = fromIntegral (B.last content') + 1
let contentlen = B.length content' - paddinglength - macSize
(content, mac, padding) <- get3i content' (contentlen, macSize, paddinglength)
getCipherData
record
CipherData
{ cipherDataContent = content
, cipherDataMAC = Just mac
, cipherDataPadding = Just (padding, blockSize)
}
decryptOf (BulkStateStream (BulkStream decryptF)) = do
-- check if we have enough bytes to cover the minimum for this cipher
when (econtentLen < macSize) sanityCheckError
let (content', bulkStream') = decryptF econtent
{- update Ctx -}
let contentlen = B.length content' - macSize
(content, mac) <- get2i content' (contentlen, macSize)
modify $ \txs -> txs{stCryptState = cst{cstKey = BulkStateStream bulkStream'}}
getCipherData
record
CipherData
{ cipherDataContent = content
, cipherDataMAC = Just mac
, cipherDataPadding = Nothing
}
decryptOf (BulkStateAEAD decryptF) = do
let authTagLen = bulkAuthTagLen bulk
nonceExpLen = bulkExplicitIV bulk
cipherLen = econtentLen - authTagLen - nonceExpLen
-- check if we have enough bytes to cover the minimum for this cipher
when (econtentLen < (authTagLen + nonceExpLen)) sanityCheckError
(enonce, econtent', authTag) <-
get3o econtent (nonceExpLen, cipherLen, authTagLen)
let encodedSeq = encodeWord64 $ msSequence $ stMacState tst
iv = cstIV (stCryptState tst)
ivlen = B.length iv
Header typ v _ = recordToHeader record
hdrLen = if ver >= TLS13 then econtentLen else cipherLen
hdr = Header typ v $ fromIntegral hdrLen
ad
| ver >= TLS13 = encodeHeader hdr
| otherwise = B.concat [encodedSeq, encodeHeader hdr]
sqnc = B.replicate (ivlen - 8) 0 `B.append` encodedSeq
nonce
| nonceExpLen == 0 = B.xor iv sqnc
| otherwise = iv `B.append` enonce
(content, authTag2) = decryptF nonce econtent' ad
when (AuthTag (B.convert authTag) /= authTag2) $
throwError $
Error_Protocol "bad record mac on AEAD" BadRecordMac
modify incrRecordState
return content
decryptOf BulkStateUninitialized =
throwError $ Error_Protocol "decrypt state uninitialized" InternalError
-- handling of outer format can report errors with Error_Packet
get3o s ls =
maybe (throwError $ Error_Packet "record bad format") return $ partition3 s ls
get2o s (d1, d2) = get3o s (d1, d2, 0) >>= \(r1, r2, _) -> return (r1, r2)
-- all format errors related to decrypted content are reported
-- externally as integrity failures, i.e. BadRecordMac
get3i s ls =
maybe (throwError $ Error_Protocol "record bad format" BadRecordMac) return $
partition3 s ls
get2i s (d1, d2) = get3i s (d1, d2, 0) >>= \(r1, r2, _) -> return (r1, r2)
|