1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
|
-- |
-- Module : Basement.Sized.Block
-- License : BSD-style
-- Maintainer : Haskell Foundation
--
-- A Nat-sized version of Block
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE NoStarIsType #-}
#endif
module Basement.Sized.Block
( BlockN
, MutableBlockN
, length
, lengthBytes
, toBlockN
, toBlock
, new
, newPinned
, singleton
, replicate
, thaw
, freeze
, index
, indexStatic
, map
, foldl'
, foldr
, cons
, snoc
, elem
, sub
, uncons
, unsnoc
, splitAt
, all
, any
, find
, reverse
, sortBy
, intersperse
, withPtr
, withMutablePtr
, withMutablePtrHint
, cast
, mutableCast
) where
import Data.Proxy (Proxy(..))
import Basement.Compat.Base
import Basement.Numerical.Additive (scale)
import Basement.Block (Block, MutableBlock(..), unsafeIndex)
import qualified Basement.Block as B
import qualified Basement.Block.Base as B
import Basement.Monad (PrimMonad, PrimState)
import Basement.Nat
import Basement.Types.OffsetSize
import Basement.NormalForm
import Basement.PrimType (PrimType, PrimSize, primSizeInBytes)
-- | Sized version of 'Block'
--
newtype BlockN (n :: Nat) a = BlockN { unBlock :: Block a }
deriving (NormalForm, Eq, Show, Data, Ord)
newtype MutableBlockN (n :: Nat) ty st = MutableBlockN { unMBlock :: MutableBlock ty st }
toBlockN :: forall n ty . (PrimType ty, KnownNat n, Countable ty n) => Block ty -> Maybe (BlockN n ty)
toBlockN b
| expected == B.length b = Just (BlockN b)
| otherwise = Nothing
where
expected = toCount @n
length :: forall n ty
. (KnownNat n, Countable ty n)
=> BlockN n ty
-> CountOf ty
length _ = toCount @n
lengthBytes :: forall n ty
. PrimType ty
=> BlockN n ty
-> CountOf Word8
lengthBytes = B.lengthBytes . unBlock
toBlock :: BlockN n ty -> Block ty
toBlock = unBlock
cast :: forall n m a b
. ( PrimType a, PrimType b
, KnownNat n, KnownNat m
, ((PrimSize b) * m) ~ ((PrimSize a) * n)
)
=> BlockN n a
-> BlockN m b
cast (BlockN b) = BlockN (B.unsafeCast b)
mutableCast :: forall n m a b st
. ( PrimType a, PrimType b
, KnownNat n, KnownNat m
, ((PrimSize b) * m) ~ ((PrimSize a) * n)
)
=> MutableBlockN n a st
-> MutableBlockN m b st
mutableCast (MutableBlockN b) = MutableBlockN (B.unsafeRecast b)
-- | Create a new unpinned mutable block of a specific N size of 'ty' elements
--
-- If the size exceeds a GHC-defined threshold, then the memory will be
-- pinned. To be certain about pinning status with small size, use 'newPinned'
new :: forall n ty prim
. (PrimType ty, KnownNat n, Countable ty n, PrimMonad prim)
=> prim (MutableBlockN n ty (PrimState prim))
new = MutableBlockN <$> B.new (toCount @n)
-- | Create a new pinned mutable block of a specific N size of 'ty' elements
newPinned :: forall n ty prim
. (PrimType ty, KnownNat n, Countable ty n, PrimMonad prim)
=> prim (MutableBlockN n ty (PrimState prim))
newPinned = MutableBlockN <$> B.newPinned (toCount @n)
singleton :: PrimType ty => ty -> BlockN 1 ty
singleton a = BlockN (B.singleton a)
replicate :: forall n ty . (KnownNat n, Countable ty n, PrimType ty) => ty -> BlockN n ty
replicate a = BlockN (B.replicate (toCount @n) a)
thaw :: (KnownNat n, PrimMonad prim, PrimType ty) => BlockN n ty -> prim (MutableBlockN n ty (PrimState prim))
thaw b = MutableBlockN <$> B.thaw (unBlock b)
freeze :: (PrimMonad prim, PrimType ty, Countable ty n) => MutableBlockN n ty (PrimState prim) -> prim (BlockN n ty)
freeze b = BlockN <$> B.freeze (unMBlock b)
indexStatic :: forall i n ty . (KnownNat i, CmpNat i n ~ 'LT, PrimType ty, Offsetable ty i) => BlockN n ty -> ty
indexStatic b = unsafeIndex (unBlock b) (toOffset @i)
index :: forall i n ty . PrimType ty => BlockN n ty -> Offset ty -> ty
index b ofs = B.index (unBlock b) ofs
map :: (PrimType a, PrimType b) => (a -> b) -> BlockN n a -> BlockN n b
map f b = BlockN (B.map f (unBlock b))
foldl' :: PrimType ty => (a -> ty -> a) -> a -> BlockN n ty -> a
foldl' f acc b = B.foldl' f acc (unBlock b)
foldr :: PrimType ty => (ty -> a -> a) -> a -> BlockN n ty -> a
foldr f acc b = B.foldr f acc (unBlock b)
cons :: PrimType ty => ty -> BlockN n ty -> BlockN (n+1) ty
cons e = BlockN . B.cons e . unBlock
snoc :: PrimType ty => BlockN n ty -> ty -> BlockN (n+1) ty
snoc b = BlockN . B.snoc (unBlock b)
sub :: forall i j n ty
. ( (i <=? n) ~ 'True
, (j <=? n) ~ 'True
, (i <=? j) ~ 'True
, PrimType ty
, KnownNat i
, KnownNat j
, Offsetable ty i
, Offsetable ty j )
=> BlockN n ty
-> BlockN (j-i) ty
sub block = BlockN (B.sub (unBlock block) (toOffset @i) (toOffset @j))
uncons :: forall n ty . (CmpNat 0 n ~ 'LT, PrimType ty, KnownNat n, Offsetable ty n)
=> BlockN n ty
-> (ty, BlockN (n-1) ty)
uncons b = (indexStatic @0 b, BlockN (B.sub (unBlock b) 1 (toOffset @n)))
unsnoc :: forall n ty . (CmpNat 0 n ~ 'LT, KnownNat n, PrimType ty, Offsetable ty n)
=> BlockN n ty
-> (BlockN (n-1) ty, ty)
unsnoc b =
( BlockN (B.sub (unBlock b) 0 (toOffset @n `offsetSub` 1))
, unsafeIndex (unBlock b) (toOffset @n `offsetSub` 1))
splitAt :: forall i n ty . (CmpNat i n ~ 'LT, PrimType ty, KnownNat i, Countable ty i) => BlockN n ty -> (BlockN i ty, BlockN (n-i) ty)
splitAt b =
let (left, right) = B.splitAt (toCount @i) (unBlock b)
in (BlockN left, BlockN right)
elem :: PrimType ty => ty -> BlockN n ty -> Bool
elem e b = B.elem e (unBlock b)
all :: PrimType ty => (ty -> Bool) -> BlockN n ty -> Bool
all p b = B.all p (unBlock b)
any :: PrimType ty => (ty -> Bool) -> BlockN n ty -> Bool
any p b = B.any p (unBlock b)
find :: PrimType ty => (ty -> Bool) -> BlockN n ty -> Maybe ty
find p b = B.find p (unBlock b)
reverse :: PrimType ty => BlockN n ty -> BlockN n ty
reverse = BlockN . B.reverse . unBlock
sortBy :: PrimType ty => (ty -> ty -> Ordering) -> BlockN n ty -> BlockN n ty
sortBy f b = BlockN (B.sortBy f (unBlock b))
intersperse :: (CmpNat n 1 ~ 'GT, PrimType ty) => ty -> BlockN n ty -> BlockN (n+n-1) ty
intersperse sep b = BlockN (B.intersperse sep (unBlock b))
toCount :: forall n ty . (KnownNat n, Countable ty n) => CountOf ty
toCount = natValCountOf (Proxy @n)
toOffset :: forall n ty . (KnownNat n, Offsetable ty n) => Offset ty
toOffset = natValOffset (Proxy @n)
-- | Get a Ptr pointing to the data in the Block.
--
-- Since a Block is immutable, this Ptr shouldn't be
-- to use to modify the contents
--
-- If the Block is pinned, then its address is returned as is,
-- however if it's unpinned, a pinned copy of the Block is made
-- before getting the address.
withPtr :: (PrimMonad prim, KnownNat n)
=> BlockN n ty
-> (Ptr ty -> prim a)
-> prim a
withPtr b = B.withPtr (unBlock b)
-- | Create a pointer on the beginning of the MutableBlock
-- and call a function 'f'.
--
-- The mutable block can be mutated by the 'f' function
-- and the change will be reflected in the mutable block
--
-- If the mutable block is unpinned, a trampoline buffer
-- is created and the data is only copied when 'f' return.
--
-- it is all-in-all highly inefficient as this cause 2 copies
withMutablePtr :: (PrimMonad prim, KnownNat n)
=> MutableBlockN n ty (PrimState prim)
-> (Ptr ty -> prim a)
-> prim a
withMutablePtr mb = B.withMutablePtr (unMBlock mb)
-- | Same as 'withMutablePtr' but allow to specify 2 optimisations
-- which is only useful when the MutableBlock is unpinned and need
-- a pinned trampoline to be called safely.
--
-- If skipCopy is True, then the first copy which happen before
-- the call to 'f', is skipped. The Ptr is now effectively
-- pointing to uninitialized data in a new mutable Block.
--
-- If skipCopyBack is True, then the second copy which happen after
-- the call to 'f', is skipped. Then effectively in the case of a
-- trampoline being used the memory changed by 'f' will not
-- be reflected in the original Mutable Block.
--
-- If using the wrong parameters, it will lead to difficult to
-- debug issue of corrupted buffer which only present themselves
-- with certain Mutable Block that happened to have been allocated
-- unpinned.
--
-- If unsure use 'withMutablePtr', which default to *not* skip
-- any copy.
withMutablePtrHint :: forall n ty prim a . (PrimMonad prim, KnownNat n)
=> Bool -- ^ hint that the buffer doesn't need to have the same value as the mutable block when calling f
-> Bool -- ^ hint that the buffer is not supposed to be modified by call of f
-> MutableBlockN n ty (PrimState prim)
-> (Ptr ty -> prim a)
-> prim a
withMutablePtrHint skipCopy skipCopyBack (MutableBlockN mb) f =
B.withMutablePtrHint skipCopy skipCopyBack mb f
|