1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
|
-- |
-- Module : Data.SecureMem
-- License : BSD-style
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
-- Stability : Stable
-- Portability : GHC
--
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE CPP #-}
module Data.SecureMem
( SecureMem
, secureMemGetSize
, secureMemCopy
, ToSecureMem(..)
-- * Allocation and early termination
, allocateSecureMem
, createSecureMem
, unsafeCreateSecureMem
, finalizeSecureMem
-- * Pointers manipulation
, withSecureMemPtr
, withSecureMemPtrSz
, withSecureMemCopy
-- * convertion
, secureMemFromByteString
, secureMemFromByteable
) where
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr
import Data.Word (Word8)
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup
import Data.Foldable (toList)
#else
import Data.Monoid
#endif
import Control.Applicative
import Data.Byteable
import Data.ByteString (ByteString)
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as B
import qualified Data.Memory.PtrMethods as B (memSet)
import qualified Data.ByteString.Internal as BS
#if MIN_VERSION_base(4,4,0)
import System.IO.Unsafe (unsafeDupablePerformIO)
#else
import System.IO.Unsafe (unsafePerformIO)
#endif
pureIO :: IO a -> a
#if MIN_VERSION_base(4,4,0)
pureIO = unsafeDupablePerformIO
#else
pureIO = unsafePerformIO
#endif
-- | SecureMem is a memory chunk which have the properties of:
--
-- * Being scrubbed after its goes out of scope.
--
-- * A Show instance that doesn't actually show any content
--
-- * A Eq instance that is constant time
--
newtype SecureMem = SecureMem ScrubbedBytes
secureMemGetSize :: SecureMem -> Int
secureMemGetSize (SecureMem scrubbedBytes) = B.length scrubbedBytes
secureMemEq :: SecureMem -> SecureMem -> Bool
secureMemEq (SecureMem sm1) (SecureMem sm2) = sm1 == sm2
secureMemAppend :: SecureMem -> SecureMem -> SecureMem
secureMemAppend (SecureMem s1) (SecureMem s2) = SecureMem (s1 `mappend` s2)
secureMemConcat :: [SecureMem] -> SecureMem
secureMemConcat = SecureMem . mconcat . map unSecureMem
where unSecureMem (SecureMem sb) = sb
secureMemCopy :: SecureMem -> IO SecureMem
secureMemCopy (SecureMem src) =
SecureMem `fmap` B.copy src (\_ -> return ())
withSecureMemCopy :: SecureMem -> (Ptr Word8 -> IO ()) -> IO SecureMem
withSecureMemCopy (SecureMem src) f = SecureMem `fmap` B.copy src f
instance Show SecureMem where
show _ = "<secure-mem>"
instance Byteable SecureMem where
toBytes = secureMemToByteString
byteableLength = secureMemGetSize
withBytePtr = withSecureMemPtr
instance Eq SecureMem where
(==) = secureMemEq
#if MIN_VERSION_base(4,9,0)
instance Semigroup SecureMem where
(<>) = secureMemAppend
sconcat = secureMemConcat . toList
#endif
instance Monoid SecureMem where
mempty = unsafeCreateSecureMem 0 (\_ -> return ())
#if !(MIN_VERSION_base(4,11,0))
mappend = secureMemAppend
mconcat = secureMemConcat
#endif
-- | Types that can be converted to a secure mem object.
class ToSecureMem a where
toSecureMem :: a -> SecureMem
instance ToSecureMem SecureMem where
toSecureMem a = a
instance ToSecureMem ByteString where
toSecureMem bs = secureMemFromByteString bs
-- | Allocate a new SecureMem
--
-- The memory is allocated on the haskell heap, and will be scrubed
-- before being released.
allocateSecureMem :: Int -> IO SecureMem
allocateSecureMem sz = SecureMem <$> B.create sz (\_ -> return ())
-- | Create a new secure mem and running an initializer function
createSecureMem :: Int -> (Ptr Word8 -> IO ()) -> IO SecureMem
createSecureMem sz f = SecureMem `fmap` B.create sz f
-- | Create a new secure mem using inline perform IO to create a pure
-- result.
unsafeCreateSecureMem :: Int -> (Ptr Word8 -> IO ()) -> SecureMem
unsafeCreateSecureMem sz f = pureIO (createSecureMem sz f)
{-# NOINLINE unsafeCreateSecureMem #-}
-- | This is a way to look at the pointer living inside a foreign object. This
-- function takes a function which is applied to that pointer. The resulting IO
-- action is then executed
--
-- this is similary to withForeignPtr for a ForeignPtr
withSecureMemPtr :: SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr (SecureMem sm) f = B.withByteArray sm f
-- | similar to withSecureMem but also include the size of the pointed memory.
withSecureMemPtrSz :: SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz (SecureMem sm) f = B.withByteArray sm (f (B.length sm))
-- | Finalize a SecureMem early
finalizeSecureMem :: SecureMem -> IO ()
finalizeSecureMem (SecureMem sb) = B.withByteArray sb $ \p ->
B.memSet p 0 (B.length sb)
-- | Create a bytestring from a Secure Mem
secureMemToByteString :: SecureMem -> ByteString
secureMemToByteString sm =
BS.unsafeCreate sz $ \dst ->
withSecureMemPtr sm $ \src ->
BS.memcpy dst src (fromIntegral sz)
where !sz = secureMemGetSize sm
-- | Create a SecureMem from a bytestring
secureMemFromByteString :: ByteString -> SecureMem
secureMemFromByteString b = pureIO $ do
sm <- allocateSecureMem len
withSecureMemPtr sm $ \dst -> withBytestringPtr $ \src -> BS.memcpy dst src (fromIntegral len)
return sm
where (fp, off, !len) = BS.toForeignPtr b
withBytestringPtr f = withForeignPtr fp $ \p -> f (p `plusPtr` off)
{-# NOINLINE secureMemFromByteString #-}
-- | Create a SecureMem from any byteable object
secureMemFromByteable :: Byteable b => b -> SecureMem
secureMemFromByteable bs = pureIO $ do
sm <- allocateSecureMem len
withSecureMemPtr sm $ \dst -> withBytePtr bs $ \src -> BS.memcpy dst src (fromIntegral len)
return sm
where len = byteableLength bs
{-# NOINLINE secureMemFromByteable #-}
|