1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE CApiFFI #-}
-- |An interface to Base64 codec.
module OpenSSL.EVP.Base64
( -- * Encoding
encodeBase64
, encodeBase64BS
, encodeBase64LBS
-- * Decoding
, decodeBase64
, decodeBase64BS
, decodeBase64LBS
)
where
import Control.Exception (assert)
import Data.ByteString.Internal (createAndTrim)
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import qualified Data.ByteString.Lazy.Internal as L8Internal
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as L8
import Data.List
#if MIN_VERSION_base(4,5,0)
import Foreign.C.Types (CUChar(..), CInt(..))
#else
import Foreign.C.Types (CUChar, CInt)
#endif
import Foreign.Ptr (Ptr, castPtr)
import System.IO.Unsafe (unsafePerformIO)
-- On encoding, we keep fetching the next block until we get at least
-- 3 bytes. Then we apply B8.concat to the returned [ByteString] and
-- split it at the offset in multiple of 3, then prepend the remaining
-- bytes to the next block.
--
-- On decoding, we apply the same algorithm but we split the input in
-- multiple of 4.
nextBlock :: Int -> ([B8.ByteString], L8.ByteString) -> ([B8.ByteString], L8.ByteString)
nextBlock minLen (xs, src)
= if foldl' (+) 0 (map B8.length xs) >= minLen then
(xs, src)
else
case src of
L8Internal.Empty -> (xs, src)
L8Internal.Chunk y ys -> nextBlock minLen (xs ++ [y], ys)
{- encode -------------------------------------------------------------------- -}
foreign import capi unsafe "openssl/evp.h EVP_EncodeBlock"
_EncodeBlock :: Ptr CUChar -> Ptr CUChar -> CInt -> IO CInt
encodeBlock :: B8.ByteString -> B8.ByteString
encodeBlock inBS
= unsafePerformIO $
unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
createAndTrim maxOutLen $ \ outBuf ->
fmap fromIntegral
(_EncodeBlock (castPtr outBuf) (castPtr inBuf) (fromIntegral inLen))
where
maxOutLen = (inputLen `div` 3 + 1) * 4 + 1 -- +1: '\0'
inputLen = B8.length inBS
-- |@'encodeBase64' str@ lazilly encodes a stream of data to
-- Base64. The string doesn't have to be finite. Note that the string
-- must not contain any letters which aren't in the range of U+0000 -
-- U+00FF.
{-# DEPRECATED encodeBase64 "Use encodeBase64BS or encodeBase64LBS instead." #-}
encodeBase64 :: String -> String
encodeBase64 = L8.unpack . encodeBase64LBS . L8.pack
-- |@'encodeBase64BS' bs@ strictly encodes a chunk of data to Base64.
encodeBase64BS :: B8.ByteString -> B8.ByteString
encodeBase64BS = encodeBlock
-- |@'encodeBase64LBS' lbs@ lazilly encodes a stream of data to
-- Base64. The string doesn't have to be finite.
encodeBase64LBS :: L8.ByteString -> L8.ByteString
encodeBase64LBS inLBS
| L8.null inLBS = L8.empty
| otherwise
= let (blockParts', remain' ) = nextBlock 3 ([], inLBS)
block' = B8.concat blockParts'
blockLen' = B8.length block'
(block , leftover) = if blockLen' < 3 then
-- The last remnant.
(block', B8.empty)
else
B8.splitAt (blockLen' - blockLen' `mod` 3) block'
remain = if B8.null leftover then
remain'
else
L8.fromChunks [leftover] `L8.append` remain'
encodedBlock = encodeBlock block
encodedRemain = encodeBase64LBS remain
in
L8.fromChunks [encodedBlock] `L8.append` encodedRemain
{- decode -------------------------------------------------------------------- -}
foreign import capi unsafe "openssl/evp.h EVP_DecodeBlock"
_DecodeBlock :: Ptr CUChar -> Ptr CUChar -> CInt -> IO CInt
decodeBlock :: B8.ByteString -> B8.ByteString
decodeBlock inBS
= assert (B8.length inBS `mod` 4 == 0) $
unsafePerformIO $
unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
createAndTrim (B8.length inBS) $ \ outBuf ->
_DecodeBlock (castPtr outBuf) (castPtr inBuf) (fromIntegral inLen)
>>= \ outLen -> return (fromIntegral outLen - paddingLen)
where
paddingLen :: Int
paddingLen = B8.count '=' inBS
-- |@'decodeBase64' str@ lazilly decodes a stream of data from
-- Base64. The string doesn't have to be finite.
{-# DEPRECATED decodeBase64 "Use decodeBase64BS or decodeBase64LBS instead." #-}
decodeBase64 :: String -> String
decodeBase64 = L8.unpack . decodeBase64LBS . L8.pack
-- |@'decodeBase64BS' bs@ strictly decodes a chunk of data from
-- Base64.
decodeBase64BS :: B8.ByteString -> B8.ByteString
decodeBase64BS = decodeBlock
-- |@'decodeBase64LBS' lbs@ lazilly decodes a stream of data from
-- Base64. The string doesn't have to be finite.
decodeBase64LBS :: L8.ByteString -> L8.ByteString
decodeBase64LBS inLBS
| L8.null inLBS = L8.empty
| otherwise
= let (blockParts', remain' ) = nextBlock 4 ([], inLBS)
block' = B8.concat blockParts'
blockLen' = B8.length block'
(block , leftover) = assert (blockLen' >= 4) $
B8.splitAt (blockLen' - blockLen' `mod` 4) block'
remain = if B8.null leftover then
remain'
else
L8.fromChunks [leftover] `L8.append` remain'
decodedBlock = decodeBlock block
decodedRemain = decodeBase64LBS remain
in
L8.fromChunks [decodedBlock] `L8.append` decodedRemain
|