File: State13.hs

package info (click to toggle)
haskell-tls 2.1.8-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,056 kB
  • sloc: haskell: 15,695; makefile: 3
file content (203 lines) | stat: -rw-r--r-- 6,549 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.State13 (
    CryptLevel (
        CryptEarlySecret,
        CryptHandshakeSecret,
        CryptApplicationSecret
    ),
    TrafficSecret,
    getTxRecordState,
    getRxRecordState,
    setTxRecordState,
    setRxRecordState,
    getTxLevel,
    getRxLevel,
    clearTxRecordState,
    clearRxRecordState,
    setHelloParameters13,
    transcriptHash,
    wrapAsMessageHash13,
    PendingRecvAction (..),
    setPendingRecvActions,
    popPendingRecvAction,
) where

import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef

import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Types

getTxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxRecordState ctx = getXState ctx ctxTxRecordState

getRxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxRecordState ctx = getXState ctx ctxRxRecordState

getXState
    :: Context
    -> (Context -> MVar RecordState)
    -> IO (Hash, Cipher, CryptLevel, ByteString)
getXState ctx func = do
    tx <- readMVar (func ctx)
    let usedCipher = fromJust $ stCipher tx
        usedHash = cipherHash usedCipher
        level = stCryptLevel tx
        secret = cstMacSecret $ stCryptState tx
    return (usedHash, usedCipher, level, secret)

-- In the case of QUIC, stCipher is Nothing.
-- So, fromJust causes an error.
getTxLevel :: Context -> IO CryptLevel
getTxLevel ctx = getXLevel ctx ctxTxRecordState

getRxLevel :: Context -> IO CryptLevel
getRxLevel ctx = getXLevel ctx ctxRxRecordState

getXLevel
    :: Context
    -> (Context -> MVar RecordState)
    -> IO CryptLevel
getXLevel ctx func = do
    tx <- readMVar (func ctx)
    return $ stCryptLevel tx

class TrafficSecret ty where
    fromTrafficSecret :: ty -> (CryptLevel, ByteString)

instance HasCryptLevel a => TrafficSecret (AnyTrafficSecret a) where
    fromTrafficSecret prx@(AnyTrafficSecret s) = (getCryptLevel prx, s)

instance HasCryptLevel a => TrafficSecret (ClientTrafficSecret a) where
    fromTrafficSecret prx@(ClientTrafficSecret s) = (getCryptLevel prx, s)

instance HasCryptLevel a => TrafficSecret (ServerTrafficSecret a) where
    fromTrafficSecret prx@(ServerTrafficSecret s) = (getCryptLevel prx, s)

setTxRecordState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setTxRecordState = setXState ctxTxRecordState BulkEncrypt

setRxRecordState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setRxRecordState = setXState ctxRxRecordState BulkDecrypt

setXState
    :: TrafficSecret ty
    => (Context -> MVar RecordState)
    -> BulkDirection
    -> Context
    -> Hash
    -> Cipher
    -> ty
    -> IO ()
setXState func encOrDec ctx h cipher ts =
    let (lvl, secret) = fromTrafficSecret ts
     in setXState' func encOrDec ctx h cipher lvl secret

setXState'
    :: (Context -> MVar RecordState)
    -> BulkDirection
    -> Context
    -> Hash
    -> Cipher
    -> CryptLevel
    -> ByteString
    -> IO ()
setXState' func encOrDec ctx h cipher lvl secret =
    modifyMVar_ (func ctx) (\_ -> return rt)
  where
    bulk = cipherBulk cipher
    keySize = bulkKeySize bulk
    ivSize = max 8 (bulkIVSize bulk + bulkExplicitIV bulk)
    key = hkdfExpandLabel h secret "key" "" keySize
    iv = hkdfExpandLabel h secret "iv" "" ivSize
    cst =
        CryptState
            { cstKey = bulkInit bulk encOrDec key
            , cstIV = iv
            , cstMacSecret = secret
            }
    rt =
        RecordState
            { stCryptState = cst
            , stMacState = MacState{msSequence = 0}
            , stCryptLevel = lvl
            , stCipher = Just cipher
            , stCompression = nullCompression
            }

clearTxRecordState :: Context -> IO ()
clearTxRecordState = clearXState ctxTxRecordState

clearRxRecordState :: Context -> IO ()
clearRxRecordState = clearXState ctxRxRecordState

clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState func ctx =
    modifyMVar_ (func ctx) (\rt -> return rt{stCipher = Nothing})

setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 cipher = do
    hst <- get
    case hstPendingCipher hst of
        Nothing -> do
            put
                hst
                    { hstPendingCipher = Just cipher
                    , hstPendingCompression = nullCompression
                    , hstHandshakeDigest = updateDigest $ hstHandshakeDigest hst
                    }
            return $ Right ()
        Just oldcipher
            | cipher == oldcipher -> return $ Right ()
            | otherwise ->
                return $
                    Left $
                        Error_Protocol "TLS 1.3 cipher changed after hello retry" IllegalParameter
  where
    hashAlg = cipherHash cipher
    updateDigest (HandshakeMessages bytes) = HandshakeDigestContext $ foldl hashUpdate (hashInit hashAlg) $ reverse bytes
    updateDigest (HandshakeDigestContext _) = error "cannot initialize digest with another digest"

-- When a HelloRetryRequest is sent or received, the existing transcript must be
-- wrapped in a "message_hash" construct.  See RFC 8446 section 4.4.1.  This
-- applies to key-schedule computations as well as the ones for PSK binders.
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
    cipher <- getPendingCipher
    foldHandshakeDigest (cipherHash cipher) foldFunc
  where
    foldFunc dig =
        B.concat
            [ "\254\0\0"
            , B.singleton (fromIntegral $ B.length dig)
            , dig
            ]

transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash ctx = do
    hst <- fromJust <$> getHState ctx
    case hstHandshakeDigest hst of
        HandshakeDigestContext hashCtx -> return $ hashFinal hashCtx
        HandshakeMessages _ -> error "un-initialized handshake digest"

setPendingRecvActions :: Context -> [PendingRecvAction] -> IO ()
setPendingRecvActions ctx = writeIORef (ctxPendingRecvActions ctx)

popPendingRecvAction :: Context -> IO (Maybe PendingRecvAction)
popPendingRecvAction ctx = do
    let ref = ctxPendingRecvActions ctx
    actions <- readIORef ref
    case actions of
        bs : bss -> writeIORef ref bss >> return (Just bs)
        [] -> return Nothing