File: Token.hs

package info (click to toggle)
haskell-crypto-token 0.1.2-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 68 kB
  • sloc: haskell: 246; makefile: 3
file content (291 lines) | stat: -rw-r--r-- 8,311 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Encrypted tokens/tickets to keep state in the client side.
module Crypto.Token (
    -- * Configuration
    Config,
    defaultConfig,
    interval,
    tokenLifetime,
    threadName,

    -- * Token manager
    TokenManager,
    spawnTokenManager,
    killTokenManager,

    -- * Encryption and decryption
    encryptToken,
    decryptToken,
) where

import Control.Concurrent
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.Types (AEADMode (..), AuthTag (..))
import qualified Crypto.Cipher.Types as C
import Crypto.Error (maybeCryptoError, throwCryptoError)
import Crypto.Random (getRandomBytes)
import Data.Array.IO
import Data.Bits (xor)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.IORef as I
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import GHC.Conc.Sync (labelThread)
import Network.ByteOrder

----------------------------------------------------------------

type Index = Word16
type Counter = Word64

-- | Configuration for token manager.
data Config = Config
    { interval :: Int
    -- ^ The interval to generate a new secret and remove the oldest one in seconds.
    , tokenLifetime :: Int
    -- ^ The token lifetime, that is, tokens can be decrypted in this period.
    , threadName :: String
    }
    deriving (Eq, Show)

-- | Default configuration to update secrets in 30 minutes (1,800 seconds) and token lifetime is 2 hours (7,200 seconds)
--
-- >>> defaultConfig
-- Config {interval = 1800, tokenLifetime = 7200}
defaultConfig :: Config
defaultConfig =
    Config
        { interval = 1800
        , tokenLifetime = 7200
        , threadName = "Crypto token manager"
        }

----------------------------------------------------------------

-- fixme: mask

-- | The abstract data type for token manager.
data TokenManager = TokenManager
    { headerMask :: Header
    , getEncryptSecret :: IO (Secret, Index)
    , getDecryptSecret :: Index -> IO Secret
    , threadId :: ThreadId
    }

-- | Spawning a token manager.
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager Config{..} = do
    emp <- emptySecret
    let lim = fromIntegral (tokenLifetime `div` interval)
    arr <- newArray (0, lim - 1) emp
    ent <- generateSecret
    writeArray arr 0 ent
    ref <- I.newIORef 0
    tid <- forkIO $ loop arr ref
    labelThread tid threadName
    msk <- newHeaderMask
    return $ TokenManager msk (readCurrentSecret arr ref) (readSecret arr) tid
  where
    loop arr ref = do
        threadDelay (interval * 1000000)
        update arr ref
        loop arr ref
    update :: IOArray Index Secret -> I.IORef Index -> IO ()
    update arr ref = do
        idx0 <- I.readIORef ref
        (_, n) <- getBounds arr
        let idx = (idx0 + 1) `mod` (n + 1)
        sec <- generateSecret
        writeArray arr idx sec
        I.writeIORef ref idx

-- | Killing a token manager.
killTokenManager :: TokenManager -> IO ()
killTokenManager TokenManager{..} = killThread threadId

----------------------------------------------------------------

readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret secrets idx0 = do
    (_, n) <- getBounds secrets
    let idx = idx0 `mod` (n + 1)
    readArray secrets idx

readCurrentSecret :: IOArray Index Secret -> I.IORef Index -> IO (Secret, Index)
readCurrentSecret arr ref = do
    idx <- I.readIORef ref
    sec <- readSecret arr idx
    return (sec, idx)

----------------------------------------------------------------

data Secret = Secret
    { secretIV :: ByteString
    , secretKey :: ByteString
    , secretCounter :: I.IORef Counter
    }

emptySecret :: IO Secret
emptySecret = Secret BS.empty BS.empty <$> I.newIORef 0

generateSecret :: IO Secret
generateSecret =
    Secret
        <$> genIV
        <*> genKey
        <*> I.newIORef 0

genKey :: IO ByteString
genKey = getRandomBytes keyLength

genIV :: IO ByteString
genIV = getRandomBytes ivLength

----------------------------------------------------------------

ivLength :: Int
ivLength = 8

keyLength :: Int
keyLength = 32

indexLength :: Int
indexLength = 2

counterLength :: Int
counterLength = 8

tagLength :: Int
tagLength = 16

----------------------------------------------------------------

