{-# LANGUAGE ScopedTypeVariables #-}

-- | Direct (de-)serialisation to / from raw memory.
--
-- The purpose of the typeclasses in this module is to abstract over data
-- structures that can expose the data they store as one or more raw 'Ptr's,
-- without any additional memory copying or conversion to intermediate data
-- structures.
--
-- This is useful for transmitting data like KES SignKeys over a socket
-- connection: by accessing the memory directly and copying it into or out of
-- a file descriptor, without going through an intermediate @ByteString@
-- representation (or other data structure that resides in the GHC heap), we
-- can more easily assure that the data is never written to disk, including
-- swap, which is an important requirement for KES.
module Cardano.Crypto.DirectSerialise
where

import Cardano.Crypto.Libsodium.Memory (copyMem)
import Control.Exception
import Control.Monad (when)
import Control.Monad.Class.MonadST (MonadST, stToIO)
import Control.Monad.Class.MonadThrow (MonadThrow)
import Data.STRef (newSTRef, readSTRef, writeSTRef)
import Foreign.C.Types
import Foreign.Ptr

data SizeCheckException
  = SizeCheckException
  { SizeCheckException -> Int
expectedSize :: Int
  , SizeCheckException -> Int
actualSize :: Int
  }
  deriving (Int -> SizeCheckException -> ShowS
[SizeCheckException] -> ShowS
SizeCheckException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeCheckException] -> ShowS
$cshowList :: [SizeCheckException] -> ShowS
show :: SizeCheckException -> String
$cshow :: SizeCheckException -> String
showsPrec :: Int -> SizeCheckException -> ShowS
$cshowsPrec :: Int -> SizeCheckException -> ShowS
Show)

instance Exception SizeCheckException

sizeCheckFailed :: Int -> Int -> m ()
sizeCheckFailed :: forall (m :: * -> *). Int -> Int -> m ()
sizeCheckFailed Int
ex Int
ac =
  forall a e. Exception e => e -> a
throw forall a b. (a -> b) -> a -> b
$ Int -> Int -> SizeCheckException
SizeCheckException Int
ex Int
ac

-- | Direct deserialization from raw memory.
--
-- @directDeserialise f@ should allocate a new value of type 'a', and
-- call @f@ with a pointer to the raw memory to be filled. @f@ may be called
-- multiple times, for data structures that store their data in multiple
-- non-contiguous blocks of memory.
--
-- The order in which memory blocks are visited matters.
class DirectDeserialise a where
  directDeserialise :: (MonadST m, MonadThrow m) => (Ptr CChar -> CSize -> m ()) -> m a

-- | Direct serialization to raw memory.
--
-- @directSerialise f x@ should call @f@ to expose the raw memory underyling
-- @x@. For data types that store their data in multiple non-contiguous blocks
-- of memory, @f@ may be called multiple times, once for each block.
--
-- The order in which memory blocks are visited matters.
class DirectSerialise a where
  directSerialise :: (MonadST m, MonadThrow m) => (Ptr CChar -> CSize -> m ()) -> a -> m ()

-- | Helper function for bounds-checked serialization.
-- Verifies that no more than the maximum number of bytes are written, and
-- returns the actual number of bytes written.
directSerialiseTo ::
  forall m a.
  DirectSerialise a =>
  MonadST m =>
  MonadThrow m =>
  (Int -> Ptr CChar -> CSize -> m ()) ->
  Int ->
  a ->
  m Int
directSerialiseTo :: forall (m :: * -> *) a.
(DirectSerialise a, MonadST m, MonadThrow m) =>
(Int -> Ptr CChar -> CSize -> m ()) -> Int -> a -> m Int
directSerialiseTo Int -> Ptr CChar -> CSize -> m ()
writeBytes Int
dstsize a
val = do
  STRef (PrimState m) Int
posRef <- forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall a b. (a -> b) -> a -> b
$ forall a s. a -> ST s (STRef s a)
newSTRef Int
0
  let pusher :: Ptr CChar -> CSize -> m ()
      pusher :: Ptr CChar -> CSize -> m ()
pusher Ptr CChar
src CSize
srcsize = do
        Int
pos <- forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef STRef (PrimState m) Int
posRef
        let pos' :: Int
pos' = Int
pos forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
srcsize
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
pos' forall a. Ord a => a -> a -> Bool
> Int
dstsize) forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *). Int -> Int -> m ()
sizeCheckFailed (Int
dstsize forall a. Num a => a -> a -> a
- Int
pos) (Int
pos' forall a. Num a => a -> a -> a
- Int
pos)
        Int -> Ptr CChar -> CSize -> m ()
