{-# LANGUAGE CPP                    #-}
{-# LANGUAGE DefaultSignatures      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MagicHash              #-}
{-# LANGUAGE PatternSynonyms        #-}
{-# LANGUAGE TypeFamilies           #-}

module Control.RefCount (
    -- * Using references
    Ref(DeRef)
  , releaseRef
  , withRef
  , dupRef
  , RefException (..)
    -- ** Weak references
  , WeakRef (..)
  , mkWeakRef
  , mkWeakRefFromRaw
  , deRefWeak
    -- * Implementing objects with finalisers
  , RefCounted (..)
  , newRef
    -- ** Low level reference counts
  , RefCounter (RefCounter)
  , newRefCounter
  , incrementRefCounter
  , decrementRefCounter
  , tryIncrementRefCounter

  -- * Test API
  , checkForgottenRefs
  , ignoreForgottenRefs
  , enableForgottenRefChecks
  , disableForgottenRefChecks
  ) where

import           Control.DeepSeq
import           Control.Exception (assert)
import           Control.Monad (void, when)
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Primitive
import           Data.Primitive.PrimVar
import           GHC.Show (appPrec)
import           GHC.Stack (CallStack, prettyCallStack)

#ifdef NO_IGNORE_ASSERTS
import           Control.Concurrent (yield)
import           Data.IORef
import           GHC.Stack (HasCallStack, callStack)
import           System.IO.Unsafe (unsafeDupablePerformIO, unsafePerformIO)
import           System.Mem.Weak hiding (deRefWeak)
#if MIN_VERSION_base(4,20,0)
import           System.Mem (performBlockingMajorGC)
#else
import           System.Mem (performMajorGC)
#endif
#endif


-------------------------------------------------------------------------------
-- Low level RefCounter API
--

-- | A reference counter with an optional finaliser action. Once the reference
-- count reaches @0@, the finaliser will be run.
data RefCounter m = RefCounter {
    forall (m :: * -> *). RefCounter m -> PrimVar (PrimState m) Int
countVar  :: !(PrimVar (PrimState m) Int)
  , forall (m :: * -> *). RefCounter m -> m ()
finaliser :: !(m ())
  }

instance Show (RefCounter m) where
  show :: RefCounter m -> String
show RefCounter m
_ = String
"<RefCounter>"

-- | NOTE: Only strict in the variable and not the referenced value.
instance NFData (RefCounter m) where
  rnf :: RefCounter m -> ()
rnf RefCounter{PrimVar (PrimState m) Int
countVar :: forall (m :: * -> *). RefCounter m -> PrimVar (PrimState m) Int
countVar :: PrimVar (PrimState m) Int
countVar, m ()
finaliser :: forall (m :: * -> *). RefCounter m -> m ()
finaliser :: m ()
finaliser} =
      PrimVar (PrimState m) Int -> ()
forall a. a -> ()
rwhnf PrimVar (PrimState m) Int
countVar () -> () -> ()
forall a b. a -> b -> b
`seq` m () -> ()
forall a. a -> ()
rwhnf m ()
finaliser

{-# SPECIALISE newRefCounter :: IO () -> IO (RefCounter IO) #-}
-- | Make a reference counter with initial value @1@.
--
-- The given finaliser is run when the reference counter reaches @0@. The
-- finaliser is run with async exceptions masked.
--
newRefCounter :: PrimMonad m => m () -> m (RefCounter m)
newRefCounter :: forall (m :: * -> *). PrimMonad m => m () -> m (RefCounter m)
newRefCounter m ()
finaliser = do
    PrimVar (PrimState m) Int
countVar <- Int -> m (PrimVar (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
1
    RefCounter m -> m (RefCounter m)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RefCounter m -> m (RefCounter m))
-> RefCounter m -> m (RefCounter m)
forall a b. (a -> b) -> a -> b
$! RefCounter { PrimVar (PrimState m) Int
countVar :: PrimVar (PrimState m) Int
countVar :: PrimVar (PrimState m) Int
countVar, m ()
finaliser :: m ()
finaliser :: m ()
finaliser }

{-# SPECIALISE incrementRefCounter :: RefCounter IO -> IO () #-}
-- | Increase the reference counter by one.
--
-- The count must be known (from context) to be non-zero already. Typically
-- this will be because the caller has a reference already and is handing out
-- another reference to some other code.
incrementRefCounter :: PrimMonad m => RefCounter m -> m ()
incrementRefCounter :: forall (m :: * -> *). PrimMonad m => RefCounter m -> m ()
incrementRefCounter RefCounter{PrimVar (PrimState m) Int
countVar :: forall (m :: * -> *). RefCounter m -> PrimVar (PrimState m) Int
countVar :: PrimVar (PrimState m) Int
countVar} = do
    Int
prevCount <- PrimVar (PrimState m) Int -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m Int
fetchAddInt PrimVar (PrimState m) Int
countVar Int
1
    Bool -> m () -> m ()
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
prevCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

{-# SPECIALISE decrementRefCounter :: RefCounter IO -> IO () #-}
-- | Decrease the reference counter by one.
--
-- The count must be known (from context) to be non-zero. Typically this will
-- be because the caller has a reference already (that they took out themselves
-- or were given).
decrementRefCounter :: (PrimMonad m, MonadMask m) => RefCounter m -> m ()
decrementRefCounter :: forall (m :: * -> *).
(PrimMonad m, MonadMask m) =>
RefCounter m -> m ()
decrementRefCounter RefCounter{PrimVar (PrimState m) Int
countVar :: forall (m :: * -> *). RefCounter m -> PrimVar (PrimState m) Int
countVar :: PrimVar (PrimState m) Int
countVar, m ()
finaliser :: forall (m :: * -> *). RefCounter m -> m ()
finaliser :: m ()
finaliser} =
    --TODO: remove mask and require all uses to run with exceptions mask.
    m () -> m ()
forall a. m a -> m a
forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      Int
prevCount <- PrimVar (PrimState m) Int -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> m Int
fetchSubInt PrimVar (PrimState m) Int
countVar Int
1
      Bool -> m () -> m ()
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
prevCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
prevCount Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) m ()
finaliser

{-# SPECIALISE tryIncrementRefCounter :: RefCounter IO -> IO Bool #-}
-- | Try to turn a \"weak\" reference on something into a proper reference.
-- This is by analogy with @deRefWeak :: Weak v -> IO (Maybe v)@, but for
-- reference counts.
--
-- This amounts to trying to increase the reference count, but if it is already
-- zero then this will fail. And unlike with 'addReference' where such failure
-- would be a programmer error, this corresponds to the case when the thing the
-- reference count is tracking has been closed already.
--
-- The result is @True@ when a strong reference has been obtained and @False@
-- when upgrading fails.
--
tryIncrementRefCounter :: PrimMonad m => RefCounter m -> m Bool
tryIncrementRefCounter :: forall (m :: * -> *). PrimMonad m => RefCounter m -> m Bool
tryIncrementRefCounter RefCounter{PrimVar (PrimState m) Int
countVar :: forall (m :: * -> *). RefCounter m -> PrimVar (PrimState m) Int
countVar :: PrimVar (PrimState m) Int
countVar} = do
    Int
prevCount <- PrimVar (PrimState m) Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> m Int
atomicReadInt PrimVar (PrimState m) Int
countVar
    Int -> m Bool
casLoop Int
prevCount
  where
    -- A classic lock-free CAS loop.
    -- Check the value before is non-zero, return failure or continue.
    -- Atomically write the new (incremented) value if the old value is
    -- unchanged, and return the old value (either way).
    -- If no other thread changed the old value, we succeed.
    -- Otherwise we go round the loop again.
    casLoop :: Int -> m Bool
casLoop Int
prevCount
      | Int
prevCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      | Bool
otherwise      = do
          Int
prevCount' <- PrimVar (PrimState m) Int -> Int -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
PrimVar (PrimState m) Int -> Int -> Int -> m Int
casInt PrimVar (PrimState m) Int
countVar Int
prevCount (Int
prevCountInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
          if Int
prevCount' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
prevCount
            then Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            else Int -> m Bool
casLoop Int
prevCount'


-------------------------------------------------------------------------------
-- Ref API
--

-- | A reference to an object of type @a@. Use references to support prompt
-- finalisation of object resources.
--
-- Rules of use:
--
-- * Each 'Ref' must eventually be released /exactly/ once with 'releaseRef'.
-- * Use 'withRef', or 'DeRef' to (temporarily) obtain the underlying
--   object.
-- * After calling 'releaseRef', the operations 'withRef' and pattern 'DeRef'
--   must /not/ be used.
-- * After calling 'releaseRef', an object obtained previously from
--   'DeRef' must /not/ be used. For this reason, it is advisable to use
--   'withRef' where possible, and be careful with use of 'DeRef'.
-- * A 'Ref' may be duplicated using 'dupRef' to produce an independent
--   reference (which must itself be released with 'releaseRef').
--
-- All of these operations are thread safe. They are not async-exception safe
-- however: the operations that allocate or deallocate must be called with
-- async exceptions masked. This includes 'newRef', 'dupRef' and 'releaseRef'.
--
-- Provided that all these rules are followed, this guarantees that the
-- object's finaliser will be run exactly once, promptly, when the final
-- reference is released.
--
-- In debug mode (when using CPP define @NO_IGNORE_ASSERTS@), adherence to
-- these rules are checked dynamically. These dynamic checks are however not
-- thread safe, so it is not guaranteed that all violations are always detected.
--
#ifndef NO_IGNORE_ASSERTS
newtype Ref obj = Ref { forall obj. Ref obj -> obj
refobj :: obj }
#else
data    Ref obj = Ref { refobj :: !obj, reftracker :: !RefTracker }
#endif

instance Show obj => Show (Ref obj) where
  showsPrec :: Int -> Ref obj -> ShowS
showsPrec Int
d Ref{obj
refobj :: forall obj. Ref obj -> obj
refobj :: obj
refobj} =
    Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
      String -> ShowS
showString String
"Ref " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> obj -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 obj
refobj

instance NFData obj => NFData (Ref obj) where
  rnf :: Ref obj -> ()
rnf Ref{obj
refobj :: forall obj. Ref obj -> obj
refobj :: obj
refobj} = obj -> ()
forall a. NFData a => a -> ()
rnf obj
refobj

-- | Class of objects which support 'Ref'.
--
-- For objects in this class the guarantee is that (when the 'Ref' rules are
-- followed) the object's finaliser is called exactly once.
--
class RefCounted m obj | obj -> m where
  getRefCounter :: obj -> RefCounter m

#ifdef NO_IGNORE_ASSERTS
#define HasCallStackIfDebug HasCallStack
#else
#define HasCallStackIfDebug ()
#endif

{-# SPECIALISE
  newRef ::
       RefCounted IO obj
    => HasCallStackIfDebug
    => IO ()
    -> (RefCounter IO -> obj)
    -> IO (Ref obj)
  #-}
-- | Make a new reference.
--
-- The given finaliser is run when the last reference is released. The
-- finaliser is run with async exceptions masked.
--
newRef ::
     (RefCounted m obj, PrimMonad m)
  => HasCallStackIfDebug
  => m ()
  -> (RefCounter m -> obj)
  -> m (Ref obj)
newRef :: forall (m :: * -> *) obj.
(RefCounted m obj, PrimMonad m) =>
m () -> (RefCounter m -> obj) -> m (Ref obj)
newRef m ()
finaliser RefCounter m -> obj
mkObject = do
    RefCounter m
rc <- m () -> m (RefCounter m)
forall (m :: * -> *). PrimMonad m => m () -> m (RefCounter m)
newRefCounter m ()
finaliser
    let !obj :: obj
obj = RefCounter m -> obj
mkObject RefCounter m
rc
    Bool -> m (Ref obj) -> m (Ref obj)
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (RefCounter m -> PrimVar (PrimState m) Int
forall (m :: * -> *). RefCounter m -> PrimVar (PrimState m) Int
countVar (obj -> RefCounter m
forall (m :: * -> *) obj. RefCounted m obj => obj -> RefCounter m
getRefCounter obj
obj) PrimVar (PrimState m) Int -> PrimVar (PrimState m) Int -> Bool
forall a. Eq a => a -> a -> Bool
== RefCounter m -> PrimVar (PrimState m) Int
forall (m :: * -> *). RefCounter m -> PrimVar (PrimState m) Int
countVar RefCounter m
rc) (m (Ref obj) -> m (Ref obj)) -> m (Ref obj) -> m (Ref obj)
forall a b. (a -> b) -> a -> b
$
      obj -> m (Ref obj)
forall (m :: * -> *) obj. PrimMonad m => obj -> m (Ref obj)
newRefWithTracker obj
obj

-- | Release a reference to an object that will no longer be used (via this
-- reference).
--
{-# SPECIALISE
  releaseRef ::
       RefCounted IO obj
    => HasCallStackIfDebug
    => Ref obj
    -> IO ()
  #-}
releaseRef ::
     (RefCounted m obj, PrimMonad m, MonadMask m)
  => HasCallStackIfDebug
  => Ref obj
  -> m ()
releaseRef :: forall (m :: * -> *) obj.
(RefCounted m obj, PrimMonad m, MonadMask m) =>
Ref obj -> m ()
releaseRef ref :: Ref obj
ref@Ref{obj
refobj :: forall obj. Ref obj -> obj
refobj :: obj
refobj} = do
    Ref obj -> m ()
forall (m :: * -> *) a. PrimMonad m => Ref a -> m ()
assertNoDoubleRelease Ref obj
ref
    m ()
forall (m :: * -> *). PrimMonad m => m ()
assertNoForgottenRefs
    Ref obj -> m ()
forall (m :: * -> *) a. PrimMonad m => Ref a -> m ()
releaseRefTracker Ref obj
ref
    RefCounter m -> m ()
forall (m :: * -> *).
(PrimMonad m, MonadMask m) =>
RefCounter m -> m ()
decrementRefCounter (obj -> RefCounter m
forall (m :: * -> *) obj. RefCounted m obj => obj -> RefCounter m
getRefCounter obj
refobj)

{-# COMPLETE DeRef #-}
#if MIN_VERSION_GLASGOW_HASKELL(9,0,0,0)
{-# INLINE DeRef #-}
#endif
-- | Get the object in a 'Ref'. Be careful with retaining the object for too
-- long, since the object must not be used after 'releaseRef' is called.
--
pattern DeRef :: HasCallStackIfDebug => obj -> Ref obj
#ifndef NO_IGNORE_ASSERTS
pattern $mDeRef :: forall {r} {obj}. Ref obj -> (obj -> r) -> ((# #) -> r) -> r
DeRef obj <- Ref obj
#else
pattern DeRef obj <- (deRef -> !obj) -- So we get assertion checking

deRef :: HasCallStack => Ref obj -> obj
deRef ref@Ref{refobj} =
          unsafeDupablePerformIO (assertNoUseAfterRelease ref)
    `seq` refobj
#endif

{-# SPECIALISE
  withRef ::
       HasCallStackIfDebug
    => Ref obj
    -> (obj -> IO a)
    -> IO a
  #-}
{-# INLINE withRef #-}
-- | Use the object in a 'Ref'. Do not retain the object after the scope of
-- the body. If you cannot use scoped \"with\" style, use pattern 'DeRef'.
--
withRef ::
     forall m obj a.
     (PrimMonad m, MonadThrow m)
  => HasCallStackIfDebug
  => Ref obj
  -> (obj -> m a)
  -> m a
withRef :: forall (m :: * -> *) obj a.
(PrimMonad m, MonadThrow m) =>
Ref obj -> (obj -> m a) -> m a
withRef ref :: Ref obj
ref@Ref{obj
refobj :: forall obj. Ref obj -> obj
refobj :: obj
refobj} obj -> m a
f = do
    Ref obj -> m ()
forall (m :: * -> *) a. PrimMonad m => Ref a -> m ()
assertNoUseAfterRelease Ref obj
ref
    m ()
forall (m :: * -> *). PrimMonad m => m ()
assertNoForgottenRefs
    obj -> m a
f obj
refobj
#ifndef NO_IGNORE_ASSERTS
  where
    _unused :: SomeException -> m Any
_unused = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO @m @SomeException
#endif

{-# SPECIALISE
  dupRef ::
       RefCounted IO obj
    => HasCallStackIfDebug
    => Ref obj
    -> IO (Ref obj)
  #-}
-- | Duplicate an existing reference, to produce a new reference.
--
dupRef ::
     forall m obj. (RefCounted m obj, PrimMonad m, MonadThrow m)
  => HasCallStackIfDebug
  => Ref obj
  -> m (Ref obj)
dupRef :: forall (m :: * -> *) obj.
(RefCounted m obj, PrimMonad m, MonadThrow m) =>
Ref obj -> m (Ref obj)
dupRef ref :: Ref obj
ref@Ref{obj
refobj :: forall obj. Ref obj -> obj
refobj :: obj
refobj} = do
    Ref obj -> m ()
forall (m :: * -> *) a. PrimMonad m => Ref a -> m ()
assertNoUseAfterRelease Ref obj
ref
    m ()
forall (m :: * -> *). PrimMonad m => m ()
assertNoForgottenRefs
    RefCounter m -> m ()
forall (m :: * -> *). PrimMonad m => RefCounter m -> m ()
incrementRefCounter (obj -> RefCounter m
forall (m :: * -> *) obj. RefCounted m obj => obj -> RefCounter m
getRefCounter obj
refobj)
    obj -> m (Ref obj)
forall (m :: * -> *) obj. PrimMonad m => obj -> m (Ref obj)
newRefWithTracker obj
refobj
#ifndef NO_IGNORE_ASSERTS
  where
    _unused :: SomeException -> m Any
_unused = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO @m @SomeException
#endif

-- | A \"weak\" reference to an object: that is, a reference that does not
-- guarantee to keep the object alive. If however the object is still alive
-- (due to other normal references still existing) then it can be converted
-- back into a normal reference with 'deRefWeak'.
--
-- Weak references do not themselves need to be released.
--
newtype WeakRef a = WeakRef a
  deriving stock Int -> WeakRef a -> ShowS
[WeakRef a] -> ShowS
WeakRef a -> String
(Int -> WeakRef a -> ShowS)
-> (WeakRef a -> String)
-> ([WeakRef a] -> ShowS)
-> Show (WeakRef a)
forall a. Show a => Int -> WeakRef a -> ShowS
forall a. Show a => [WeakRef a] -> ShowS
forall a. Show a => WeakRef a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> WeakRef a -> ShowS
showsPrec :: Int -> WeakRef a -> ShowS
$cshow :: forall a. Show a => WeakRef a -> String
show :: WeakRef a -> String
$cshowList :: forall a. Show a => [WeakRef a] -> ShowS
showList :: [WeakRef a] -> ShowS
Show

-- | Given an existing normal reference, create a new weak reference.
--
mkWeakRef :: Ref obj -> WeakRef obj
mkWeakRef :: forall obj. Ref obj -> WeakRef obj
mkWeakRef Ref {obj
refobj :: forall obj. Ref obj -> obj
refobj :: obj
refobj} = obj -> WeakRef obj
forall a. a -> WeakRef a
WeakRef obj
refobj

-- | Given an existing raw reference, create a new weak reference.
--
mkWeakRefFromRaw :: obj -> WeakRef obj
mkWeakRefFromRaw :: forall a. a -> WeakRef a
mkWeakRefFromRaw obj
obj = obj -> WeakRef obj
forall a. a -> WeakRef a
WeakRef obj
obj

{-# SPECIALISE
  deRefWeak ::
       RefCounted IO obj
    => HasCallStackIfDebug
    => WeakRef obj
    -> IO (Maybe (Ref obj))
  #-}
-- | If the object is still alive, obtain a /new/ normal reference. The normal
-- rules for 'Ref' apply, including the need to eventually call 'releaseRef'.
--
deRefWeak ::
     (RefCounted m obj, PrimMonad m)
  => HasCallStackIfDebug
  => WeakRef obj
  -> m (Maybe (Ref obj))
deRefWeak :: forall (m :: * -> *) obj.
(RefCounted m obj, PrimMonad m) =>
WeakRef obj -> m (Maybe (Ref obj))
deRefWeak (WeakRef obj
obj) = do
    Bool
success <- RefCounter m -> m Bool
forall (m :: * -> *). PrimMonad m => RefCounter m -> m Bool
tryIncrementRefCounter (obj -> RefCounter m
forall (m :: * -> *) obj. RefCounted m obj => obj -> RefCounter m
getRefCounter obj
obj)
    if Bool
success then Ref obj -> Maybe (Ref obj)
forall a. a -> Maybe a
Just (Ref obj -> Maybe (Ref obj)) -> m (Ref obj) -> m (Maybe (Ref obj))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> obj -> m (Ref obj)
forall (m :: * -> *) obj. PrimMonad m => obj -> m (Ref obj)
newRefWithTracker obj
obj
               else Maybe (Ref obj) -> m (Maybe (Ref obj))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Ref obj)
forall a. Maybe a
Nothing

{-# INLINE newRefWithTracker #-}
#ifndef NO_IGNORE_ASSERTS
newRefWithTracker :: PrimMonad m => obj -> m (Ref obj)
newRefWithTracker :: forall (m :: * -> *) obj. PrimMonad m => obj -> m (Ref obj)
newRefWithTracker obj
obj =
    Ref obj -> m (Ref obj)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Ref obj -> m (Ref obj)) -> Ref obj -> m (Ref obj)
forall a b. (a -> b) -> a -> b
$! obj -> Ref obj
forall obj. obj -> Ref obj
Ref obj
obj
#else
newRefWithTracker :: (PrimMonad m, HasCallStack) => obj -> m (Ref obj)
newRefWithTracker obj = do
    reftracker' <- newRefTracker callStack
    return $! Ref obj reftracker'
#endif

data RefException =
       RefUseAfterRelease RefId
        CallStack -- ^ Allocation site
        CallStack -- ^ Release site
        CallStack -- ^ Use site
     | RefDoubleRelease RefId
        CallStack -- ^ Allocation site
        CallStack -- ^ First release site
        CallStack -- ^ Second release site
     | RefNeverReleased RefId
        CallStack -- ^ Allocation site

newtype RefId = RefId Int
  deriving stock (Int -> RefId -> ShowS
[RefId] -> ShowS
RefId -> String
(Int -> RefId -> ShowS)
-> (RefId -> String) -> ([RefId] -> ShowS) -> Show RefId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RefId -> ShowS
showsPrec :: Int -> RefId -> ShowS
$cshow :: RefId -> String
show :: RefId -> String
$cshowList :: [RefId] -> ShowS
showList :: [RefId] -> ShowS
Show, RefId -> RefId -> Bool
(RefId -> RefId -> Bool) -> (RefId -> RefId -> Bool) -> Eq RefId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RefId -> RefId -> Bool
== :: RefId -> RefId -> Bool
$c/= :: RefId -> RefId -> Bool
/= :: RefId -> RefId -> Bool
Eq, Eq RefId
Eq RefId =>
(RefId -> RefId -> Ordering)
-> (RefId -> RefId -> Bool)
-> (RefId -> RefId -> Bool)
-> (RefId -> RefId -> Bool)
-> (RefId -> RefId -> Bool)
-> (RefId -> RefId -> RefId)
-> (RefId -> RefId -> RefId)
-> Ord RefId
RefId -> RefId -> Bool
RefId -> RefId -> Ordering
RefId -> RefId -> RefId
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: RefId -> RefId -> Ordering
compare :: RefId -> RefId -> Ordering
$c< :: RefId -> RefId -> Bool
< :: RefId -> RefId -> Bool
$c<= :: RefId -> RefId -> Bool
<= :: RefId -> RefId -> Bool
$c> :: RefId -> RefId -> Bool
> :: RefId -> RefId -> Bool
$c>= :: RefId -> RefId -> Bool
>= :: RefId -> RefId -> Bool
$cmax :: RefId -> RefId -> RefId
max :: RefId -> RefId -> RefId
$cmin :: RefId -> RefId -> RefId
min :: RefId -> RefId -> RefId
Ord)

instance Show RefException where
  --Sigh. QuickCheck still uses 'show' rather than 'displayException'.
  showsPrec :: Int -> RefException -> ShowS
showsPrec Int
d RefException
x = Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
appPrec) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString (RefException -> String
forall e. Exception e => e -> String
displayException RefException
x)

instance Exception RefException where
  displayException :: RefException -> String
displayException (RefUseAfterRelease RefId
refid CallStack
allocSite CallStack
releaseSite CallStack
useSite) =
      String
"Reference is used after release: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ RefId -> String
forall a. Show a => a -> String
show RefId
refid
    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nAllocation site: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CallStack -> String
prettyCallStack CallStack
allocSite
    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nRelease site: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CallStack -> String
prettyCallStack CallStack
releaseSite
    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nUse site: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CallStack -> String
prettyCallStack CallStack
useSite
  displayException (RefDoubleRelease RefId
refid CallStack
allocSite CallStack
releaseSite1 CallStack
releaseSite2) =
      String
"Reference is released twice: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ RefId -> String
forall a. Show a => a -> String
show RefId
refid
    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nAllocation site: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CallStack -> String
prettyCallStack CallStack
allocSite
    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nFirst release site: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CallStack -> String
prettyCallStack CallStack
releaseSite1
    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nSecond release site: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CallStack -> String
prettyCallStack CallStack
releaseSite2
  displayException (RefNeverReleased RefId
refid CallStack
allocSite) =
      String
"Reference is never released: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ RefId -> String
forall a. Show a => a -> String
show RefId
refid
   String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nAllocation site: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ CallStack -> String
prettyCallStack CallStack
allocSite

#ifndef NO_IGNORE_ASSERTS

{-# INLINE releaseRefTracker #-}
releaseRefTracker :: PrimMonad m => Ref a -> m ()
releaseRefTracker :: forall (m :: * -> *) a. PrimMonad m => Ref a -> m ()
releaseRefTracker Ref a
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

{-# INLINE assertNoForgottenRefs #-}
assertNoForgottenRefs :: PrimMonad m => m ()
assertNoForgottenRefs :: forall (m :: * -> *). PrimMonad m => m ()
assertNoForgottenRefs = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

{-# INLINE assertNoUseAfterRelease #-}
assertNoUseAfterRelease :: PrimMonad m => Ref a -> m ()
assertNoUseAfterRelease :: forall (m :: * -> *) a. PrimMonad m => Ref a -> m ()
assertNoUseAfterRelease Ref a
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

{-# INLINE assertNoDoubleRelease #-}
assertNoDoubleRelease :: PrimMonad m => Ref a -> m ()
assertNoDoubleRelease :: forall (m :: * -> *) a. PrimMonad m => Ref a -> m ()
assertNoDoubleRelease Ref a
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

#else

-- | A weak pointer to an outer IORef, containing an inner IORef with a maybe to
-- indicate if the ref has been explicitly released.
--
-- The finaliser for the outer weak pointer is given access to the inner IORef
-- so that it can tell if the reference has become garbage without being
-- explicitly released.
--
-- The outer IORef is also stored directly. This ensures the weak pointer to
-- the same is not garbage collected until the RefTracker itself (and thus the
-- parent Ref) is itself garbage collected.
--
-- The inner IORef is mutated when explicitly released. The outer IORef is
-- never modified, but we use an IORef to ensure the weak pointer is reliable.
--
-- The inner IORef also tracks the call stack for the site where the reference
-- (tracker) is released. This call stack is used in exceptions for easier
-- debugging.
data RefTracker = RefTracker !RefId
                             !(Weak (IORef (IORef (Maybe CallStack))))
                             !(IORef (IORef (Maybe CallStack))) -- ^ Release site
                             !CallStack -- ^ Allocation site

{-# NOINLINE globalRefIdSupply #-}
globalRefIdSupply :: PrimVar RealWorld Int
globalRefIdSupply = unsafePerformIO $ newPrimVar 0

data Enabled a = Enabled !a | Disabled

{-# NOINLINE globalForgottenRef #-}
globalForgottenRef :: IORef (Enabled (Maybe (RefId, CallStack)))
globalForgottenRef = unsafePerformIO $ newIORef (Enabled Nothing)

-- | This version of 'unsafeIOToPrim' is strict in the result of the arument
-- action.
--
-- Without strictness it seems that some IO side effects are not happening at
-- the right time, like clearing the @globalForgottenRef@ in
-- @assertNoForgottenRefs@.
unsafeIOToPrimStrict :: PrimMonad m => IO a -> m a
unsafeIOToPrimStrict k = do
    !x <- unsafeIOToPrim k
    pure x

newRefTracker :: PrimMonad m => CallStack -> m RefTracker
newRefTracker allocSite = unsafeIOToPrimStrict $ do
    inner <- newIORef Nothing
    outer <- newIORef inner
    refid <- fetchAddInt globalRefIdSupply 1
    weak  <- mkWeakIORef outer $
               finaliserRefTracker inner (RefId refid) allocSite
    return (RefTracker (RefId refid) weak outer allocSite)

releaseRefTracker :: (HasCallStack, PrimMonad m) => Ref a -> m ()
releaseRefTracker Ref { reftracker =  RefTracker _refid _weak outer _ } =
  unsafeIOToPrimStrict $ do
    inner <- readIORef outer
    let releaseSite = callStack
    writeIORef inner (Just releaseSite)

finaliserRefTracker :: IORef (Maybe CallStack) -> RefId -> CallStack -> IO ()
finaliserRefTracker inner refid allocSite = do
    released <- readIORef inner
    case released of
      Just _releaseSite -> pure ()
      Nothing -> do
        -- Uh oh! Forgot a reference without releasing!
        -- Add it to a global var which we can poll elsewhere.
        mref <- readIORef globalForgottenRef
        case mref of
          Disabled -> pure ()
          -- Just keep one, but keep the last allocated one.
          -- The reason for last is that when there are nested structures with
          -- refs then the last allocated is likely to be the outermost, which
          -- is the best place to start hunting for ref leaks. Otherwise one can
          -- go on a wild goose chase tracking down inner refs that were only
          -- forgotten due to an outer ref being forgotten.
          Enabled (Just (refid', _)) | refid < refid' -> return ()
          Enabled _ -> writeIORef globalForgottenRef (Enabled (Just (refid, allocSite)))

assertNoForgottenRefs :: (PrimMonad m, MonadThrow m) => m ()
assertNoForgottenRefs = do
    mrefs <- unsafeIOToPrimStrict $ readIORef globalForgottenRef
    case mrefs of
      Disabled      -> return ()
      Enabled Nothing -> return ()
      Enabled (Just (refid, allocSite)) -> do
        -- Clear the var so we don't assert again.
        --
        -- Using the strict version is important here: if @m ~ IOSim s@, then
        -- using the non-strict version will lead to @RefNeverReleased@
        -- exceptions.
        unsafeIOToPrimStrict $ writeIORef globalForgottenRef (Enabled Nothing)
        throwIO (RefNeverReleased refid allocSite)


assertNoUseAfterRelease :: (PrimMonad m, MonadThrow m, HasCallStack) => Ref a -> m ()
assertNoUseAfterRelease Ref { reftracker = RefTracker refid _weak outer allocSite } = do
    released <- unsafeIOToPrimStrict (readIORef =<< readIORef outer)
    case released of
      Nothing -> pure ()
      Just releaseSite -> do
        -- The site where the reference is used after release
        let useSite = callStack
        throwIO (RefUseAfterRelease refid allocSite releaseSite useSite)
#if !(MIN_VERSION_base(4,20,0))
  where
    _unused = callStack
#endif

assertNoDoubleRelease :: (PrimMonad m, MonadThrow m, HasCallStack) => Ref a -> m ()
assertNoDoubleRelease Ref { reftracker = RefTracker refid _weak outer allocSite } = do
    released <- unsafeIOToPrimStrict (readIORef =<< readIORef outer)
    case released of
      Nothing -> pure ()
      Just releaseSite1 -> do
        -- The second release site
        let releaseSite2 = callStack
        throwIO (RefDoubleRelease refid allocSite releaseSite1 releaseSite2)
#if !(MIN_VERSION_base(4,20,0))
  where
    _unused = callStack
#endif

#endif

-- | Run a GC to try and see if any refs have been forgotten without being
-- released. If so, this will throw a synchronous exception.
--
-- Note however that this is not the only place where 'RefNeverReleased'
-- exceptions can be thrown. All Ref operations poll for forgotten refs.
--
checkForgottenRefs :: forall m. (PrimMonad m, MonadThrow m) => m ()
checkForgottenRefs :: forall (m :: * -> *). (PrimMonad m, MonadThrow m) => m ()
checkForgottenRefs = do
#ifndef NO_IGNORE_ASSERTS
    () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#else
    -- The hope is that by combining `performMajorGC` with `yield` that the
    -- former starts the finalizer threads for all dropped weak references and
    -- the latter suspends the current process and puts it at the end of the
    -- thread queue, such that when the current process resumes the finalizer
    -- threads for all dropped weak references have finished.
    -- Unfortunately, this relies on the implementation of the GHC scheduler,
    -- not on any Haskell specification, and is therefore both non-portable and
    -- presumably rather brittle. Therefore, for good measure, we do it twice.
    unsafeIOToPrimStrict $ do
      performMajorGCWithBlockingIfAvailable
      yield
      performMajorGCWithBlockingIfAvailable
      yield
    assertNoForgottenRefs
#endif
  where
    _unused :: SomeException -> m Any
_unused = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO @m @SomeException

-- | Ignore and reset the state of forgotten reference tracking. This ensures
-- that any stale fogotten references are not reported later.
--
-- This is especillay important in QC tests with shrinking which otherwise
-- leads to confusion.
ignoreForgottenRefs :: (PrimMonad m, MonadCatch m) => m ()
ignoreForgottenRefs :: forall (m :: * -> *). (PrimMonad m, MonadCatch m) => m ()
ignoreForgottenRefs = m (Either SomeException ()) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Either SomeException ()) -> m ())
-> m (Either SomeException ()) -> m ()
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try @_ @SomeException (m () -> m (Either SomeException ()))
-> m () -> m (Either SomeException ())
forall a b. (a -> b) -> a -> b
$ m ()
forall (m :: * -> *). (PrimMonad m, MonadThrow m) => m ()
checkForgottenRefs

#ifdef NO_IGNORE_ASSERTS
performMajorGCWithBlockingIfAvailable :: IO ()
performMajorGCWithBlockingIfAvailable = do
#if MIN_VERSION_base(4,20,0)
    performBlockingMajorGC
#else
    performMajorGC
#endif
#endif

-- | Enable forgotten reference checks.
enableForgottenRefChecks :: IO ()

-- | Disable forgotten reference checks. This will error if there are already
-- forgotten references while we are trying to disable the checks.
disableForgottenRefChecks :: IO ()

#ifdef NO_IGNORE_ASSERTS
enableForgottenRefChecks =
    modifyIORef globalForgottenRef $ \case
      Disabled -> Enabled Nothing
      Enabled _  -> error "enableForgottenRefChecks: already enabled"

disableForgottenRefChecks =
    modifyIORef globalForgottenRef $ \case
      Disabled -> error "disableForgottenRefChecks: already disabled"
      Enabled Nothing -> Disabled
      Enabled _  -> error "disableForgottenRefChecks: can not disable when there are forgotten references"
#else
enableForgottenRefChecks :: IO ()
enableForgottenRefChecks = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
disableForgottenRefChecks :: IO ()
disableForgottenRefChecks = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
#endif