data Header = Header
    { headerIndex :: Index
    , headerCounter :: Counter
    }

encodeHeader :: Header -> IO ByteString
encodeHeader Header{..} = withWriteBuffer (indexLength + counterLength) $ \wbuf -> do
    write16 wbuf headerIndex
    write64 wbuf headerCounter

decodeHeader :: ByteString -> IO Header
decodeHeader bs = withReadBuffer bs $ \rbuf ->
    Header <$> read16 rbuf <*> read64 rbuf

newHeaderMask :: IO Header
newHeaderMask = do
    bin <- getRandomBytes (indexLength + counterLength) :: IO ByteString
    decodeHeader bin

----------------------------------------------------------------

xorHeader :: Header -> Header -> Header
xorHeader x y =
    Header
        { headerIndex = headerIndex x `xor` headerIndex y
        , headerCounter = headerCounter x `xor` headerCounter y
        }

addHeader :: TokenManager -> Index -> Counter -> ByteString -> IO ByteString
addHeader TokenManager{..} idx counter cipher = do
    let hdr = Header idx counter
        mskhdr = headerMask `xorHeader` hdr
    hdrbin <- encodeHeader mskhdr
    return (hdrbin `BS.append` cipher)

delHeader
    :: TokenManager -> ByteString -> IO (Maybe (Index, Counter, ByteString))
delHeader TokenManager{..} token
    | BS.length token < minlen = return Nothing
    | otherwise = do
        let (hdrbin, cipher) = BS.splitAt minlen token
        mskhdr <- decodeHeader hdrbin
        let hdr = headerMask `xorHeader` mskhdr
            idx = headerIndex hdr
            counter = headerCounter hdr
        return $ Just (idx, counter, cipher)
  where
    minlen = indexLength + counterLength

-- | Encrypting a target value to get a token.
encryptToken
    :: TokenManager
    -> ByteString
    -> IO ByteString
encryptToken mgr x = do
    (secret, idx) <- getEncryptSecret mgr
    (counter, cipher) <- encrypt secret x
    addHeader mgr idx counter cipher

encrypt
    :: Secret
    -> ByteString
    -> IO (Counter, ByteString)
encrypt secret plain = do
    counter <- I.atomicModifyIORef' (secretCounter secret) (\i -> (i + 1, i))
    nonce <- makeNonce counter $ secretIV secret
    let cipher = aes256gcmEncrypt plain (secretKey secret) nonce
    return (counter, cipher)

-- | Decrypting a token to get a target value.
decryptToken
    :: TokenManager
    -> ByteString
    -> IO (Maybe ByteString)
decryptToken mgr token = do
    mx <- delHeader mgr token
    case mx of
        Nothing -> return Nothing
        Just (idx, counter, cipher) -> do
            secret <- getDecryptSecret mgr idx
            decrypt secret counter cipher

decrypt
    :: Secret
    -> Counter
    -> ByteString
    -> IO (Maybe ByteString)
decrypt secret counter cipher = do
    nonce <- makeNonce counter $ secretIV secret
    return $ aes256gcmDecrypt cipher (secretKey secret) nonce

makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce counter iv = do
    cv <- BS.create ivLength $ \ptr -> poke (castPtr ptr) counter
    return $ iv `BA.xor` cv

----------------------------------------------------------------

constantAdditionalData :: ByteString
constantAdditionalData = BS.empty

aes256gcmEncrypt
    :: ByteString
    -> ByteString
    -> ByteString
    -> ByteString
aes256gcmEncrypt plain key nonce = cipher `BS.append` (BA.convert tag)
  where
    conn = throwCryptoError (C.cipherInit key) :: AES256
    aeadIni = throwCryptoError $ C.aeadInit AEAD_GCM conn nonce
    (AuthTag tag, cipher) = C.aeadSimpleEncrypt aeadIni constantAdditionalData plain tagLength

aes256gcmDecrypt
    :: ByteString
    -> ByteString
    -> ByteString
    -> Maybe ByteString
aes256gcmDecrypt ctexttag key nonce = do
    aes <- maybeCryptoError $ C.cipherInit key :: Maybe AES256
    aead <- maybeCryptoError $ C.aeadInit AEAD_GCM aes nonce
    let (ctext, tag) = BS.splitAt (BS.length ctexttag - tagLength) ctexttag
        authtag = AuthTag $ BA.convert tag
    C.aeadSimpleDecrypt aead constantAdditionalData ctext authtag