writeBytes Int
pos Ptr CChar
src (forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
srcsize)
        forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef (PrimState m) Int
posRef Int
pos'
  forall a (m :: * -> *).
(DirectSerialise a, MonadST m, MonadThrow m) =>
(Ptr CChar -> CSize -> m ()) -> a -> m ()
directSerialise Ptr CChar -> CSize -> m ()
pusher a
val
  forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef STRef (PrimState m) Int
posRef

-- | Helper function for size-checked serialization.
-- Verifies that exactly the specified number of bytes are written.
directSerialiseToChecked ::
  forall m a.
  DirectSerialise a =>
  MonadST m =>
  MonadThrow m =>
  (Int -> Ptr CChar -> CSize -> m ()) ->
  Int ->
  a ->
  m ()
directSerialiseToChecked :: forall (m :: * -> *) a.
(DirectSerialise a, MonadST m, MonadThrow m) =>
(Int -> Ptr CChar -> CSize -> m ()) -> Int -> a -> m ()
directSerialiseToChecked Int -> Ptr CChar -> CSize -> m ()
writeBytes Int
dstsize a
val = do
  Int
bytesWritten <- forall (m :: * -> *) a.
(DirectSerialise a, MonadST m, MonadThrow m) =>
(Int -> Ptr CChar -> CSize -> m ()) -> Int -> a -> m Int
directSerialiseTo Int -> Ptr CChar -> CSize -> m ()
writeBytes Int
dstsize a
val
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bytesWritten forall a. Eq a => a -> a -> Bool
/= Int
dstsize) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *). Int -> Int -> m ()
sizeCheckFailed Int
dstsize Int
bytesWritten

-- | Helper function for the common use case of serializing to an in-memory
-- buffer.
-- Verifies that no more than the maximum number of bytes are written, and
-- returns the actual number of bytes written.
directSerialiseBuf ::
  forall m a.
  DirectSerialise a =>
  MonadST m =>
  MonadThrow m =>
  Ptr CChar ->
  Int ->
  a ->
  m Int
directSerialiseBuf :: forall (m :: * -> *) a.
(DirectSerialise a, MonadST m, MonadThrow m) =>
Ptr CChar -> Int -> a -> m Int
directSerialiseBuf Ptr CChar
dst =
  forall (m :: * -> *) a.
(DirectSerialise a, MonadST m, MonadThrow m) =>
(Int -> Ptr CChar -> CSize -> m ()) -> Int -> a -> m Int
directSerialiseTo (forall (m :: * -> *) a.
MonadST m =>
Ptr a -> Ptr a -> CSize -> m ()
copyMem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr CChar
dst)

-- | Helper function for size-checked serialization to an in-memory buffer.
-- Verifies that exactly the specified number of bytes are written.
directSerialiseBufChecked ::
  forall m a.
  DirectSerialise a =>
  MonadST m =>
  MonadThrow m =>
  Ptr CChar ->
  Int ->
  a ->
  m ()
directSerialiseBufChecked :: forall (m :: * -> *) a.
(DirectSerialise a, MonadST m, MonadThrow m) =>
Ptr CChar -> Int -> a -> m ()
directSerialiseBufChecked Ptr CChar
buf Int
dstsize a
val = do
  Int
bytesWritten <- forall (m :: * -> *) a.
(DirectSerialise a, MonadST m, MonadThrow m) =>
Ptr CChar -> Int -> a -> m Int
directSerialiseBuf Ptr CChar
buf Int
dstsize a
val
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bytesWritten forall a. Eq a => a -> a -> Bool
/= Int
dstsize) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *). Int -> Int -> m ()
sizeCheckFailed Int
dstsize Int
bytesWritten

-- | Helper function for size-checked deserialization.
-- Verifies that no more than the maximum number of bytes are read, and returns
-- the actual number of bytes read.
directDeserialiseFrom ::
  forall m a.
  DirectDeserialise a =>
  MonadST m =>
  MonadThrow m =>
  (Int -> Ptr CChar -> CSize -> m ()) ->
  Int ->
  m (a, Int)
directDeserialiseFrom :: forall (m :: * -> *) a.
(DirectDeserialise a, MonadST m, MonadThrow m) =>
(Int -> Ptr CChar -> CSize -> m ()) -> Int -> m (a, Int)
directDeserialiseFrom Int -> Ptr CChar -> CSize -> m ()
readBytes Int
srcsize = do
  STRef (PrimState m) Int
posRef <- forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall a b. (a -> b) -> a -> b
$ forall a s. a -> ST s (STRef s a)
newSTRef Int
0
  let puller :: Ptr CChar -> CSize -> m ()
      puller :: Ptr CChar -> CSize -> m ()
