{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fexpose-all-unfoldings #-}

-- | Mutable heap for k-merge algorithm.
--
-- This data-structure represents a min-heap with the root node *removed*.
-- (internally the filling of root value and sifting down is delayed).
--
-- Also there isn't *insert* operation, i.e. the heap can only shrink.
-- Other heap usual heap opeartions are *create-heap*, *extract-min* and *replace*.
-- However, as the 'MutableHeap' always represents a heap with its root (minimum value)
-- extracted, *extract-min* is "fused" to other operations.
module KMerge.Heap (
    MutableHeap (..),
    newMutableHeap,
    replaceRoot,
    extract,
) where

import           Control.Monad (when)
import           Control.Monad.Primitive (PrimMonad (PrimState), RealWorld)
import qualified Control.Monad.ST as Lazy
import qualified Control.Monad.ST as Strict
import           Data.Bits (unsafeShiftL, unsafeShiftR)
import           Data.Foldable.WithIndex (ifor_)
import           Data.List.NonEmpty (NonEmpty (..))
import           Data.Primitive (SmallMutableArray, newSmallArray,
                     readSmallArray, writeSmallArray)
import           Data.Primitive.PrimVar (PrimVar, newPrimVar, readPrimVar,
                     writePrimVar)
import           Unsafe.Coerce (unsafeCoerce)

-- | Mutable heap for k-merge algorithm.
data MutableHeap s a = MH
    !(PrimVar s Int) -- ^ element count, size
    !(SmallMutableArray s a)

-- | Placeholder value used to fill the internal array.
placeholder :: a
placeholder :: forall a. a
placeholder = () -> a
forall a b. a -> b
unsafeCoerce ()

-- | Create new heap, and immediately extract its minimum value.
newMutableHeap :: forall a m. (PrimMonad m, Ord a) => NonEmpty a -> m (MutableHeap (PrimState m) a, a)
newMutableHeap :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
NonEmpty a -> m (MutableHeap (PrimState m) a, a)
newMutableHeap NonEmpty a
xs = do
    let !size :: Int
size = NonEmpty a -> Int
forall a. NonEmpty a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty a
xs

    SmallMutableArray (PrimState m) a
arr <- Int -> a -> m (SmallMutableArray (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (SmallMutableArray (PrimState m) a)
newSmallArray Int
size a
forall a. a
placeholder
    NonEmpty a -> (Int -> a -> m ()) -> m ()
forall i (t :: * -> *) (f :: * -> *) a b.
(FoldableWithIndex i t, Applicative f) =>
t a -> (i -> a -> f b) -> f ()
ifor_ NonEmpty a
xs ((Int -> a -> m ()) -> m ()) -> (Int -> a -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
idx a
x -> do
        SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
x
        SmallMutableArray (PrimState m) a -> a -> Int -> m ()
forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> a -> Int -> m ()
siftUp SmallMutableArray (PrimState m) a
arr a
x Int
idx

    PrimVar (PrimState m) Int
sizeRef <- Int -> m (PrimVar (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
size

    -- This indexing is safe!
    -- Due to the required NonEmpty input type, there must be at least one element to read.
    a
x <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
0
    SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
0 a
forall a. a
placeholder
    (MutableHeap (PrimState m) a, a)
-> m (MutableHeap (PrimState m) a, a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ((MutableHeap (PrimState m) a, a)
 -> m (MutableHeap (PrimState m) a, a))
-> (MutableHeap (PrimState m) a, a)
-> m (MutableHeap (PrimState m) a, a)
forall a b. (a -> b) -> a -> b
$! (PrimVar (PrimState m) Int
-> SmallMutableArray (PrimState m) a -> MutableHeap (PrimState m) a
forall s a.
PrimVar s Int -> SmallMutableArray s a -> MutableHeap s a
MH PrimVar (PrimState m) Int
sizeRef SmallMutableArray (PrimState m) a
arr, a
x)

-- | Replace the minimum-value, and immediately extract the new minimum value.
replaceRoot :: forall a m. (PrimMonad m, Ord a) => MutableHeap (PrimState m) a -> a -> m a
replaceRoot :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
MutableHeap (PrimState m) a -> a -> m a
replaceRoot (MH PrimVar (PrimState m) Int
sizeRef SmallMutableArray (PrimState m) a
arr) a
val = do
    Int
size <- PrimVar (PrimState m) Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar (PrimState m) Int
sizeRef
    if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1
    then a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
val
    else do
        SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
0 a
val
        SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
siftDown SmallMutableArray (PrimState m) a
arr Int
size a
val Int
0
        a
x <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
0
        a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

{-# SPECIALIZE replaceRoot :: forall a.   Ord a => MutableHeap RealWorld a -> a -> IO          a #-}
{-# SPECIALIZE replaceRoot :: forall a s. Ord a => MutableHeap s         a -> a -> Strict.ST s a #-}
{-# SPECIALIZE replaceRoot :: forall a s. Ord a => MutableHeap s         a -> a -> Lazy.ST   s a #-}

-- | Extract the next minimum value.
extract :: forall a m. (PrimMonad m, Ord a) => MutableHeap (PrimState m) a -> m (Maybe a)
extract :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
MutableHeap (PrimState m) a -> m (Maybe a)
extract (MH PrimVar (PrimState m) Int
sizeRef SmallMutableArray (PrimState m) a
arr) = do
    Int
size <- PrimVar (PrimState m) Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar (PrimState m) Int
sizeRef
    if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1
    then Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
    else do
        PrimVar (PrimState m) Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar (PrimState m) Int
sizeRef (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$! Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        a
val <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
0 a
val
        SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
siftDown SmallMutableArray (PrimState m) a
arr Int
size a
val Int
0
        a
x <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
0
        SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a
forall a. a
placeholder
        Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$! a -> Maybe a
forall a. a -> Maybe a
Just a
x

{-# SPECIALIZE extract :: forall a.   Ord a => MutableHeap RealWorld a -> IO          (Maybe a) #-}
{-# SPECIALIZE extract :: forall a s. Ord a => MutableHeap s         a -> Strict.ST s (Maybe a) #-}
{-# SPECIALIZE extract :: forall a s. Ord a => MutableHeap s         a -> Lazy.ST   s (Maybe a) #-}

{-------------------------------------------------------------------------------
  Internal operations
-------------------------------------------------------------------------------}

siftUp :: forall a m. (PrimMonad m, Ord a) => SmallMutableArray (PrimState m) a -> a -> Int -> m ()
siftUp :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> a -> Int -> m ()
siftUp !SmallMutableArray (PrimState m) a
arr !a
x = Int -> m ()
loop where
    loop :: Int -> m ()
loop !Int
idx
        | Int
idx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
        = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

        | Bool
otherwise
        = do
            let !parent :: Int
parent = Int -> Int
halfOf (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
            a
p <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
parent
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
p) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
              SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
parent a
x
              SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx    a
p
              Int -> m ()
loop Int
parent

{-# SPECIALIZE siftUp :: forall a.   Ord a => SmallMutableArray RealWorld a -> a -> Int -> IO          () #-}
{-# SPECIALIZE siftUp :: forall a s. Ord a => SmallMutableArray s         a -> a -> Int -> Strict.ST s () #-}
{-# SPECIALIZE siftUp :: forall a s. Ord a => SmallMutableArray s         a -> a -> Int -> Lazy.ST   s () #-}

siftDown :: forall a m. (PrimMonad m, Ord a) => SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
siftDown :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
siftDown !SmallMutableArray (PrimState m) a
arr !Int
size !a
x = Int -> m ()
loop where
    loop :: Int -> m ()
loop !Int
idx
        | Int
rgt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
size
        = do
            a
l <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
lft
            a
r <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
rgt

            if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
l
            then do
                if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
r
                then () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                else do
                    -- r < x <= l; swap x and r
                    SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
rgt a
x
                    SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
r
                    Int -> m ()
loop Int
rgt
            else do
                if a
l a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
r
                then do
                    -- l < x, l <= r; swap x and l
                    SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
l
                    SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
lft a
x
                    Int -> m ()
loop Int
lft
                else do
                    -- r < l <= x; swap x and r
                    SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
rgt a
x
                    SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
r
                    Int -> m ()
loop Int
rgt

        -- there's only left value
        | Int
lft Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
size
        = do
            a
l <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
lft
            if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
l
            then () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            else do
                SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
l
                SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
lft a
x
                -- there is no need to loop further, lft was the last value.

        | Bool
otherwise
        = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      where
        !lft :: Int
lft = Int -> Int
doubleOf Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        !rgt :: Int
rgt = Int -> Int
doubleOf Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2

{-# SPECIALIZE siftDown :: forall a.   Ord a => SmallMutableArray RealWorld a -> Int -> a -> Int -> IO          () #-}
{-# SPECIALIZE siftDown :: forall a s. Ord a => SmallMutableArray s         a -> Int -> a -> Int -> Strict.ST s () #-}
{-# SPECIALIZE siftDown :: forall a s. Ord a => SmallMutableArray s         a -> Int -> a -> Int -> Lazy.ST   s () #-}

{-------------------------------------------------------------------------------
  Helpers
-------------------------------------------------------------------------------}

halfOf :: Int -> Int
halfOf :: Int -> Int
halfOf Int
i = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
1
{-# INLINE halfOf #-}

doubleOf :: Int -> Int
doubleOf :: Int -> Int
doubleOf Int
i = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
i Int
1
{-# INLINE doubleOf #-}