1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
|
{-# LANGUAGE OverloadedStrings #-}
-- |
-- Module : Network.TLS.Handshake.State13
-- License : BSD-style
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
-- Stability : experimental
-- Portability : unknown
--
module Network.TLS.Handshake.State13
( CryptLevel ( CryptEarlySecret
, CryptHandshakeSecret
, CryptApplicationSecret
)
, TrafficSecret
, getTxState
, getRxState
, setTxState
, setRxState
, clearTxState
, clearRxState
, setHelloParameters13
, transcriptHash
, wrapAsMessageHash13
, PendingAction(..)
, setPendingActions
, popPendingAction
) where
import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Imports
import Network.TLS.Types
import Network.TLS.Util
getTxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxState ctx = getXState ctx ctxTxState
getRxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxState ctx = getXState ctx ctxRxState
getXState :: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState ctx func = do
tx <- readMVar (func ctx)
let Just usedCipher = stCipher tx
usedHash = cipherHash usedCipher
level = stCryptLevel tx
secret = cstMacSecret $ stCryptState tx
return (usedHash, usedCipher, level, secret)
class TrafficSecret ty where
fromTrafficSecret :: ty -> (CryptLevel, ByteString)
instance HasCryptLevel a => TrafficSecret (AnyTrafficSecret a) where
fromTrafficSecret prx@(AnyTrafficSecret s) = (getCryptLevel prx, s)
instance HasCryptLevel a => TrafficSecret (ClientTrafficSecret a) where
fromTrafficSecret prx@(ClientTrafficSecret s) = (getCryptLevel prx, s)
instance HasCryptLevel a => TrafficSecret (ServerTrafficSecret a) where
fromTrafficSecret prx@(ServerTrafficSecret s) = (getCryptLevel prx, s)
setTxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setTxState = setXState ctxTxState BulkEncrypt
setRxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setRxState = setXState ctxRxState BulkDecrypt
setXState :: TrafficSecret ty
=> (Context -> MVar RecordState) -> BulkDirection
-> Context -> Hash -> Cipher -> ty
-> IO ()
setXState func encOrDec ctx h cipher ts =
let (lvl, secret) = fromTrafficSecret ts
in setXState' func encOrDec ctx h cipher lvl secret
setXState' :: (Context -> MVar RecordState) -> BulkDirection
-> Context -> Hash -> Cipher -> CryptLevel -> ByteString
-> IO ()
setXState' func encOrDec ctx h cipher lvl secret =
modifyMVar_ (func ctx) (\_ -> return rt)
where
bulk = cipherBulk cipher
keySize = bulkKeySize bulk
ivSize = max 8 (bulkIVSize bulk + bulkExplicitIV bulk)
key = hkdfExpandLabel h secret "key" "" keySize
iv = hkdfExpandLabel h secret "iv" "" ivSize
cst = CryptState {
cstKey = bulkInit bulk encOrDec key
, cstIV = iv
, cstMacSecret = secret
}
rt = RecordState {
stCryptState = cst
, stMacState = MacState { msSequence = 0 }
, stCryptLevel = lvl
, stCipher = Just cipher
, stCompression = nullCompression
}
clearTxState :: Context -> IO ()
clearTxState = clearXState ctxTxState
clearRxState :: Context -> IO ()
clearRxState = clearXState ctxRxState
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState func ctx =
modifyMVar_ (func ctx) (\rt -> return rt { stCipher = Nothing })
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 cipher = do
hst <- get
case hstPendingCipher hst of
Nothing -> do
put hst {
hstPendingCipher = Just cipher
, hstPendingCompression = nullCompression
, hstHandshakeDigest = updateDigest $ hstHandshakeDigest hst
}
return $ Right ()
Just oldcipher
| cipher == oldcipher -> return $ Right ()
| otherwise -> return $ Left $ Error_Protocol ("TLS 1.3 cipher changed after hello retry", True, IllegalParameter)
where
hashAlg = cipherHash cipher
updateDigest (HandshakeMessages bytes) = HandshakeDigestContext $ foldl hashUpdate (hashInit hashAlg) $ reverse bytes
updateDigest (HandshakeDigestContext _) = error "cannot initialize digest with another digest"
-- When a HelloRetryRequest is sent or received, the existing transcript must be
-- wrapped in a "message_hash" construct. See RFC 8446 section 4.4.1. This
-- applies to key-schedule computations as well as the ones for PSK binders.
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
cipher <- getPendingCipher
foldHandshakeDigest (cipherHash cipher) foldFunc
where
foldFunc dig = B.concat [ "\254\0\0"
, B.singleton (fromIntegral $ B.length dig)
, dig
]
transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash ctx = do
hst <- fromJust "HState" <$> getHState ctx
case hstHandshakeDigest hst of
HandshakeDigestContext hashCtx -> return $ hashFinal hashCtx
HandshakeMessages _ -> error "un-initialized handshake digest"
setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions ctx = writeIORef (ctxPendingActions ctx)
popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction ctx = do
let ref = ctxPendingActions ctx
actions <- readIORef ref
case actions of
bs:bss -> writeIORef ref bss >> return (Just bs)
[] -> return Nothing
|