File: State.hs

package info (click to toggle)
haskell-tls 2.1.8-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,056 kB
  • sloc: haskell: 15,695; makefile: 3
file content (176 lines) | stat: -rw-r--r-- 5,060 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
{-# LANGUAGE MultiParamTypeClasses #-}

module Network.TLS.Record.State (
    CryptState (..),
    CryptLevel (..),
    HasCryptLevel (..),
    MacState (..),
    RecordOptions (..),
    RecordState (..),
    newRecordState,
    incrRecordState,
    RecordM,
    runRecordM,
    getRecordOptions,
    getRecordVersion,
    setRecordIV,
    withCompression,
    computeDigest,
    makeDigest,
    getBulk,
    getMacSequence,
) where

import Control.Monad.State.Strict
import qualified Data.ByteString as B

import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.ErrT
import Network.TLS.Imports
import Network.TLS.MAC
import Network.TLS.Packet
import Network.TLS.Struct
import Network.TLS.Types
import Network.TLS.Wire

data CryptState = CryptState
    { cstKey :: BulkState
    , cstIV :: ByteString
    , -- In TLS 1.2 or earlier, this holds mac secret.
      -- In TLS 1.3, this holds application traffic secret N.
      cstMacSecret :: ByteString
    }
    deriving (Show)

newtype MacState = MacState
    { msSequence :: Word64
    }
    deriving (Show)

data RecordOptions = RecordOptions
    { recordVersion :: Version -- version to use when sending/receiving
    , recordTLS13 :: Bool -- TLS13 record processing
    }

-- | TLS encryption level.
data CryptLevel
    = -- | Unprotected traffic
      CryptInitial
    | -- | Protected with main secret (TLS < 1.3)
      CryptMainSecret
    | -- | Protected with early traffic secret (TLS 1.3)
      CryptEarlySecret
    | -- | Protected with handshake traffic secret (TLS 1.3)
      CryptHandshakeSecret
    | -- | Protected with application traffic secret (TLS 1.3)
      CryptApplicationSecret
    deriving (Eq, Show)

class HasCryptLevel a where getCryptLevel :: proxy a -> CryptLevel
instance HasCryptLevel EarlySecret where getCryptLevel _ = CryptEarlySecret
instance HasCryptLevel HandshakeSecret where
    getCryptLevel _ = CryptHandshakeSecret
instance HasCryptLevel ApplicationSecret where
    getCryptLevel _ = CryptApplicationSecret

data RecordState = RecordState
    { stCipher :: Maybe Cipher
    , stCompression :: Compression
    , stCryptLevel :: CryptLevel
    , stCryptState :: CryptState
    , stMacState :: MacState
    }
    deriving (Show)

newtype RecordM a = RecordM
    { runRecordM
        :: RecordOptions
        -> RecordState
        -> Either TLSError (a, RecordState)
    }

instance Applicative RecordM where
    pure a = RecordM $ \_ st -> Right (a, st)
    (<*>) = ap

instance Monad RecordM where
    m1 >>= m2 = RecordM $ \opt st ->
        case runRecordM m1 opt st of
            Left err -> Left err
            Right (a, st2) -> runRecordM (m2 a) opt st2

instance Functor RecordM where
    fmap f m = RecordM $ \opt st ->
        case runRecordM m opt st of
            Left err -> Left err
            Right (a, st2) -> Right (f a, st2)

getRecordOptions :: RecordM RecordOptions
getRecordOptions = RecordM $ \opt st -> Right (opt, st)

getRecordVersion :: RecordM Version
getRecordVersion = recordVersion <$> getRecordOptions

instance MonadState RecordState RecordM where
    put x = RecordM $ \_ _ -> Right ((), x)
    get = RecordM $ \_ st -> Right (st, st)
    state f = RecordM $ \_ st -> Right (f st)

instance MonadError TLSError RecordM where
    throwError e = RecordM $ \_ _ -> Left e
    catchError m f = RecordM $ \opt st ->
        case runRecordM m opt st of
            Left err -> runRecordM (f err) opt st
            r -> r

newRecordState :: RecordState
newRecordState =
    RecordState
        { stCipher = Nothing
        , stCompression = nullCompression
        , stCryptLevel = CryptInitial
        , stCryptState = CryptState BulkStateUninitialized B.empty B.empty
        , stMacState = MacState 0
        }

incrRecordState :: RecordState -> RecordState
incrRecordState ts = ts{stMacState = MacState (ms + 1)}
  where
    (MacState ms) = stMacState ts

setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV iv st = st{stCryptState = (stCryptState st){cstIV = iv}}

withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression f = do
    st <- get
    let (nc, a) = f $ stCompression st
    put $ st{stCompression = nc}
    return a

computeDigest
    :: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest _ver tstate hdr content = (digest, incrRecordState tstate)
  where
    digest = macF (cstMacSecret cst) msg
    cst = stCryptState tstate
    cipher = fromJust $ stCipher tstate
    hashA = cipherHash cipher
    encodedSeq = encodeWord64 $ msSequence $ stMacState tstate

    (macF, msg) = (hmac hashA, B.concat [encodedSeq, encodeHeader hdr, content])

makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest hdr content = do
    ver <- getRecordVersion
    st <- get
    let (digest, nstate) = computeDigest ver st hdr content
    put nstate
    return digest

getBulk :: RecordM Bulk
getBulk = cipherBulk . fromJust . stCipher <$> get

getMacSequence :: RecordM Word64
getMacSequence = msSequence . stMacState <$> get