-- | A read-write-locked mutable variable with a bias towards write locks.
--
-- This module is intended to be imported qualified:
--
-- @
--   import           Control.Concurrent.Class.MonadSTM.RWVar (RWVar)
--   import qualified Control.Concurrent.Class.MonadSTM.RWVar as RW
-- @
module Control.Concurrent.Class.MonadSTM.RWVar (
    RWVar (..)
  , RWState (..)
  , new
  , unsafeAcquireReadAccess
  , unsafeReleaseReadAccess
  , withReadAccess
  , unsafeAcquireWriteAccess
  , unsafeReleaseWriteAccess
  , withWriteAccess
  , withWriteAccess_
  ) where

import           Control.Concurrent.Class.MonadSTM.Strict
import           Control.DeepSeq
import           Control.Monad.Class.MonadThrow
import           Data.Word

-- | A read-write-locked mutable variable with a bias towards write-locks.
newtype RWVar m a = RWVar (StrictTVar m (RWState a))

-- | __NOTE__: Only strict in the reference and not the referenced value.
instance NFData (RWVar m a) where
  rnf :: RWVar m a -> ()
rnf = RWVar m a -> ()
forall a. a -> ()
rwhnf

data RWState a =
    -- | @n@ concurrent readers and no writer.
    Reading !Word64 !a
    -- | @n@ concurrent readers and no writer, but no new readers can get
    -- access.
  | WaitingToWrite !Word64 !a
    -- | A single writer and no concurrent readers.
  | Writing

