1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
{-# LANGUAGE MultiParamTypeClasses #-}
module Network.TLS.Record.State (
CryptState (..),
CryptLevel (..),
HasCryptLevel (..),
MacState (..),
RecordOptions (..),
RecordState (..),
newRecordState,
incrRecordState,
RecordM,
runRecordM,
getRecordOptions,
getRecordVersion,
setRecordIV,
withCompression,
computeDigest,
makeDigest,
getBulk,
getMacSequence,
) where
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.ErrT
import Network.TLS.Imports
import Network.TLS.MAC
import Network.TLS.Packet
import Network.TLS.Struct
import Network.TLS.Types
import Network.TLS.Wire
data CryptState = CryptState
{ cstKey :: BulkState
, cstIV :: ByteString
, -- In TLS 1.2 or earlier, this holds mac secret.
-- In TLS 1.3, this holds application traffic secret N.
cstMacSecret :: ByteString
}
deriving (Show)
newtype MacState = MacState
{ msSequence :: Word64
}
deriving (Show)
data RecordOptions = RecordOptions
{ recordVersion :: Version -- version to use when sending/receiving
, recordTLS13 :: Bool -- TLS13 record processing
}
-- | TLS encryption level.
data CryptLevel
= -- | Unprotected traffic
CryptInitial
| -- | Protected with main secret (TLS < 1.3)
CryptMainSecret
| -- | Protected with early traffic secret (TLS 1.3)
CryptEarlySecret
| -- | Protected with handshake traffic secret (TLS 1.3)
CryptHandshakeSecret
| -- | Protected with application traffic secret (TLS 1.3)
CryptApplicationSecret
deriving (Eq, Show)
class HasCryptLevel a where getCryptLevel :: proxy a -> CryptLevel
instance HasCryptLevel EarlySecret where getCryptLevel _ = CryptEarlySecret
instance HasCryptLevel HandshakeSecret where
getCryptLevel _ = CryptHandshakeSecret
instance HasCryptLevel ApplicationSecret where
getCryptLevel _ = CryptApplicationSecret
data RecordState = RecordState
{ stCipher :: Maybe Cipher
, stCompression :: Compression
, stCryptLevel :: CryptLevel
, stCryptState :: CryptState
, stMacState :: MacState
}
deriving (Show)
newtype RecordM a = RecordM
{ runRecordM
:: RecordOptions
-> RecordState
-> Either TLSError (a, RecordState)
}
instance Applicative RecordM where
pure a = RecordM $ \_ st -> Right (a, st)
(<*>) = ap
instance Monad RecordM where
m1 >>= m2 = RecordM $ \opt st ->
case runRecordM m1 opt st of
Left err -> Left err
Right (a, st2) -> runRecordM (m2 a) opt st2
instance Functor RecordM where
fmap f m = RecordM $ \opt st ->
case runRecordM m opt st of
Left err -> Left err
Right (a, st2) -> Right (f a, st2)
getRecordOptions :: RecordM RecordOptions
getRecordOptions = RecordM $ \opt st -> Right (opt, st)
getRecordVersion :: RecordM Version
getRecordVersion = recordVersion <$> getRecordOptions
instance MonadState RecordState RecordM where
put x = RecordM $ \_ _ -> Right ((), x)
get = RecordM $ \_ st -> Right (st, st)
state f = RecordM $ \_ st -> Right (f st)
instance MonadError TLSError RecordM where
throwError e = RecordM $ \_ _ -> Left e
catchError m f = RecordM $ \opt st ->
case runRecordM m opt st of
Left err -> runRecordM (f err) opt st
r -> r
newRecordState :: RecordState
newRecordState =
RecordState
{ stCipher = Nothing
, stCompression = nullCompression
, stCryptLevel = CryptInitial
, stCryptState = CryptState BulkStateUninitialized B.empty B.empty
, stMacState = MacState 0
}
incrRecordState :: RecordState -> RecordState
incrRecordState ts = ts{stMacState = MacState (ms + 1)}
where
(MacState ms) = stMacState ts
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV iv st = st{stCryptState = (stCryptState st){cstIV = iv}}
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression f = do
st <- get
let (nc, a) = f $ stCompression st
put $ st{stCompression = nc}
return a
computeDigest
:: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest _ver tstate hdr content = (digest, incrRecordState tstate)
where
digest = macF (cstMacSecret cst) msg
cst = stCryptState tstate
cipher = fromJust $ stCipher tstate
hashA = cipherHash cipher
encodedSeq = encodeWord64 $ msSequence $ stMacState tstate
(macF, msg) = (hmac hashA, B.concat [encodedSeq, encodeHeader hdr, content])
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest hdr content = do
ver <- getRecordVersion
st <- get
let (digest, nstate) = computeDigest ver st hdr content
put nstate
return digest
getBulk :: RecordM Bulk
getBulk = cipherBulk . fromJust . stCipher <$> get
getMacSequence :: RecordM Word64
getMacSequence = msSequence . stMacState <$> get
|