puller Ptr CChar
dst CSize
dstsize = do
        Int
pos <- forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef STRef (PrimState m) Int
posRef
        let pos' :: Int
pos' = Int
pos forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
dstsize
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
pos' forall a. Ord a => a -> a -> Bool
> Int
srcsize) forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *). Int -> Int -> m ()
sizeCheckFailed (Int
srcsize forall a. Num a => a -> a -> a
- Int
pos) (Int
pos' forall a. Num a => a -> a -> a
- Int
pos)
        Int -> Ptr CChar -> CSize -> m ()
readBytes Int
pos Ptr CChar
dst (forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
dstsize)
        forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef (PrimState m) Int
posRef Int
pos'
  (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *).
(DirectDeserialise a, MonadST m, MonadThrow m) =>
(Ptr CChar -> CSize -> m ()) -> m a
directDeserialise Ptr CChar -> CSize -> m ()
puller forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO (forall s a. STRef s a -> ST s a
readSTRef STRef (PrimState m) Int
posRef)

-- | Helper function for size-checked deserialization.
-- Verifies that exactly the specified number of bytes are read.
directDeserialiseFromChecked ::
  forall m a.
  DirectDeserialise a =>
  MonadST m =>
  MonadThrow m =>
  (Int -> Ptr CChar -> CSize -> m ()) ->
  Int ->
  m a
directDeserialiseFromChecked :: forall (m :: * -> *) a.
(DirectDeserialise a, MonadST m, MonadThrow m) =>
(Int -> Ptr CChar -> CSize -> m ()) -> Int -> m a
directDeserialiseFromChecked Int -> Ptr CChar -> CSize -> m ()
readBytes Int
srcsize = do
  (a
r, Int
bytesRead) <- forall (m :: * -> *) a.
(DirectDeserialise a, MonadST m, MonadThrow m) =>
(Int -> Ptr CChar -> CSize -> m ()) -> Int -> m (a, Int)
directDeserialiseFrom Int -> Ptr CChar -> CSize -> m ()
readBytes Int
srcsize
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bytesRead forall a. Eq a => a -> a -> Bool
/= Int
srcsize) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *). Int -> Int -> m ()
sizeCheckFailed Int
srcsize Int
bytesRead
  forall (m :: * -> *) a. Monad m => a -> m a
return a
r

-- | Helper function for the common use case of deserializing from an in-memory
-- buffer.
-- Verifies that no more than the maximum number of bytes are read, and returns
-- the actual number of bytes read.
directDeserialiseBuf ::
  forall m a.
  DirectDeserialise a =>
  MonadST m =>
  MonadThrow m =>
  Ptr CChar ->
  Int ->
  m (a, Int)
directDeserialiseBuf :: forall (m :: * -> *) a.
(DirectDeserialise a, MonadST m, MonadThrow m) =>
Ptr CChar -> Int -> m (a, Int)
directDeserialiseBuf Ptr CChar
src =
  forall (m :: * -> *) a.
(DirectDeserialise a, MonadST m, MonadThrow m) =>
(Int -> Ptr CChar -> CSize -> m ()) -> Int -> m (a, Int)
directDeserialiseFrom (\Int
pos Ptr CChar
dst -> forall (m :: * -> *) a.
MonadST m =>
Ptr a -> Ptr a -> CSize -> m ()
copyMem Ptr CChar
dst (forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr CChar
src Int
pos))

-- | Helper function for size-checked deserialization from an in-memory buffer.
-- Verifies that exactly the specified number of bytes are read.
directDeserialiseBufChecked ::
  forall m a.
  DirectDeserialise a =>
  MonadST m =>
  MonadThrow m =>
  Ptr CChar ->
  Int ->
  m a
directDeserialiseBufChecked :: forall (m :: * -> *) a.
(DirectDeserialise a, MonadST m, MonadThrow m) =>
Ptr CChar -> Int -> m a
directDeserialiseBufChecked Ptr CChar
buf Int
srcsize = do
  (a
r, Int
bytesRead) <- forall (m :: * -> *) a.
(DirectDeserialise a, MonadST m, MonadThrow m) =>
Ptr CChar -> Int -> m (a, Int)
directDeserialiseBuf Ptr CChar
buf Int
srcsize
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bytesRead forall a. Eq a => a -> a -> Bool
/= Int
srcsize) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *). Int -> Int -> m ()
sizeCheckFailed Int
srcsize Int
bytesRead
  forall (m :: * -> *) a. Monad m => a -> m a
return a
r