{-# LANGUAGE CPP           #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE UnboxedTuples #-}
-- | Minimal bit array implementation.
module Data.BloomFilter.Classic.BitArray (
    BitArray (..),
    unsafeIndex,
    prefetchIndex,
    MBitArray (..),
    new,
    unsafeSet,
    unsafeRead,
    freeze,
    unsafeFreeze,
    thaw,
    serialise,
    deserialise,
) where

import           Control.Exception (assert)
import           Control.Monad.Primitive (PrimMonad, PrimState)
import           Control.Monad.ST (ST)
import           Data.Bits
import           Data.Primitive.ByteArray
import           Data.Primitive.PrimArray
import           Data.Word (Word64, Word8)

import           GHC.Exts (Int (I#), prefetchByteArray0#)
import           GHC.ST (ST (ST))

-- | Bit vector backed up by an array of Word64
--
-- This vector's offset and length are multiples of 64
newtype BitArray = BitArray (PrimArray Word64)
  deriving stock (BitArray -> BitArray -> Bool
(BitArray -> BitArray -> Bool)
-> (BitArray -> BitArray -> Bool) -> Eq BitArray
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BitArray -> BitArray -> Bool
== :: BitArray -> BitArray -> Bool
$c/= :: BitArray -> BitArray -> Bool
/= :: BitArray -> BitArray -> Bool
Eq, Int -> BitArray -> ShowS
[BitArray] -> ShowS
BitArray -> String
(Int -> BitArray -> ShowS)
-> (BitArray -> String) -> ([BitArray] -> ShowS) -> Show BitArray
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BitArray -> ShowS
showsPrec :: Int -> BitArray -> ShowS
$cshow :: BitArray -> String
show :: BitArray -> String
$cshowList :: [BitArray] -> ShowS
showList :: [BitArray] -> ShowS
Show)

{-# INLINE unsafeIndex #-}
unsafeIndex :: BitArray -> Int -> Bool
unsafeIndex :: BitArray -> Int -> Bool
unsafeIndex (BitArray PrimArray Word64
arr) !Int
i =
    Bool -> Bool -> Bool
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< PrimArray Word64 -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray Word64
arr) (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    Word64 -> Int -> Bool
unsafeTestBit (PrimArray Word64 -> Int -> Word64
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray Word64
arr Int
j) Int
k
  where
    !j :: Int
j = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
6 -- `div` 64, bit index to Word64 index.
    !k :: Int
k = Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
63         -- `mod` 64, bit within Word64

{-# INLINE prefetchIndex #-}
prefetchIndex :: BitArray -> Int -> ST s ()
prefetchIndex :: forall s. BitArray -> Int -> ST s ()
prefetchIndex (BitArray (PrimArray ByteArray#
ba#)) !Int
i =
    let !(I# Int#
bi#) = Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
3 in
    STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
ST (\State# s
s -> case ByteArray# -> Int# -> State# s -> State# s
forall d. ByteArray# -> Int# -> State# d -> State# d
prefetchByteArray0# ByteArray#
ba# Int#
bi# State# s
s of
                State# s
s' -> (# State# s
s', () #))
    -- We only need to shiftR 3 here, not 6, because we're going from a bit
    -- offset to a byte offset for prefetch. Whereas in unsafeIndex, we go from
    -- a bit offset to a Word64 offset, so an extra shiftR 3, for 6 total.

newtype MBitArray s = MBitArray (MutablePrimArray s Word64)

-- | Will create an explicitly pinned byte array.
-- This is done because pinned byte arrays allow for more efficient
-- serialisation, but the definition of 'isByteArrayPinned' changed in GHC 9.6,
-- see <https://gitlab.haskell.org/ghc/ghc/-/issues/22255>.
--
-- TODO: remove this workaround once a solution exists, e.g. a new primop that
-- allows checking for implicit pinning.
new :: Int -> ST s (MBitArray s)
new :: forall s. Int -> ST s (MBitArray s)
new Int
s = do
    mba :: MutableByteArray s
mba@(MutableByteArray MutableByteArray# s
mba#) <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newPinnedByteArray Int
numBytes
    MutableByteArray (PrimState (ST s))
-> Int -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> Int -> a -> m ()
setByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba Int
0 Int
numBytes (Word8
0 :: Word8)
    MBitArray s -> ST s (MBitArray s)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MutablePrimArray s Word64 -> MBitArray s
forall s. MutablePrimArray s Word64 -> MBitArray s
MBitArray (MutableByteArray# s -> MutablePrimArray s Word64
forall s a. MutableByteArray# s -> MutablePrimArray s a
MutablePrimArray MutableByteArray# s
mba#))
  where
    !numWords :: Int
numWords = Int -> Int
roundUpTo64 Int
s
    !numBytes :: Int
numBytes = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
numWords Int
3 -- * 8

    -- this may overflow, but so be it (2^64 bits is a lot)
    roundUpTo64 :: Int -> Int
    roundUpTo64 :: Int -> Int
roundUpTo64 Int
i = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
63) Int
6 -- `div` 64, rounded up

serialise :: BitArray -> (ByteArray, Int, Int)
serialise :: BitArray -> (ByteArray, Int, Int)
serialise BitArray
bitArray =
    let ba :: ByteArray
ba = BitArray -> ByteArray
asByteArray BitArray
bitArray
     in (ByteArray
ba, Int
0, ByteArray -> Int
sizeofByteArray ByteArray
ba)
  where
    asByteArray :: BitArray -> ByteArray
asByteArray (BitArray (PrimArray ByteArray#
ba#)) = ByteArray# -> ByteArray
ByteArray ByteArray#
ba#

{-# INLINE deserialise #-}
-- | Do an inplace overwrite of the byte array representing the bit block.
deserialise :: PrimMonad m
            => MBitArray (PrimState m)
            -> (MutableByteArray (PrimState m) -> Int -> Int -> m ())
            -> m ()
deserialise :: forall (m :: * -> *).
PrimMonad m =>
MBitArray (PrimState m)
-> (MutableByteArray (PrimState m) -> Int -> Int -> m ()) -> m ()
deserialise MBitArray (PrimState m)
bitArray MutableByteArray (PrimState m) -> Int -> Int -> m ()
fill = do
    let mba :: MutableByteArray (PrimState m)
mba = MBitArray (PrimState m) -> MutableByteArray (PrimState m)
forall {s}. MBitArray s -> MutableByteArray s
asMutableByteArray MBitArray (PrimState m)
bitArray
    Int
len <- MutableByteArray (PrimState m) -> m Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray (PrimState m)
mba
    MutableByteArray (PrimState m) -> Int -> Int -> m ()
fill MutableByteArray (PrimState m)
mba Int
0 Int
len
  where
    asMutableByteArray :: MBitArray s -> MutableByteArray s
asMutableByteArray (MBitArray (MutablePrimArray MutableByteArray# s
mba#)) =
      MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
mba#

unsafeSet :: MBitArray s -> Int -> ST s ()
unsafeSet :: forall s. MBitArray s -> Int -> ST s ()
unsafeSet (MBitArray MutablePrimArray s Word64
arr) Int
i = do
#ifdef NO_IGNORE_ASSERTS
    sz <- getSizeofMutablePrimArray arr
    assert (j >= 0 && j < sz) $ pure ()
#endif
    Word64
w <- MutablePrimArray (PrimState (ST s)) Word64 -> Int -> ST s Word64
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr Int
j
    MutablePrimArray (PrimState (ST s)) Word64
-> Int -> Word64 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr Int
j (Word64 -> Int -> Word64
unsafeSetBit Word64
w Int
k)
  where
    !j :: Int
j = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
6 -- `div` 64
    !k :: Int
k = Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
63         -- `mod` 64

unsafeRead :: MBitArray s -> Int -> ST s Bool
unsafeRead :: forall s. MBitArray s -> Int -> ST s Bool
unsafeRead (MBitArray MutablePrimArray s Word64
arr) Int
i = do
#ifdef NO_IGNORE_ASSERTS
    sz <- getSizeofMutablePrimArray arr
    assert (j >= 0 && j < sz) $ pure ()
#endif
    Word64
w <- MutablePrimArray (PrimState (ST s)) Word64 -> Int -> ST s Word64
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr Int
j
    Bool -> ST s Bool
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> ST s Bool) -> Bool -> ST s Bool
forall a b. (a -> b) -> a -> b
$! Word64 -> Int -> Bool
unsafeTestBit Word64
w Int
k
  where
    !j :: Int
j = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
6 -- `div` 64
    !k :: Int
k = Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
63         -- `mod` 64

freeze :: MBitArray s -> ST s BitArray
freeze :: forall s. MBitArray s -> ST s BitArray
freeze (MBitArray MutablePrimArray s Word64
arr) = do
    Int
len <- MutablePrimArray (PrimState (ST s)) Word64 -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr
    PrimArray Word64 -> BitArray
BitArray (PrimArray Word64 -> BitArray)
-> ST s (PrimArray Word64) -> ST s BitArray
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState (ST s)) Word64
-> Int -> Int -> ST s (PrimArray Word64)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> Int -> Int -> m (PrimArray a)
freezePrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr Int
0 Int
len

unsafeFreeze :: MBitArray s -> ST s BitArray
unsafeFreeze :: forall s. MBitArray s -> ST s BitArray
unsafeFreeze (MBitArray MutablePrimArray s Word64
arr) =
    PrimArray Word64 -> BitArray
BitArray (PrimArray Word64 -> BitArray)
-> ST s (PrimArray Word64) -> ST s BitArray
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState (ST s)) Word64
-> ST s (PrimArray Word64)
forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr

thaw :: BitArray -> ST s (MBitArray s)
thaw :: forall s. BitArray -> ST s (MBitArray s)
thaw (BitArray PrimArray Word64
arr) =
    MutablePrimArray s Word64 -> MBitArray s
forall s. MutablePrimArray s Word64 -> MBitArray s
MBitArray (MutablePrimArray s Word64 -> MBitArray s)
-> ST s (MutablePrimArray s Word64) -> ST s (MBitArray s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimArray Word64
-> Int -> Int -> ST s (MutablePrimArray (PrimState (ST s)) Word64)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimArray a -> Int -> Int -> m (MutablePrimArray (PrimState m) a)
thawPrimArray PrimArray Word64
arr Int
0 (PrimArray Word64 -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray Word64
arr)

{-# INLINE unsafeTestBit #-}
-- like testBit but using unsafeShiftL instead of shiftL
unsafeTestBit :: Word64 -> Int -> Bool
unsafeTestBit :: Word64 -> Int -> Bool
unsafeTestBit Word64
w Int
k = Word64
w Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. (Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k) Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word64
0

{-# INLINE unsafeSetBit #-}
-- like setBit but using unsafeShiftL instead of shiftL
unsafeSetBit :: Word64 -> Int -> Word64
unsafeSetBit :: Word64 -> Int -> Word64
unsafeSetBit Word64
w Int
k = Word64
w Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k)