{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fexpose-all-unfoldings #-}
module KMerge.Heap (
MutableHeap (..),
newMutableHeap,
replaceRoot,
extract,
) where
import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad (PrimState), RealWorld)
import qualified Control.Monad.ST as Lazy
import qualified Control.Monad.ST as Strict
import Data.Bits (unsafeShiftL, unsafeShiftR)
import Data.Foldable.WithIndex (ifor_)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Primitive (SmallMutableArray, newSmallArray,
readSmallArray, writeSmallArray)
import Data.Primitive.PrimVar (PrimVar, newPrimVar, readPrimVar,
writePrimVar)
import Unsafe.Coerce (unsafeCoerce)
data MutableHeap s a = MH
!(PrimVar s Int)
!(SmallMutableArray s a)
placeholder :: a
placeholder :: forall a. a
placeholder = () -> a
forall a b. a -> b
unsafeCoerce ()
newMutableHeap :: forall a m. (PrimMonad m, Ord a) => NonEmpty a -> m (MutableHeap (PrimState m) a, a)
newMutableHeap :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
NonEmpty a -> m (MutableHeap (PrimState m) a, a)
newMutableHeap NonEmpty a
xs = do
let !size :: Int
size = NonEmpty a -> Int
forall a. NonEmpty a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty a
xs
SmallMutableArray (PrimState m) a
arr <- Int -> a -> m (SmallMutableArray (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (SmallMutableArray (PrimState m) a)
newSmallArray Int
size a
forall a. a
placeholder
NonEmpty a -> (Int -> a -> m ()) -> m ()
forall i (t :: * -> *) (f :: * -> *) a b.
(FoldableWithIndex i t, Applicative f) =>
t a -> (i -> a -> f b) -> f ()
ifor_ NonEmpty a
xs ((Int -> a -> m ()) -> m ()) -> (Int -> a -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
idx a
x -> do
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
x
SmallMutableArray (PrimState m) a -> a -> Int -> m ()
forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> a -> Int -> m ()
siftUp SmallMutableArray (PrimState m) a
arr a
x Int
idx
PrimVar (PrimState m) Int
sizeRef <- Int -> m (PrimVar (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
size
a
x <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
0
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
0 a
forall a. a
placeholder
(MutableHeap (PrimState m) a, a)
-> m (MutableHeap (PrimState m) a, a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ((MutableHeap (PrimState m) a, a)
-> m (MutableHeap (PrimState m) a, a))
-> (MutableHeap (PrimState m) a, a)
-> m (MutableHeap (PrimState m) a, a)
forall a b. (a -> b) -> a -> b
$! (PrimVar (PrimState m) Int
-> SmallMutableArray (PrimState m) a -> MutableHeap (PrimState m) a
forall s a.
PrimVar s Int -> SmallMutableArray s a -> MutableHeap s a
MH PrimVar (PrimState m) Int
sizeRef SmallMutableArray (PrimState m) a
arr, a
x)
replaceRoot :: forall a m. (PrimMonad m, Ord a) => MutableHeap (PrimState m) a -> a -> m a
replaceRoot :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
MutableHeap (PrimState m) a -> a -> m a
replaceRoot (MH PrimVar (PrimState m) Int
sizeRef SmallMutableArray (PrimState m) a
arr) a
val = do
Int
size <- PrimVar (PrimState m) Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar (PrimState m) Int
sizeRef
if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1
then a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
val
else do
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
0 a
val
SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
siftDown SmallMutableArray (PrimState m) a
arr Int
size a
val Int
0
a
x <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
0
a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
{-# SPECIALIZE replaceRoot :: forall a. Ord a => MutableHeap RealWorld a -> a -> IO a #-}
{-# SPECIALIZE replaceRoot :: forall a s. Ord a => MutableHeap s a -> a -> Strict.ST s a #-}
{-# SPECIALIZE replaceRoot :: forall a s. Ord a => MutableHeap s a -> a -> Lazy.ST s a #-}
extract :: forall a m. (PrimMonad m, Ord a) => MutableHeap (PrimState m) a -> m (Maybe a)
(MH PrimVar (PrimState m) Int
sizeRef SmallMutableArray (PrimState m) a
arr) = do
Int
size <- PrimVar (PrimState m) Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar (PrimState m) Int
sizeRef
if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1
then Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
else do
PrimVar (PrimState m) Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar (PrimState m) Int
sizeRef (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$! Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
a
val <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
0 a
val
SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
siftDown SmallMutableArray (PrimState m) a
arr Int
size a
val Int
0
a
x <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
0
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a
forall a. a
placeholder
Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$! a -> Maybe a
forall a. a -> Maybe a
Just a
x
{-# SPECIALIZE extract :: forall a. Ord a => MutableHeap RealWorld a -> IO (Maybe a) #-}
{-# SPECIALIZE extract :: forall a s. Ord a => MutableHeap s a -> Strict.ST s (Maybe a) #-}
{-# SPECIALIZE extract :: forall a s. Ord a => MutableHeap s a -> Lazy.ST s (Maybe a) #-}
siftUp :: forall a m. (PrimMonad m, Ord a) => SmallMutableArray (PrimState m) a -> a -> Int -> m ()
siftUp :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> a -> Int -> m ()
siftUp !SmallMutableArray (PrimState m) a
arr !a
x = Int -> m ()
loop where
loop :: Int -> m ()
loop !Int
idx
| Int
idx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
= () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise
= do
let !parent :: Int
parent = Int -> Int
halfOf (Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
a
p <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
parent
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
p) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
parent a
x
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
p
Int -> m ()
loop Int
parent
{-# SPECIALIZE siftUp :: forall a. Ord a => SmallMutableArray RealWorld a -> a -> Int -> IO () #-}
{-# SPECIALIZE siftUp :: forall a s. Ord a => SmallMutableArray s a -> a -> Int -> Strict.ST s () #-}
{-# SPECIALIZE siftUp :: forall a s. Ord a => SmallMutableArray s a -> a -> Int -> Lazy.ST s () #-}
siftDown :: forall a m. (PrimMonad m, Ord a) => SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
siftDown :: forall a (m :: * -> *).
(PrimMonad m, Ord a) =>
SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m ()
siftDown !SmallMutableArray (PrimState m) a
arr !Int
size !a
x = Int -> m ()
loop where
loop :: Int -> m ()
loop !Int
idx
| Int
rgt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
size
= do
a
l <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
lft
a
r <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
rgt
if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
l
then do
if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
r
then () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else do
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
rgt a
x
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
r
Int -> m ()
loop Int
rgt
else do
if a
l a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
r
then do
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
l
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
lft a
x
Int -> m ()
loop Int
lft
else do
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
rgt a
x
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
r
Int -> m ()
loop Int
rgt
| Int
lft Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
size
= do
a
l <- SmallMutableArray (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> m a
readSmallArray SmallMutableArray (PrimState m) a
arr Int
lft
if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
l
then () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else do
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
idx a
l
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
SmallMutableArray (PrimState m) a -> Int -> a -> m ()
writeSmallArray SmallMutableArray (PrimState m) a
arr Int
lft a
x
| Bool
otherwise
= () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
where
!lft :: Int
lft = Int -> Int
doubleOf Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
!rgt :: Int
rgt = Int -> Int
doubleOf Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
{-# SPECIALIZE siftDown :: forall a. Ord a => SmallMutableArray RealWorld a -> Int -> a -> Int -> IO () #-}
{-# SPECIALIZE siftDown :: forall a s. Ord a => SmallMutableArray s a -> Int -> a -> Int -> Strict.ST s () #-}
{-# SPECIALIZE siftDown :: forall a s. Ord a => SmallMutableArray s a -> Int -> a -> Int -> Lazy.ST s () #-}
halfOf :: Int -> Int
halfOf :: Int -> Int
halfOf Int
i = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
1
{-# INLINE halfOf #-}
doubleOf :: Int -> Int
doubleOf :: Int -> Int
doubleOf Int
i = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
i Int
1
{-# INLINE doubleOf #-}