{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# OPTIONS_HADDOCK not-home #-}

module Database.LSMTree.Internal.Vector (
    mkPrimVector,
    byteVectorFromPrim,
    noRetainedExtraMemory,
    primArrayToPrimVector,
    mapStrict,
    mapMStrict,
    imapMStrict,
    forMStrict,
    zipWithStrict,
    binarySearchL,
    unsafeInsertWithMStrict,
    unfoldrNM',
) where

import           Control.Monad
import           Control.Monad.Primitive (PrimMonad, PrimState)
import qualified Data.Primitive as P
import           Data.Primitive.ByteArray (ByteArray, newByteArray,
                     runByteArray, sizeofByteArray, writeByteArray)
import           Data.Primitive.Types (Prim (sizeOfType#), sizeOfType)
import           Data.Proxy (Proxy (..))
import qualified Data.Vector as V
import qualified Data.Vector.Algorithms.Search as VA
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Primitive as VP
import           Data.Word (Word8)
import           Database.LSMTree.Internal.Assertions
import           GHC.Exts (Int (..))
import           GHC.ST (runST)

mkPrimVector :: forall a. Prim a => Int -> Int -> ByteArray -> VP.Vector a
mkPrimVector :: forall a. Prim a => Int -> Int -> ByteArray -> Vector a
mkPrimVector Int
off Int
len ByteArray
ba =
    Bool -> Vector a -> Vector a
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int -> Int -> ByteArray -> Bool
isValidSlice (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sizeof) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sizeof) ByteArray
ba) (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$
    Int -> Int -> ByteArray -> Vector a
forall a. Int -> Int -> ByteArray -> Vector a
VP.Vector Int
off Int
len ByteArray
ba
  where
    sizeof :: Int
sizeof = Int# -> Int
I# (Proxy a -> Int#
forall a. Prim a => Proxy a -> Int#
sizeOfType# (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a))
{-# INLINE mkPrimVector #-}

byteVectorFromPrim :: forall a. Prim a => a -> VP.Vector Word8
byteVectorFromPrim :: forall a. Prim a => a -> Vector Word8
byteVectorFromPrim a
prim = Int -> Int -> ByteArray -> Vector Word8
forall a. Prim a => Int -> Int -> ByteArray -> Vector a
mkPrimVector Int
0 (forall a. Prim a => Int
sizeOfType @a) (ByteArray -> Vector Word8) -> ByteArray -> Vector Word8
forall a b. (a -> b) -> a -> b
$
                          (forall s. ST s (MutableByteArray s)) -> ByteArray
runByteArray ((forall s. ST s (MutableByteArray s)) -> ByteArray)
-> (forall s. ST s (MutableByteArray s)) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
                              MutableByteArray s
rep <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (forall a. Prim a => Int
sizeOfType @a)
                              MutableByteArray (PrimState (ST s)) -> Int -> a -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
rep Int
0 a
prim
                              MutableByteArray s -> ST s (MutableByteArray s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MutableByteArray s
rep
{-# INLINE byteVectorFromPrim #-}

noRetainedExtraMemory :: forall a. Prim a => VP.Vector a -> Bool
noRetainedExtraMemory :: forall a. Prim a => Vector a -> Bool
noRetainedExtraMemory (VP.Vector Int
off Int
len ByteArray
ba) =
    Int
off Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sizeof Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ByteArray -> Int
sizeofByteArray ByteArray
ba
   where
    sizeof :: Int
sizeof = Int# -> Int
I# (Proxy a -> Int#
forall a. Prim a => Proxy a -> Int#
sizeOfType# (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a))

{-# INLINE primArrayToPrimVector #-}
primArrayToPrimVector :: Prim a => P.PrimArray a -> VP.Vector a
primArrayToPrimVector :: forall a. Prim a => PrimArray a -> Vector a
primArrayToPrimVector pa :: PrimArray a
pa@(P.PrimArray ByteArray#
ba) =
    Int -> Int -> ByteArray -> Vector a
forall a. Int -> Int -> ByteArray -> Vector a
VP.Vector Int
0 (PrimArray a -> Int
forall a. Prim a => PrimArray a -> Int
P.sizeofPrimArray PrimArray a
pa) (ByteArray# -> ByteArray
P.ByteArray ByteArray#
ba)

{-# INLINE mapStrict #-}
-- | /( O(n) /) Like 'V.map', but strict in the produced elements of type @b@.
mapStrict :: forall a b. (a -> b) -> V.Vector a -> V.Vector b
mapStrict :: forall a b. (a -> b) -> Vector a -> Vector b
mapStrict a -> b
f Vector a
v = (forall s. ST s (Vector b)) -> Vector b
forall a. (forall s. ST s a) -> a
runST ((a -> ST s b) -> Vector a -> ST s (Vector b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (\a
x -> b -> ST s b
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> ST s b) -> b -> ST s b
forall a b. (a -> b) -> a -> b
$! a -> b
f a
x) Vector a
v)

{-# INLINE mapMStrict #-}
-- | /( O(n) /) Like 'V.mapM', but strict in the produced elements of type @b@.
mapMStrict :: Monad m => (a -> m b) -> V.Vector a -> m (V.Vector b)
mapMStrict :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
mapMStrict a -> m b
f Vector a
v = (a -> m b) -> Vector a -> m (Vector b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM (a -> m b
f (a -> m b) -> (b -> m b) -> a -> m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure $!)) Vector a
v

{-# INLINE imapMStrict #-}
-- | /( O(n) /) Like 'V.imapM', but strict in the produced elements of type @b@.
imapMStrict :: Monad m => (Int -> a -> m b) -> V.Vector a -> m (V.Vector b)
imapMStrict :: forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m (Vector b)
imapMStrict Int -> a -> m b
f Vector a
v = (Int -> a -> m b) -> Vector a -> m (Vector b)
forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m (Vector b)
V.imapM (\Int
i -> Int -> a -> m b
f Int
i (a -> m b) -> (b -> m b) -> a -> m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure $!)) Vector a
v

{-# INLINE zipWithStrict #-}
-- | /( O(min(m,n)) /) Like 'V.zipWithM', but strict in the produced elements of
-- type @c@.
zipWithStrict :: forall a b c. (a -> b -> c) -> V.Vector a -> V.Vector b -> V.Vector c
zipWithStrict :: forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWithStrict a -> b -> c
f Vector a
xs Vector b
ys = (forall s. ST s (Vector c)) -> Vector c
forall a. (forall s. ST s a) -> a
runST ((a -> b -> ST s c) -> Vector a -> Vector b -> ST s (Vector c)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> Vector a -> Vector b -> m (Vector c)
V.zipWithM (\a
x b
y -> c -> ST s c
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (c -> ST s c) -> c -> ST s c
forall a b. (a -> b) -> a -> b
$! a -> b -> c
f a
x b
y) Vector a
xs Vector b
ys)

-- | /( O(n) /) Like 'V.forM', but strict in the produced elements of type @b@.
{-# INLINE forMStrict #-}
forMStrict :: Monad m => V.Vector a -> (a -> m b) -> m (V.Vector b)
forMStrict :: forall (m :: * -> *) a b.
Monad m =>
Vector a -> (a -> m b) -> m (Vector b)
forMStrict Vector a
xs a -> m b
f = Vector a -> (a -> m b) -> m (Vector b)
forall (m :: * -> *) a b.
Monad m =>
Vector a -> (a -> m b) -> m (Vector b)
V.forM Vector a
xs (a -> m b
f (a -> m b) -> (b -> m b) -> a -> m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure $!))

{-|
    Finds the lowest index in a given sorted vector at which the given element
    could be inserted while maintaining the sortedness.

    This is a variant of 'Data.Vector.Algorithms.Search.binarySearchL' for
    immutable vectors.
-}
binarySearchL :: Ord a => V.Vector a -> a -> Int
binarySearchL :: forall a. Ord a => Vector a -> a -> Int
binarySearchL Vector a
vec a
val = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.unsafeThaw Vector a
vec ST s (MVector s a) -> (MVector s a -> ST s Int) -> ST s Int
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (MVector s a -> a -> ST s Int) -> a -> MVector s a -> ST s Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip MVector s a -> a -> ST s Int
MVector (PrimState (ST s)) a -> a -> ST s Int
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e, Ord e) =>
v (PrimState m) e -> e -> m Int
VA.binarySearchL a
val
{-# INLINE binarySearchL #-}

{-# INLINE unsafeInsertWithMStrict #-}
-- | Insert (in a broad sense) an entry in a mutable vector at a given index,
-- but if a @Just@ entry already exists at that index, combine the two entries
-- using @f@.
unsafeInsertWithMStrict ::
     PrimMonad m
  => VM.MVector (PrimState m) (Maybe a)
  -> (a -> a -> a)  -- ^ function @f@, called as @f new old@
  -> Int
  -> a
  -> m ()
unsafeInsertWithMStrict :: forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) (Maybe a)
-> (a -> a -> a) -> Int -> a -> m ()
unsafeInsertWithMStrict MVector (PrimState m) (Maybe a)
mvec a -> a -> a
f Int
i a
y = MVector (PrimState m) (Maybe a)
-> (Maybe a -> m (Maybe a)) -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> m a) -> Int -> m ()
VM.unsafeModifyM MVector (PrimState m) (Maybe a)
mvec Maybe a -> m (Maybe a)
g Int
i
  where
    g :: Maybe a -> m (Maybe a)
g Maybe a
x = Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$! a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$! a -> (a -> a) -> Maybe a -> a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe a
y (a -> a -> a
`f` a
y) Maybe a
x

{-# INLINE unfoldrNM' #-}
-- | A version of 'V.unfoldrNM' that also returns the final state.
--
-- /O(n)/ Construct a vector by repeatedly applying the monadic generator
-- function to a seed. The generator function also yields 'Just' the next
-- element or 'Nothing' if there are no more elements.
--
-- The state as well as all elements of the result vector are forced to weak
-- head normal form.
unfoldrNM' :: PrimMonad m => Int -> (b -> m (Maybe a, b)) -> b -> m (V.Vector a, b)
unfoldrNM' :: forall (m :: * -> *) b a.
PrimMonad m =>
Int -> (b -> m (Maybe a, b)) -> b -> m (Vector a, b)
unfoldrNM' Int
len b -> m (Maybe a, b)
f = \b
b0 -> do
    MVector (PrimState m) a
vec <- Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
VM.unsafeNew Int
len
    MVector (PrimState m) a -> Int -> b -> m (Vector a, b)
go MVector (PrimState m) a
vec Int
0 b
b0
  where
    go :: MVector (PrimState m) a -> Int -> b -> m (Vector a, b)
go !MVector (PrimState m) a
vec !Int
n !b
b
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len = (, b
b) (Vector a -> (Vector a, b)) -> m (Vector a) -> m (Vector a, b)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector (PrimState m) a
vec
      | Bool
otherwise =
          b -> m (Maybe a, b)
f b
b m (Maybe a, b)
-> ((Maybe a, b) -> m (Vector a, b)) -> m (Vector a, b)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            (Maybe a
Nothing, !b
b') ->
              (, b
b') (Vector a -> (Vector a, b)) -> m (Vector a) -> m (Vector a, b)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> Int -> MVector (PrimState m) a -> MVector (PrimState m) a
forall s a. Int -> Int -> MVector s a -> MVector s a
VM.slice Int
0 Int
n MVector (PrimState m) a
vec)
            (Just !a
a,  !b
b') -> do
              MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.unsafeWrite MVector (PrimState m) a
vec Int
n a
a
              MVector (PrimState m) a -> Int -> b -> m (Vector a, b)
go MVector (PrimState m) a
vec (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) b
b'