{-# SPECIALISE new :: a -> IO (RWVar IO a) #-}
new :: MonadSTM m => a -> m (RWVar m a)
new :: forall (m :: * -> *) a. MonadSTM m => a -> m (RWVar m a)
new !a
x = StrictTVar m (RWState a) -> RWVar m a
forall (m :: * -> *) a. StrictTVar m (RWState a) -> RWVar m a
RWVar (StrictTVar m (RWState a) -> RWVar m a)
-> m (StrictTVar m (RWState a)) -> m (RWVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RWState a -> m (StrictTVar m (RWState a))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO (Word64 -> a -> RWState a
forall a. Word64 -> a -> RWState a
Reading Word64
0 a
x)

{-# SPECIALISE unsafeAcquireReadAccess :: RWVar IO a -> STM IO a #-}
unsafeAcquireReadAccess :: MonadSTM m => RWVar m a -> STM m a
unsafeAcquireReadAccess :: forall (m :: * -> *) a. MonadSTM m => RWVar m a -> STM m a
unsafeAcquireReadAccess (RWVar !StrictTVar m (RWState a)
var) = do
    StrictTVar m (RWState a) -> STM m (RWState a)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (RWState a)
var STM m (RWState a) -> (RWState a -> STM m a) -> STM m a
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Reading Word64
n a
x -> do
        StrictTVar m (RWState a) -> RWState a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (RWState a)
var (Word64 -> a -> RWState a
forall a. Word64 -> a -> RWState a
Reading (Word64
nWord64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+Word64
1) a
x)
        a -> STM m a
forall a. a -> STM m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
      WaitingToWrite{} -> STM m a
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
      RWState a
Writing -> STM m a
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry

{-# SPECIALISE unsafeReleaseReadAccess :: RWVar IO a -> STM IO () #-}
unsafeReleaseReadAccess :: MonadSTM m => RWVar m a -> STM m ()
unsafeReleaseReadAccess :: forall (m :: * -> *) a. MonadSTM m => RWVar m a -> STM m ()
unsafeReleaseReadAccess (RWVar !StrictTVar m (RWState a)
var) = do
    StrictTVar m (RWState a) -> STM m (RWState a)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (RWState a)
var STM m (RWState a) -> (RWState a -> STM m ()) -> STM m ()
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Reading Word64
n a
x
        | Word64
n Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
0 -> [Char] -> STM m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"releasing a reader without read access (Reading)"
        | Bool
otherwise -> StrictTVar m (RWState a) -> RWState a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (RWState a)
var (Word64 -> a -> RWState a
forall a. Word64 -> a -> RWState a
Reading (Word64
n Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1) a
x)
      WaitingToWrite Word64
n a
x
        | Word64
n Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
0 -> [Char] -> STM m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"releasing a reader without read access (WaitingToWrite)"
        | Bool
otherwise -> StrictTVar m (RWState a) -> RWState a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (RWState a)
var (Word64 -> a -> RWState a
forall a. Word64 -> a -> RWState a
WaitingToWrite (Word64
n Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1) a
x)
      RWState a
Writing -> [Char] -> STM m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"releasing a reader without read access (Writing)"

{-# SPECIALISE withReadAccess :: RWVar IO a -> (a -> IO b) -> IO b #-}
withReadAccess :: (MonadSTM m, MonadThrow m) => RWVar m a -> (a -> m b) -> m b
withReadAccess :: forall (m :: * -> *) a b.
(MonadSTM m, MonadThrow m) =>
RWVar m a -> (a -> m b) -> m b
withReadAccess RWVar m a
rwvar a -> m b
k =
    m a -> (a -> m ()) -> (a -> m b) -> m b
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
      (STM m a -> m a
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m a -> m a) -> STM m a -> m a
forall a b. (a -> b) -> a -> b
$ RWVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => RWVar m a -> STM m a
unsafeAcquireReadAccess RWVar m a
rwvar)
      (\a
_ -> STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ RWVar m a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => RWVar m a -> STM m ()
unsafeReleaseReadAccess RWVar m a
rwvar)
      a -> m b
k

{-# SPECIALISE unsafeAcquireWriteAccess :: RWVar IO a -> IO a #-}
-- | Acquire write access. This function assumes that it runs in a masked
-- context, and that is properly paired with an 'unsafeReleaseWriteAccess'!
--
-- If multiple threads try to acquire write access concurrently, then they will
-- race for access. However, if a thread has set RWState to WaitingToWrite, then
-- it is guaranteed that the same thread will acquire write access when all
-- readers have finished. That is, other writes can not "jump the queue". When
-- the writer finishes, then all other waiting threads will race for write
-- access again.
--
-- TODO: unsafeReleaseWriteAccess will set RWState to Reading 0. In case we have
-- readers *and* writers waiting for a writer to finish, once the writer is
-- finished there will be a race. In this race, readers and writers are just as
-- likely to acquire access first. However, if we wanted to make RWVar even more
-- biased towards writers, then we could ensure that all waiting writers get
-- access before the readers get a chance. This would probably require us to
-- change RWState to represent the case where writers are waiting for a writer
-- to finish.
unsafeAcquireWriteAccess :: (MonadSTM m, MonadCatch m) => RWVar m a -> m a
unsafeAcquireWriteAccess :: forall (m :: * -> *) a.
(MonadSTM m, MonadCatch m) =>
RWVar m a -> m a
unsafeAcquireWriteAccess (RWVar !StrictTVar m (RWState a)
var) =
    -- trySetWriting is interruptible, but it is fine if it is interrupted
    -- because the RWState can not be changed before the interruption.
    --
    -- trySetWriting might update the RWState. There are interruptible
    -- operations in the body of the bracketOnError (in waitToWrite), so async
    -- exceptions can be delivered there. If an async exception happens because
    -- of an interrupt, we undo the RWState change using undoWaitingToWrite.
    --
    -- Note that if waitToWrite is interrupted, that it is impossible for the
    -- RWState to have changed from WaitingToWrite to either Reading or Writing.
    -- Therefore, undoWaitingToWrite can assume that it will find WaitingToWrite
    -- in the lock.
    m (Maybe a) -> (Maybe a -> m ()) -> (Maybe a -> m a) -> m a
forall a b c. m a -> (a -> m b) -> (a -> m c) -> m c
forall (m :: * -> *) a b c.
MonadCatch m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracketOnError m (Maybe a)
trySetWriting Maybe a -> m ()
undoWaitingToWrite ((Maybe a -> m a) -> m a) -> (Maybe a -> m a) -> m a
forall a b. (a -> b) -> a -> b
$
      -- When Nothing is returned, it means that we set the RWState to
      -- WaitingToWrite, and so we wait to acquire the final write access.
      --
      -- When Just is returned, we already have write access.
      m a -> (a -> m a) -> Maybe a -> m a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m a
waitToWrite a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  where
    -- Try to acquire a write lock immediately, or otherwise set the internal
    -- state to WaitingToWrite as soon as possible.
    --
    -- Note: this is interruptible
    trySetWriting :: m (Maybe a)
trySetWriting = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Maybe a) -> m (Maybe a)) -> STM m (Maybe a) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ StrictTVar m (RWState a) -> STM m (RWState a)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (RWState a)
var STM m (RWState a)
-> (RWState a -> STM m (Maybe a)) -> STM m (Maybe a)
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Reading Word64
n a
x
          | Word64
n Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
0 -> do
              StrictTVar m (RWState a) -> RWState a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (RWState a)
var RWState a
forall a. RWState a
Writing
              Maybe a -> STM m (Maybe a)
forall a. a -> STM m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Maybe a
forall a. a -> Maybe a
Just a
x)
          | Bool
otherwise -> do
              StrictTVar m (RWState a) -> RWState a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (RWState a)
var (Word64 -> a -> RWState a
forall a. Word64 -> a -> RWState a
WaitingToWrite Word64
n a
x)
              Maybe a -> STM m (Maybe a)
forall a. a -> STM m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
        -- The following two branches are interruptible
        WaitingToWrite Word64
_n a
_x -> STM m (Maybe a)
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
        RWState a
Writing -> STM m (Maybe a)
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry

    -- Note: this is uninterruptible
    undoWaitingToWrite :: Maybe a -> m ()
undoWaitingToWrite Maybe a
Nothing  =  STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ StrictTVar m (RWState a) -> STM m (RWState a)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (RWState a)
var STM m (RWState a) -> (RWState a -> STM m ()) -> STM m ()
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Reading Word64
_n a
_x -> [Char] -> STM m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"undoWaitingToWrite: found Reading but expected WaitingToWrite"
        WaitingToWrite Word64
n a
x -> StrictTVar m (RWState a) -> RWState a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (RWState a)
var (Word64 -> a -> RWState a
forall a. Word64 -> a -> RWState a
Reading Word64
n a
x)
        RWState a
Writing -> [Char] -> STM m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"undoWaitingToWrite: found Writing but expected WaitingToWrite"
    undoWaitingToWrite (Just a
_) = [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"undoWaitingToWrite: found Just but expected Nothing"

    -- Wait for the number of readers to go to 0, and then finally acquire write
    -- access.
    --
    -- Note: this is interruptible
    waitToWrite :: m a
waitToWrite = STM m a -> m a
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m a -> m a) -> STM m a -> m a
forall a b. (a -> b) -> a -> b
$ StrictTVar m (RWState a) -> STM m (RWState a)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (RWState a)
var STM m (RWState a) -> (RWState a -> STM m a) -> STM m a
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Reading Word64
_n a
_x -> [Char] -> STM m a
forall a. HasCallStack => [Char] -> a
error [Char]
"waitToWrite: found Reading but expected WaitingToWrite"
        WaitingToWrite Word64
n a
x
          | Word64
n Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
== Word64
0 -> do
              StrictTVar m (RWState a) -> RWState a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (RWState a)
var RWState a
forall a. RWState a
Writing
              a -> STM m a
forall a. a -> STM m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
          -- This branch is interruptible
          | Bool
otherwise -> STM m a
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
        RWState a
Writing -> [Char] -> STM m a
forall a. HasCallStack => [Char] -> a
error [Char]
"waitToWrite: found Reading but expected Writing"

{-# SPECIALISE unsafeReleaseWriteAccess :: RWVar IO a -> a -> STM IO () #-}
-- | Release write access. This function assumes that it runs in a masked
-- context, and that is properly paired with an 'unsafeAcquireWriteAccess'!
unsafeReleaseWriteAccess :: MonadSTM m => RWVar m a -> a -> STM m ()
unsafeReleaseWriteAccess :: forall (m :: * -> *) a. MonadSTM m => RWVar m a -> a -> STM m ()
unsafeReleaseWriteAccess (RWVar !StrictTVar m (RWState a)
var) !a
x = do
    StrictTVar m (RWState a) -> STM m (RWState a)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (RWState a)
var STM m (RWState a) -> (RWState a -> STM m ()) -> STM m ()
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Reading Word64
_ a
_ -> [Char] -> STM m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"releasing a writer without write access (Reading)"
      WaitingToWrite Word64
_ a
_ -> [Char] -> STM m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"releasing a writer without write access (WaitingToWrite)"
      RWState a
Writing -> StrictTVar m (RWState a) -> RWState a -> STM m ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m (RWState a)
var (Word64 -> a -> RWState a
forall a. Word64 -> a -> RWState a
Reading Word64
0 a
x)

{-# SPECIALISE withWriteAccess :: RWVar IO a -> (a -> IO (a, b)) -> IO b #-}
withWriteAccess :: (MonadSTM m, MonadCatch m) => RWVar m a -> (a -> m (a, b)) -> m b
withWriteAccess :: forall (m :: * -> *) a b.
(MonadSTM m, MonadCatch m) =>
RWVar m a -> (a -> m (a, b)) -> m b
withWriteAccess RWVar m a
rwvar a -> m (a, b)
k = (a, b) -> b
forall a b. (a, b) -> b
snd ((a, b) -> b) -> (((a, b), ()) -> (a, b)) -> ((a, b), ()) -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, b), ()) -> (a, b)
forall a b. (a, b) -> a
fst (((a, b), ()) -> b) -> m ((a, b), ()) -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
    m a
-> (a -> ExitCase (a, b) -> m ())
-> (a -> m (a, b))
-> m ((a, b), ())
forall a b c.
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
forall (m :: * -> *) a b c.
MonadCatch m =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracket
      (RWVar m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, MonadCatch m) =>
RWVar m a -> m a
unsafeAcquireWriteAccess RWVar m a
rwvar)
      (\a
x ExitCase (a, b)
ec -> do
        STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ case ExitCase (a, b)
ec of
            ExitCaseSuccess (a
x', b
_) -> RWVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => RWVar m a -> a -> STM m ()
unsafeReleaseWriteAccess RWVar m a
rwvar a
x'
            ExitCaseException SomeException
_     -> RWVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => RWVar m a -> a -> STM m ()
unsafeReleaseWriteAccess RWVar m a
rwvar a
x
            ExitCase (a, b)
ExitCaseAbort           -> RWVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => RWVar m a -> a -> STM m ()
unsafeReleaseWriteAccess RWVar m a
rwvar a
x
      )
      a -> m (a, b)
k

{-# SPECIALISE withWriteAccess_ :: RWVar IO a -> (a -> IO a) -> IO () #-}
withWriteAccess_ :: (MonadSTM m, MonadCatch m) => RWVar m a -> (a -> m a) -> m ()
withWriteAccess_ :: forall (m :: * -> *) a.
(MonadSTM m, MonadCatch m) =>
RWVar m a -> (a -> m a) -> m ()
withWriteAccess_ RWVar m a
rwvar a -> m a
k = RWVar m a -> (a -> m (a, ())) -> m ()
forall (m :: * -> *) a b.
(MonadSTM m, MonadCatch m) =>
RWVar m a -> (a -> m (a, b)) -> m b
withWriteAccess RWVar m a
rwvar ((a -> (a, ())) -> m a -> m (a, ())
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,()) (m a -> m (a, ())) -> (a -> m a) -> a -> m (a, ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
k)