{-# LANGUAGE CPP           #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE UnboxedTuples #-}
-- | Blocked bit array implementation. This uses blocks of 64 bytes, aligned
-- to 64byte boundaries to match typical cache line sizes. This means that
-- multiple accesses to the same block only require a single cache line load
-- or store.
module Data.BloomFilter.Blocked.BitArray (
    NumBlocks (..),
    bitsToBlocks,
    blocksToBits,
    BlockIx (..),
    BitIx (..),
    BitArray (..),
    unsafeIndex,
    prefetchIndex,
    MBitArray (..),
    new,
    unsafeSet,
    prefetchSet,
    freeze,
    unsafeFreeze,
    thaw,
    serialise,
    deserialise,
) where

import           Control.Exception (assert)
import           Control.Monad.Primitive (PrimMonad, PrimState)
import           Control.Monad.ST (ST)
import           Data.Bits
import           Data.Primitive.ByteArray
import           Data.Primitive.PrimArray
import           Data.Word (Word64, Word8)

import           GHC.Exts (Int (I#), prefetchByteArray0#,
                     prefetchMutableByteArray0#)
import           GHC.ST (ST (ST))

-- | An array of blocks of bits.
--
-- Each block is 512 bits (64 bytes large), corresponding to a cache line on
-- most current architectures.
--
-- It is represented by an array of 'Word64'. This array is aligned to 64 bytes
-- so that multiple accesses within a single block will use only one cache line.
--
newtype BitArray = BitArray (PrimArray Word64)
  deriving stock (BitArray -> BitArray -> Bool
(BitArray -> BitArray -> Bool)
-> (BitArray -> BitArray -> Bool) -> Eq BitArray
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BitArray -> BitArray -> Bool
== :: BitArray -> BitArray -> Bool
$c/= :: BitArray -> BitArray -> Bool
/= :: BitArray -> BitArray -> Bool
Eq, Int -> BitArray -> ShowS
[BitArray] -> ShowS
BitArray -> String
(Int -> BitArray -> ShowS)
-> (BitArray -> String) -> ([BitArray] -> ShowS) -> Show BitArray
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BitArray -> ShowS
showsPrec :: Int -> BitArray -> ShowS
$cshow :: BitArray -> String
show :: BitArray -> String
$cshowList :: [BitArray] -> ShowS
showList :: [BitArray] -> ShowS
Show)

-- | Blocks are 512 bits, 64 bytes.
newtype NumBlocks = NumBlocks Int
  deriving stock NumBlocks -> NumBlocks -> Bool
(NumBlocks -> NumBlocks -> Bool)
-> (NumBlocks -> NumBlocks -> Bool) -> Eq NumBlocks
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NumBlocks -> NumBlocks -> Bool
== :: NumBlocks -> NumBlocks -> Bool
$c/= :: NumBlocks -> NumBlocks -> Bool
/= :: NumBlocks -> NumBlocks -> Bool
Eq

-- | The number of 512-bit blocks for the given number of bits. This rounds
-- up to the nearest multiple of 512.
bitsToBlocks :: Int -> NumBlocks
bitsToBlocks :: Int -> NumBlocks
bitsToBlocks Int
n = Int -> NumBlocks
NumBlocks ((Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
511) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
512)  -- rounded up

blocksToBits :: NumBlocks -> Int
blocksToBits :: NumBlocks -> Int
blocksToBits (NumBlocks Int
n) = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
512

newtype BlockIx = BlockIx Word
newtype BitIx   = BitIx   Int

{-# INLINE unsafeIndex #-}
unsafeIndex :: BitArray -> BlockIx -> BitIx -> Bool
unsafeIndex :: BitArray -> BlockIx -> BitIx -> Bool
unsafeIndex (BitArray PrimArray Word64
arr) BlockIx
blockIx BitIx
blockBitIx =
    Bool -> Bool -> Bool
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
wordIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
wordIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< PrimArray Word64 -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray Word64
arr) (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    PrimArray Word64 -> Int -> Word64
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray Word64
arr Int
wordIx Word64 -> Int -> Bool
`unsafeTestBit` Int
wordBitIx
  where
    (Int
wordIx, Int
wordBitIx) = BlockIx -> BitIx -> (Int, Int)
wordAndBitIndex BlockIx
blockIx BitIx
blockBitIx

{-# INLINE prefetchIndex #-}
prefetchIndex :: BitArray -> BlockIx -> ST s ()
prefetchIndex :: forall s. BitArray -> BlockIx -> ST s ()
prefetchIndex (BitArray (PrimArray ByteArray#
ba#)) (BlockIx Word
blockIx) =
    -- For reading, we want to prefetch such that we do least disturbance of
    -- the caches. We will typically not keep this cache line longer than one
    -- use of elemHashes which does several memory reads of the same cache line.
    let !i :: Int
i@(I# Int#
i#) = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
blockIx Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
6 in
    -- blockIx * 64 to go from block index to the byte offset of the beginning
    -- of the block. This offset is in bytes, not words.

    Bool -> ST s () -> ST s ()
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< ByteArray -> Int
sizeofByteArray (ByteArray# -> ByteArray
ByteArray ByteArray#
ba#) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
63) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$

    -- In prefetchByteArray0, the 0 refers to a "non temporal" load, which is
    -- a hint that the value will be used soon, and then not used again (soon).
    -- So the caches can evict the value as soon as they like.
    STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
ST (\State# s
s -> case ByteArray# -> Int# -> State# s -> State# s
forall d. ByteArray# -> Int# -> State# d -> State# d
prefetchByteArray0# ByteArray#
ba# Int#
i# State# s
s of
                State# s
s' -> (# State# s
s', () #))

newtype MBitArray s = MBitArray (MutablePrimArray s Word64)

-- | We create an explicitly pinned byte array, aligned to 64 bytes.
--
new :: NumBlocks -> ST s (MBitArray s)
new :: forall s. NumBlocks -> ST s (MBitArray s)
new (NumBlocks Int
numBlocks) = do
    mba :: MutableByteArray s
mba@(MutableByteArray MutableByteArray# s
mba#) <- Int -> Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> Int -> m (MutableByteArray (PrimState m))
newAlignedPinnedByteArray Int
numBytes Int
64
    MutableByteArray (PrimState (ST s))
-> Int -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> Int -> a -> m ()
setByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba Int
0 Int
numBytes (Word8
0 :: Word8)
    MBitArray s -> ST s (MBitArray s)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MutablePrimArray s Word64 -> MBitArray s
forall s. MutablePrimArray s Word64 -> MBitArray s
MBitArray (MutableByteArray# s -> MutablePrimArray s Word64
forall s a. MutableByteArray# s -> MutablePrimArray s a
MutablePrimArray MutableByteArray# s
mba#))
  where
    !numBytes :: Int
numBytes = Int
numBlocks Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
64

serialise :: BitArray -> (ByteArray, Int, Int)
serialise :: BitArray -> (ByteArray, Int, Int)
serialise BitArray
bitArray =
    let ba :: ByteArray
ba = BitArray -> ByteArray
asByteArray BitArray
bitArray
     in (ByteArray
ba, Int
0, ByteArray -> Int
sizeofByteArray ByteArray
ba)
  where
    asByteArray :: BitArray -> ByteArray
asByteArray (BitArray (PrimArray ByteArray#
ba#)) = ByteArray# -> ByteArray
ByteArray ByteArray#
ba#

{-# INLINE deserialise #-}
-- | Do an inplace overwrite of the byte array representing the bit block.
deserialise :: PrimMonad m
            => MBitArray (PrimState m)
            -> (MutableByteArray (PrimState m) -> Int -> Int -> m ())
            -> m ()
deserialise :: forall (m :: * -> *).
PrimMonad m =>
MBitArray (PrimState m)
-> (MutableByteArray (PrimState m) -> Int -> Int -> m ()) -> m ()
deserialise MBitArray (PrimState m)
bitArray MutableByteArray (PrimState m) -> Int -> Int -> m ()
fill = do
    let mba :: MutableByteArray (PrimState m)
mba = MBitArray (PrimState m) -> MutableByteArray (PrimState m)
forall {s}. MBitArray s -> MutableByteArray s
asMutableByteArray MBitArray (PrimState m)
bitArray
    Int
len <- MutableByteArray (PrimState m) -> m Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray (PrimState m)
mba
    MutableByteArray (PrimState m) -> Int -> Int -> m ()
fill MutableByteArray (PrimState m)
mba Int
0 Int
len
  where
    asMutableByteArray :: MBitArray s -> MutableByteArray s
asMutableByteArray (MBitArray (MutablePrimArray MutableByteArray# s
mba#)) =
      MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
mba#

unsafeSet :: MBitArray s -> BlockIx -> BitIx -> ST s ()
unsafeSet :: forall s. MBitArray s -> BlockIx -> BitIx -> ST s ()
unsafeSet (MBitArray MutablePrimArray s Word64
arr) BlockIx
blockIx BitIx
blockBitIx = do
#ifdef NO_IGNORE_ASSERTS
    sz <- getSizeofMutablePrimArray arr
    assert (wordIx >= 0 && wordIx < sz) $ pure ()
#endif
    Word64
w <- MutablePrimArray (PrimState (ST s)) Word64 -> Int -> ST s Word64
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr Int
wordIx
    MutablePrimArray (PrimState (ST s)) Word64
-> Int -> Word64 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr Int
wordIx (Word64 -> Int -> Word64
unsafeSetBit Word64
w Int
wordBitIx)
  where
    (Int
wordIx, Int
wordBitIx) = BlockIx -> BitIx -> (Int, Int)
wordAndBitIndex BlockIx
blockIx BitIx
blockBitIx

{-# INLINE prefetchSet #-}
prefetchSet :: MBitArray s -> BlockIx -> ST s ()
prefetchSet :: forall s. MBitArray s -> BlockIx -> ST s ()
prefetchSet (MBitArray (MutablePrimArray MutableByteArray# s
mba#)) (BlockIx Word
blockIx) = do
    -- For setting, we will do several writes to the same cache line, but all
    -- immediately after each other, after which we will not need the value in
    -- the cache again (for a long time). So as with prefetchIndex we want to
    -- disturbe the caches the least, and so we use prefetchMutableByteArray0.
    let !(I# Int#
i#) = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
blockIx Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
6
    -- blockIx * 64 to go from block index to the byte offset of the beginning
    -- of the block. This offset is in bytes, not words.

#ifdef NO_IGNORE_ASSERTS
    sz <- getSizeofMutableByteArray (MutableByteArray mba#)
    assert (let i = I# i# in i >= 0 && i < sz-63) $ pure ()
#endif

    -- In prefetchMutableByteArray0, the 0 refers to a "non temporal" load,
    -- which is a hint that the value will be used soon, and then not used
    -- again (soon). So the caches can evict the value as soon as they like.
    STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
ST (\State# s
s -> case MutableByteArray# s -> Int# -> State# s -> State# s
forall d. MutableByteArray# d -> Int# -> State# d -> State# d
prefetchMutableByteArray0# MutableByteArray# s
mba# Int#
i# State# s
s of
                State# s
s' -> (# State# s
s', () #))

freeze :: MBitArray s -> ST s BitArray
freeze :: forall s. MBitArray s -> ST s BitArray
freeze (MBitArray MutablePrimArray s Word64
arr) = do
    Int
len <- MutablePrimArray (PrimState (ST s)) Word64 -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr
    PrimArray Word64 -> BitArray
BitArray (PrimArray Word64 -> BitArray)
-> ST s (PrimArray Word64) -> ST s BitArray
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState (ST s)) Word64
-> Int -> Int -> ST s (PrimArray Word64)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> Int -> Int -> m (PrimArray a)
freezePrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr Int
0 Int
len

unsafeFreeze :: MBitArray s -> ST s BitArray
unsafeFreeze :: forall s. MBitArray s -> ST s BitArray
unsafeFreeze (MBitArray MutablePrimArray s Word64
arr) =
    PrimArray Word64 -> BitArray
BitArray (PrimArray Word64 -> BitArray)
-> ST s (PrimArray Word64) -> ST s BitArray
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutablePrimArray (PrimState (ST s)) Word64
-> ST s (PrimArray Word64)
forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray s Word64
MutablePrimArray (PrimState (ST s)) Word64
arr

thaw :: BitArray -> ST s (MBitArray s)
thaw :: forall s. BitArray -> ST s (MBitArray s)
thaw (BitArray PrimArray Word64
arr) =
    MutablePrimArray s Word64 -> MBitArray s
forall s. MutablePrimArray s Word64 -> MBitArray s
MBitArray (MutablePrimArray s Word64 -> MBitArray s)
-> ST s (MutablePrimArray s Word64) -> ST s (MBitArray s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimArray Word64
-> Int -> Int -> ST s (MutablePrimArray (PrimState (ST s)) Word64)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimArray a -> Int -> Int -> m (MutablePrimArray (PrimState m) a)
thawPrimArray PrimArray Word64
arr Int
0 (PrimArray Word64 -> Int
forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray Word64
arr)

{-# INLINE wordAndBitIndex #-}
-- | Given the index of the 512 bit block, and the index of the bit within the
-- block, compute the index of the word in the array, and index of the bit
-- within the word.
--
wordAndBitIndex :: BlockIx -> BitIx -> (Int, Int)
wordAndBitIndex :: BlockIx -> BitIx -> (Int, Int)
wordAndBitIndex (BlockIx Word
blockIx) (BitIx Int
blockBitIx) =
    Bool -> (Int, Int) -> (Int, Int)
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
blockBitIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
512) ((Int, Int) -> (Int, Int)) -> (Int, Int) -> (Int, Int)
forall a b. (a -> b) -> a -> b
$
    (Int
wordIx, Int
wordBitIx)
  where
    -- Select the Word64 in the underlying array based on the block index
    -- and the bit index.
    -- * There are 8 Word64s in each 64byte block.
    -- * Use 3 bits (bits 6..8) to select the Word64 within the block
    wordIx :: Int
wordIx    = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
blockIx Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
3 -- * 8
              Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
blockBitIx Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
6) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
7 -- `div` 64, `mod` 8

    -- Bits 0..5 of blockBitIx select the bit within Word64
    wordBitIx :: Int
wordBitIx = Int
blockBitIx Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
63 -- `mod` 64

{-# INLINE unsafeTestBit #-}
-- like testBit but using unsafeShiftL instead of shiftL
unsafeTestBit :: Word64 -> Int -> Bool
unsafeTestBit :: Word64 -> Int -> Bool
unsafeTestBit Word64
w Int
k = Word64
w Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. (Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k) Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word64
0

{-# INLINE unsafeSetBit #-}
-- like setBit but using unsafeShiftL instead of shiftL
unsafeSetBit :: Word64 -> Int -> Word64
unsafeSetBit :: Word64 -> Int -> Word64
unsafeSetBit Word64
w Int
k = Word64
w Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k)