File: Sized.hs

package info (click to toggle)
haskell-memory 0.18.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 324 kB
  • sloc: haskell: 3,362; makefile: 7
file content (398 lines) | stat: -rw-r--r-- 12,920 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
-- |
-- Module      : Data.ByteArray.Sized
-- License     : BSD-style
-- Maintainer  : Nicolas Di Prima <nicolas@primetype.co.uk>
-- Stability   : stable
-- Portability : Good
--

{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE NoStarIsType #-}
#endif

module Data.ByteArray.Sized
    ( ByteArrayN(..)
    , SizedByteArray
    , unSizedByteArray
    , sizedByteArray
    , unsafeSizedByteArray

    , -- * ByteArrayN operators
      alloc
    , create
    , allocAndFreeze
    , unsafeCreate
    , inlineUnsafeCreate
    , empty
    , pack
    , unpack
    , cons
    , snoc
    , xor
    , index
    , splitAt
    , take
    , drop
    , append
    , copy
    , copyRet
    , copyAndFreeze
    , replicate
    , zero
    , convert
    , fromByteArrayAccess
    , unsafeFromByteArrayAccess
    ) where

import Basement.Imports
import Basement.NormalForm
import Basement.Nat
import Basement.Numerical.Additive ((+))
import Basement.Numerical.Subtractive ((-))

import Basement.Sized.List (ListN, unListN, toListN)

import           Foreign.Storable
import           Foreign.Ptr
import           Data.Maybe (fromMaybe)

import           Data.Memory.Internal.Compat
import           Data.Memory.PtrMethods

import Data.Proxy (Proxy(..))

import Data.ByteArray.Types (ByteArrayAccess(..), ByteArray)
import qualified Data.ByteArray.Types as ByteArray (allocRet)

#if MIN_VERSION_basement(0,0,7)
import           Basement.BlockN (BlockN)
import qualified Basement.BlockN as BlockN
import qualified Basement.PrimType as Base
import           Basement.Types.OffsetSize (Countable)
#endif

-- | Type class to emulate exactly the behaviour of 'ByteArray' but with
-- a known length at compile time
--
class (ByteArrayAccess c, KnownNat n) => ByteArrayN (n :: Nat) c | c -> n where
    -- | just like 'allocRet' but with the size at the type level
    allocRet :: forall p a
              . Proxy n
             -> (Ptr p -> IO a)
             -> IO (a, c)

-- | Wrapper around any collection type with the size as type parameter
--
newtype SizedByteArray (n :: Nat) ba = SizedByteArray { unSizedByteArray :: ba }
  deriving (Eq, Show, Typeable, Ord, NormalForm)

-- | create a 'SizedByteArray' from the given 'ByteArrayAccess' if the
-- size is the same as the target size.
--
sizedByteArray :: forall n ba . (KnownNat n, ByteArrayAccess ba)
               => ba
               -> Maybe (SizedByteArray n ba)
sizedByteArray ba
    | length ba == n = Just $ SizedByteArray ba
    | otherwise      = Nothing
  where
    n = fromInteger $ natVal (Proxy @n)

-- | just like the 'sizedByteArray' function but throw an exception if
-- the size is invalid.
unsafeSizedByteArray :: forall n ba . (ByteArrayAccess ba, KnownNat n) => ba -> SizedByteArray n ba
unsafeSizedByteArray = fromMaybe (error "The size is invalid") . sizedByteArray

instance (ByteArrayAccess ba, KnownNat n) => ByteArrayAccess (SizedByteArray n ba) where
    length _ = fromInteger $ natVal (Proxy @n)
    withByteArray (SizedByteArray ba) = withByteArray ba

instance (KnownNat n, ByteArray ba) => ByteArrayN n (SizedByteArray n ba) where
    allocRet p f = do
        (a, ba) <- ByteArray.allocRet n f
        pure (a, SizedByteArray ba)
      where
        n = fromInteger $ natVal p

#if MIN_VERSION_basement(0,0,7)
instance ( ByteArrayAccess (BlockN n ty)
         , PrimType ty
         , KnownNat n
         , Countable ty n
         , KnownNat nbytes
         , nbytes ~ (Base.PrimSize ty * n)
         ) => ByteArrayN nbytes (BlockN n ty) where
    allocRet _ f = do
        mba <- BlockN.new @n
        a   <- BlockN.withMutablePtrHint True False mba (f . castPtr)
        ba  <- BlockN.freeze mba
        return (a, ba)
#endif


-- | Allocate a new bytearray of specific size, and run the initializer on this memory
alloc :: forall n ba p . (ByteArrayN n ba, KnownNat n)
      => (Ptr p -> IO ())
      -> IO ba
alloc f = snd <$> allocRet (Proxy @n) f

-- | Allocate a new bytearray of specific size, and run the initializer on this memory
create :: forall n ba p . (ByteArrayN n ba, KnownNat n)
       => (Ptr p -> IO ())
       -> IO ba
create = alloc @n
{-# NOINLINE create #-}

-- | similar to 'allocN' but hide the allocation and initializer in a pure context
allocAndFreeze :: forall n ba p . (ByteArrayN n ba, KnownNat n)
               => (Ptr p -> IO ()) -> ba
allocAndFreeze f = unsafeDoIO (alloc @n f)
{-# NOINLINE allocAndFreeze #-}

-- | similar to 'createN' but hide the allocation and initializer in a pure context
unsafeCreate :: forall n ba p . (ByteArrayN n ba, KnownNat n)
             => (Ptr p -> IO ()) -> ba
unsafeCreate f = unsafeDoIO (alloc @n f)
{-# NOINLINE unsafeCreate #-}

inlineUnsafeCreate :: forall n ba p . (ByteArrayN n ba, KnownNat n)
                   => (Ptr p -> IO ()) -> ba
inlineUnsafeCreate f = unsafeDoIO (alloc @n f)
{-# INLINE inlineUnsafeCreate #-}

-- | Create an empty byte array
empty :: forall ba . ByteArrayN 0 ba => ba
empty = unsafeDoIO (alloc @0 $ \_ -> return ())

-- | Pack a list of bytes into a bytearray
pack :: forall n ba . (ByteArrayN n ba, KnownNat n) => ListN n Word8 -> ba
pack l = inlineUnsafeCreate @n (fill $ unListN l)
  where fill []     _  = return ()
        fill (x:xs) !p = poke p x >> fill xs (p `plusPtr` 1)
        {-# INLINE fill #-}
{-# NOINLINE pack #-}

-- | Un-pack a bytearray into a list of bytes
unpack :: forall n ba
        . (ByteArrayN n ba, KnownNat n, NatWithinBound Int n, ByteArrayAccess ba)
       => ba -> ListN n Word8
unpack bs =  fromMaybe (error "the impossible appened") $ toListN @n $ loop 0
  where !len = length bs
        loop i
            | i == len  = []
            | otherwise =
                let !v = unsafeDoIO $ withByteArray bs (`peekByteOff` i)
                 in v : loop (i+1)

-- | prepend a single byte to a byte array
cons :: forall ni no bi bo
      . ( ByteArrayN ni bi, ByteArrayN no bo, ByteArrayAccess bi
        , KnownNat ni, KnownNat no
        , (ni + 1) ~ no
        )
     => Word8 -> bi -> bo
cons b ba = unsafeCreate @no $ \d -> withByteArray ba $ \s -> do
    pokeByteOff d 0 b
    memCopy (d `plusPtr` 1) s len
  where
    !len = fromInteger $ natVal (Proxy @ni)

-- | append a single byte to a byte array
snoc :: forall bi bo ni no
      . ( ByteArrayN ni bi, ByteArrayN no bo, ByteArrayAccess bi
        , KnownNat ni, KnownNat no
        , (ni + 1) ~ no
        )
     => bi -> Word8 -> bo
snoc ba b = unsafeCreate @no $ \d -> withByteArray ba $ \s -> do
    memCopy d s len
    pokeByteOff d len b
  where
    !len = fromInteger $ natVal (Proxy @ni)

-- | Create a xor of bytes between a and b.
--
-- the returns byte array is the size of the smallest input.
xor :: forall n a b c
     . ( ByteArrayN n a, ByteArrayN n b, ByteArrayN n c
       , ByteArrayAccess a, ByteArrayAccess b
       , KnownNat n
       )
    => a -> b -> c
xor a b =
    unsafeCreate @n $ \pc ->
    withByteArray a  $ \pa ->
    withByteArray b  $ \pb ->
        memXor pc pa pb n
  where
    n  = fromInteger (natVal (Proxy @n))

-- | return a specific byte indexed by a number from 0 in a bytearray
--
-- unsafe, no bound checking are done
index :: forall n na ba
       . ( ByteArrayN na ba, ByteArrayAccess ba
         , KnownNat na, KnownNat n
         , n <= na
         )
      => ba -> Proxy n -> Word8
index b pi = unsafeDoIO $ withByteArray b $ \p -> peek (p `plusPtr` i)
  where
    i = fromInteger $ natVal pi

-- | Split a bytearray at a specific length in two bytearray
splitAt :: forall nblhs nbi nbrhs bi blhs brhs
         . ( ByteArrayN nbi bi, ByteArrayN nblhs blhs, ByteArrayN nbrhs brhs
           , ByteArrayAccess bi
           , KnownNat nbi, KnownNat nblhs, KnownNat nbrhs
           , nblhs <= nbi, (nbrhs + nblhs) ~ nbi
           )
        => bi -> (blhs, brhs)
splitAt bs = unsafeDoIO $
    withByteArray bs $ \p -> do
        b1 <- alloc @nblhs $ \r -> memCopy r p n
        b2 <- alloc @nbrhs $ \r -> memCopy r (p `plusPtr` n) (len - n)
        return (b1, b2)
  where
    n = fromInteger $ natVal (Proxy @nblhs)
    len = length bs

-- | Take the first @n@ byte of a bytearray
take :: forall nbo nbi bi bo
      . ( ByteArrayN nbi bi, ByteArrayN nbo bo
        , ByteArrayAccess bi
        , KnownNat nbi, KnownNat nbo
        , nbo <= nbi
        )
     => bi -> bo
take bs = unsafeCreate @nbo $ \d -> withByteArray bs $ \s -> memCopy d s m
  where
    !m   = min len n
    !len = length bs
    !n   = fromInteger $ natVal (Proxy @nbo)

-- | drop the first @n@ byte of a bytearray
drop :: forall n nbi nbo bi bo
      . ( ByteArrayN nbi bi, ByteArrayN nbo bo
        , ByteArrayAccess bi
        , KnownNat n, KnownNat nbi, KnownNat nbo
        , (nbo + n) ~ nbi
        )
     => Proxy n -> bi -> bo
drop pn bs = unsafeCreate @nbo $ \d ->
    withByteArray bs $ \s ->
    memCopy d (s `plusPtr` ofs) nb
  where
    ofs = min len n
    nb  = len - ofs
    len = length bs
    n   = fromInteger $ natVal pn

-- | append one bytearray to the other
append :: forall nblhs nbrhs nbout blhs brhs bout
        . ( ByteArrayN nblhs blhs, ByteArrayN nbrhs brhs, ByteArrayN nbout bout
          , ByteArrayAccess blhs, ByteArrayAccess brhs
          , KnownNat nblhs, KnownNat nbrhs, KnownNat nbout
          , (nbrhs + nblhs) ~ nbout
          )
       => blhs -> brhs -> bout
append blhs brhs = unsafeCreate @nbout $ \p ->
    withByteArray blhs $ \plhs ->
    withByteArray brhs $ \prhs -> do
        memCopy p plhs (length blhs)
        memCopy (p `plusPtr` length blhs) prhs (length brhs)

-- | Duplicate a bytearray into another bytearray, and run an initializer on it
copy :: forall n bs1 bs2 p
      . ( ByteArrayN n bs1, ByteArrayN n bs2
        , ByteArrayAccess bs1
        , KnownNat n
        )
     => bs1 -> (Ptr p -> IO ()) -> IO bs2
copy bs f = alloc @n $ \d -> do
    withByteArray bs $ \s -> memCopy d s (length bs)
    f (castPtr d)

-- | Similar to 'copy' but also provide a way to return a value from the initializer
copyRet :: forall n bs1 bs2 p a
         . ( ByteArrayN n bs1, ByteArrayN n bs2
           , ByteArrayAccess bs1
           , KnownNat n
           )
        => bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
copyRet bs f =
    allocRet (Proxy @n) $ \d -> do
        withByteArray bs $ \s -> memCopy d s (length bs)
        f (castPtr d)

-- | Similiar to 'copy' but expect the resulting bytearray in a pure context
copyAndFreeze :: forall n bs1 bs2 p
               . ( ByteArrayN n bs1, ByteArrayN n bs2
                 , ByteArrayAccess bs1
                 , KnownNat n
                 )
              => bs1 -> (Ptr p -> IO ()) -> bs2
copyAndFreeze bs f =
    inlineUnsafeCreate @n $ \d -> do
        copyByteArrayToPtr bs d
        f (castPtr d)
{-# NOINLINE copyAndFreeze #-}

-- | Create a bytearray of a specific size containing a repeated byte value
replicate :: forall n ba . (ByteArrayN n ba, KnownNat n)
          => Word8 -> ba
replicate b = inlineUnsafeCreate @n $ \ptr -> memSet ptr b (fromInteger $ natVal $ Proxy @n)
{-# NOINLINE replicate #-}

-- | Create a bytearray of a specific size initialized to 0
zero :: forall n ba . (ByteArrayN n ba, KnownNat n) => ba
zero = unsafeCreate @n $ \ptr -> memSet ptr 0 (fromInteger $ natVal $ Proxy @n)
{-# NOINLINE zero #-}

-- | Convert a bytearray to another type of bytearray
convert :: forall n bin bout
         . ( ByteArrayN n bin, ByteArrayN n bout
           , KnownNat n
           )
        => bin -> bout
convert bs = inlineUnsafeCreate @n (copyByteArrayToPtr bs)

-- | Convert a ByteArrayAccess to another type of bytearray
--
-- This function returns nothing if the size is not compatible
fromByteArrayAccess :: forall n bin bout
                     . ( ByteArrayAccess bin, ByteArrayN n bout
                       , KnownNat n
                       )
                    => bin -> Maybe bout
fromByteArrayAccess bs
    | l == n    = Just $ inlineUnsafeCreate @n (copyByteArrayToPtr bs)
    | otherwise = Nothing
  where
    l = length bs
    n = fromInteger $ natVal (Proxy @n)

-- | Convert a ByteArrayAccess to another type of bytearray
unsafeFromByteArrayAccess :: forall n bin bout
                           . ( ByteArrayAccess bin, ByteArrayN n bout
                             , KnownNat n
                           )
                          => bin -> bout
unsafeFromByteArrayAccess bs = case fromByteArrayAccess @n @bin @bout bs of
    Nothing -> error "Invalid Size"
    Just v  -> v