{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Cardano.Crypto.Libsodium.Hash (
  SodiumHashAlgorithm (..),
  digestMLockedStorable,
  digestMLockedBS,
  expandHash,
  expandHashWith,
) where

import Data.Proxy (Proxy (..))
import Data.Word (Word8)
import Foreign.C.Types (CSize)
import Foreign.Ptr (castPtr, plusPtr)
import Foreign.Storable (Storable (poke))
import GHC.TypeLits

import Cardano.Crypto.Hash (HashAlgorithm (SizeHash))
import Cardano.Crypto.Libsodium.Hash.Class
import Cardano.Crypto.Libsodium.MLockedBytes.Internal
import Cardano.Crypto.Libsodium.Memory
import Control.Monad.Class.MonadST (MonadST (..))
import Control.Monad.Class.MonadThrow (MonadThrow)
import Control.Monad.ST.Unsafe (unsafeIOToST)

-------------------------------------------------------------------------------
-- Hash expansion
-------------------------------------------------------------------------------

expandHash ::
  forall h m proxy.
  (SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
  proxy h ->
  MLockedSizedBytes (SizeHash h) ->
  m (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
expandHash :: forall h (m :: * -> *) (proxy :: * -> *).
(SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
proxy h
-> MLockedSizedBytes (SizeHash h)
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
expandHash = forall h (m :: * -> *) (proxy :: * -> *).
(SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
MLockedAllocator m
-> proxy h
-> MLockedSizedBytes (SizeHash h)
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
expandHashWith forall (m :: * -> *). MonadST m => MLockedAllocator m
mlockedMalloc

expandHashWith ::
  forall h m proxy.
  (SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
  MLockedAllocator m ->
  proxy h ->
  MLockedSizedBytes (SizeHash h) ->
  m (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h))
expandHashWith :: forall h (m :: * -> *) (proxy :: * -> *).
(SodiumHashAlgorithm h, MonadST m, MonadThrow m) =>
MLockedAllocator m
-> proxy h
-> MLockedSizedBytes (SizeHash h)
-> m (MLockedSizedBytes (SizeHash h),
      MLockedSizedBytes (SizeHash h))
expandHashWith MLockedAllocator m
allocator proxy h
h (MLSB MLockedForeignPtr (SizedVoid (SizeHash h))
sfptr) = do
  forall (m :: * -> *) a b.
MonadST m =>
MLockedForeignPtr a -> (Ptr a -> m b) -> m b
withMLockedForeignPtr MLockedForeignPtr (SizedVoid (SizeHash h))
sfptr forall a b. (a -> b) -> a -> b
$ \Ptr (SizedVoid (SizeHash h))
ptr -> do
    MLockedSizedBytes (SizeHash h)
l <- forall a b (m :: * -> *).
(MonadThrow m, MonadST m) =>
MLockedAllocator m -> CSize -> (Ptr a -> m b) -> m b
mlockedAllocaWith MLockedAllocator m
allocator CSize
size1 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr' -> do
      forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a s. IO a -> ST s a
unsafeIOToST forall a b. (a -> b) -> a -> b
$ do
        forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr' (Word8
1 :: Word8)
        forall (m :: * -> *) a.
MonadST m =>
Ptr a -> Ptr a -> CSize -> m ()
copyMem (forall a b. Ptr a -> Ptr b
castPtr (forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
ptr' Int
1)) Ptr (SizedVoid (SizeHash h))
ptr CSize
size
        forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
h Ptr Word8
ptr' (forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size1)

    MLockedSizedBytes (SizeHash h)
r <- forall a b (m :: * -> *).
(MonadThrow m, MonadST m) =>
MLockedAllocator m -> CSize -> (Ptr a -> m b) -> m b
mlockedAllocaWith MLockedAllocator m
allocator CSize
size1 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr' -> do
      forall (m :: * -> *) a. MonadST m => ST (PrimState m) a -> m a
stToIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a s. IO a -> ST s a
unsafeIOToST forall a b. (a -> b) -> a -> b
$ do
        forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr' (Word8
2 :: Word8)
        forall (m :: * -> *) a.
MonadST m =>
Ptr a -> Ptr a -> CSize -> m ()
copyMem (forall a b. Ptr a -> Ptr b
castPtr (forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr Word8
ptr' Int
1)) Ptr (SizedVoid (SizeHash h))
ptr CSize
size
        forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
h Ptr Word8
ptr' (forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size1)

    forall (m :: * -> *) a. Monad m => a -> m a
return (MLockedSizedBytes (SizeHash h)
l, MLockedSizedBytes (SizeHash h)
r)
  where
    size1 :: CSize
    size1 :: CSize
size1 = CSize
size forall a. Num a => a -> a -> a
+ CSize
1

    size :: CSize
    size :: CSize
size = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @(SizeHash h))