File: State13.hs

package info (click to toggle)
haskell-tls 1.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 916 kB
  • sloc: haskell: 12,430; makefile: 3
file content (171 lines) | stat: -rw-r--r-- 6,140 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      : Network.TLS.Handshake.State13
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.Handshake.State13
       ( CryptLevel ( CryptEarlySecret
                    , CryptHandshakeSecret
                    , CryptApplicationSecret
                    )
       , TrafficSecret
       , getTxState
       , getRxState
       , setTxState
       , setRxState
       , clearTxState
       , clearRxState
       , setHelloParameters13
       , transcriptHash
       , wrapAsMessageHash13
       , PendingAction(..)
       , setPendingActions
       , popPendingAction
       ) where

import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Imports
import Network.TLS.Types
import Network.TLS.Util

getTxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxState ctx = getXState ctx ctxTxState

getRxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxState ctx = getXState ctx ctxRxState

getXState :: Context
          -> (Context -> MVar RecordState)
          -> IO (Hash, Cipher, CryptLevel, ByteString)
getXState ctx func = do
    tx <- readMVar (func ctx)
    let Just usedCipher = stCipher tx
        usedHash = cipherHash usedCipher
        level = stCryptLevel tx
        secret = cstMacSecret $ stCryptState tx
    return (usedHash, usedCipher, level, secret)

class TrafficSecret ty where
    fromTrafficSecret :: ty -> (CryptLevel, ByteString)

instance HasCryptLevel a => TrafficSecret (AnyTrafficSecret a) where
    fromTrafficSecret prx@(AnyTrafficSecret s) = (getCryptLevel prx, s)

instance HasCryptLevel a => TrafficSecret (ClientTrafficSecret a) where
    fromTrafficSecret prx@(ClientTrafficSecret s) = (getCryptLevel prx, s)

instance HasCryptLevel a => TrafficSecret (ServerTrafficSecret a) where
    fromTrafficSecret prx@(ServerTrafficSecret s) = (getCryptLevel prx, s)

setTxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setTxState = setXState ctxTxState BulkEncrypt

setRxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setRxState = setXState ctxRxState BulkDecrypt

setXState :: TrafficSecret ty
          => (Context -> MVar RecordState) -> BulkDirection
          -> Context -> Hash -> Cipher -> ty
          -> IO ()
setXState func encOrDec ctx h cipher ts =
    let (lvl, secret) = fromTrafficSecret ts
     in setXState' func encOrDec ctx h cipher lvl secret

setXState' :: (Context -> MVar RecordState) -> BulkDirection
          -> Context -> Hash -> Cipher -> CryptLevel -> ByteString
          -> IO ()
setXState' func encOrDec ctx h cipher lvl secret =
    modifyMVar_ (func ctx) (\_ -> return rt)
  where
    bulk    = cipherBulk cipher
    keySize = bulkKeySize bulk
    ivSize  = max 8 (bulkIVSize bulk + bulkExplicitIV bulk)
    key = hkdfExpandLabel h secret "key" "" keySize
    iv  = hkdfExpandLabel h secret "iv"  "" ivSize
    cst = CryptState {
        cstKey       = bulkInit bulk encOrDec key
      , cstIV        = iv
      , cstMacSecret = secret
      }
    rt = RecordState {
        stCryptState  = cst
      , stMacState    = MacState { msSequence = 0 }
      , stCryptLevel  = lvl
      , stCipher      = Just cipher
      , stCompression = nullCompression
      }

clearTxState :: Context -> IO ()
clearTxState = clearXState ctxTxState

clearRxState :: Context -> IO ()
clearRxState = clearXState ctxRxState

clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState func ctx =
    modifyMVar_ (func ctx) (\rt -> return rt { stCipher = Nothing })

setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 cipher = do
    hst <- get
    case hstPendingCipher hst of
        Nothing -> do
            put hst {
                  hstPendingCipher      = Just cipher
                , hstPendingCompression = nullCompression
                , hstHandshakeDigest    = updateDigest $ hstHandshakeDigest hst
                }
            return $ Right ()
        Just oldcipher
            | cipher == oldcipher -> return $ Right ()
            | otherwise -> return $ Left $ Error_Protocol ("TLS 1.3 cipher changed after hello retry", True, IllegalParameter)
  where
    hashAlg = cipherHash cipher
    updateDigest (HandshakeMessages bytes)  = HandshakeDigestContext $ foldl hashUpdate (hashInit hashAlg) $ reverse bytes
    updateDigest (HandshakeDigestContext _) = error "cannot initialize digest with another digest"

-- When a HelloRetryRequest is sent or received, the existing transcript must be
-- wrapped in a "message_hash" construct.  See RFC 8446 section 4.4.1.  This
-- applies to key-schedule computations as well as the ones for PSK binders.
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
    cipher <- getPendingCipher
    foldHandshakeDigest (cipherHash cipher) foldFunc
  where
    foldFunc dig = B.concat [ "\254\0\0"
                            , B.singleton (fromIntegral $ B.length dig)
                            , dig
                            ]

transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash ctx = do
    hst <- fromJust "HState" <$> getHState ctx
    case hstHandshakeDigest hst of
      HandshakeDigestContext hashCtx -> return $ hashFinal hashCtx
      HandshakeMessages      _       -> error "un-initialized handshake digest"

setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions ctx = writeIORef (ctxPendingActions ctx)

popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction ctx = do
    let ref = ctxPendingActions ctx
    actions <- readIORef ref
    case actions of
        bs:bss -> writeIORef ref bss >> return (Just bs)
        []     -> return Nothing