{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Cardano.Crypto.Libsodium.Hash.Class (
  SodiumHashAlgorithm (..),
  digestMLockedStorable,
  digestMLockedBS,
) where

import Control.Monad (unless)
import Data.Proxy (Proxy (..))
import Data.Type.Equality ((:~:) (..))
import Foreign.C.Error (errnoToIOError, getErrno)
import Foreign.Ptr (Ptr, castPtr, nullPtr)
import Foreign.Storable (Storable (sizeOf))
import GHC.IO.Exception (ioException)
import GHC.TypeLits

import qualified Data.ByteString as BS

import Cardano.Crypto.Hash (Blake2b_256, HashAlgorithm (SizeHash), SHA256)
import Cardano.Crypto.Libsodium.C
import Cardano.Crypto.Libsodium.MLockedBytes.Internal

-------------------------------------------------------------------------------
-- Type-Class
-------------------------------------------------------------------------------

class HashAlgorithm h => SodiumHashAlgorithm h where
  -- This function is in IO, it is "morally pure"
  -- and can be 'unsafePerformDupableIO'd.
  naclDigestPtr ::
    proxy h ->
    -- | input
    Ptr a ->
    -- | input length
    Int ->
    IO (MLockedSizedBytes (SizeHash h))

-- TODO: provide interface for multi-part?
-- That will be useful to hashing ('1' <> oldseed).

digestMLockedStorable ::
  forall h a proxy.
  (SodiumHashAlgorithm h, Storable a) =>
  proxy h -> Ptr a -> IO (MLockedSizedBytes (SizeHash h))
digestMLockedStorable :: forall h a (proxy :: * -> *).
(SodiumHashAlgorithm h, Storable a) =>
proxy h -> Ptr a -> IO (MLockedSizedBytes (SizeHash h))
digestMLockedStorable proxy h
p Ptr a
ptr =
  forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
p Ptr a
ptr ((forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: a)))

digestMLockedBS ::
  forall h proxy.
  SodiumHashAlgorithm h =>
  proxy h -> BS.ByteString -> IO (MLockedSizedBytes (SizeHash h))
digestMLockedBS :: forall h (proxy :: * -> *).
SodiumHashAlgorithm h =>
proxy h -> ByteString -> IO (MLockedSizedBytes (SizeHash h))
digestMLockedBS proxy h
p ByteString
bs =
  forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
bs forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) ->
    forall h (proxy :: * -> *) a.
SodiumHashAlgorithm h =>
proxy h -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash h))
naclDigestPtr proxy h
p (forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr) Int
len

-------------------------------------------------------------------------------
-- Instances
-------------------------------------------------------------------------------

instance SodiumHashAlgorithm SHA256 where
  naclDigestPtr ::
    forall proxy a. proxy SHA256 -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash SHA256))
  naclDigestPtr :: forall (proxy :: * -> *) a.
proxy SHA256
-> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash SHA256))
naclDigestPtr proxy SHA256
_ Ptr a
input Int
inputlen = do
    MLockedSizedBytes CRYPTO_SHA256_BYTES
output <- forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
m (MLockedSizedBytes n)
mlsbNew
    forall (n :: Nat) r (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> (SizedPtr n -> m r) -> m r
mlsbUseAsSizedPtr MLockedSizedBytes CRYPTO_SHA256_BYTES
output forall a b. (a -> b) -> a -> b
$ \SizedPtr CRYPTO_SHA256_BYTES
output' -> do
      Int
res <- SizedPtr CRYPTO_SHA256_BYTES -> Ptr CUChar -> CULLong -> IO Int
c_crypto_hash_sha256 SizedPtr CRYPTO_SHA256_BYTES
output' (forall a b. Ptr a -> Ptr b
castPtr Ptr a
input) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inputlen)
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
res forall a. Eq a => a -> a -> Bool
== Int
0) forall a b. (a -> b) -> a -> b
$ do
        Errno
errno <- IO Errno
getErrno
        forall a. IOException -> IO a
ioException forall a b. (a -> b) -> a -> b
$ String -> Errno -> Maybe Handle -> Maybe String -> IOException
errnoToIOError String
"digestMLocked @SHA256: c_crypto_hash_sha256" Errno
errno forall a. Maybe a
Nothing forall a. Maybe a
Nothing
    forall (m :: * -> *) a. Monad m => a -> m a
return MLockedSizedBytes CRYPTO_SHA256_BYTES
output

-- Test that manually written numbers are the same as in libsodium
_testSHA256 :: SizeHash SHA256 :~: CRYPTO_SHA256_BYTES
_testSHA256 :: SizeHash SHA256 :~: CRYPTO_SHA256_BYTES
_testSHA256 = forall {k} (a :: k). a :~: a
Refl

instance SodiumHashAlgorithm Blake2b_256 where
  naclDigestPtr ::
    forall proxy a. proxy Blake2b_256 -> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash Blake2b_256))
  naclDigestPtr :: forall (proxy :: * -> *) a.
proxy Blake2b_256
-> Ptr a -> Int -> IO (MLockedSizedBytes (SizeHash Blake2b_256))
naclDigestPtr proxy Blake2b_256
_ Ptr a
input Int
inputlen = do
    MLockedSizedBytes CRYPTO_SHA256_BYTES
output <- forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
m (MLockedSizedBytes n)
mlsbNew
    forall (m :: * -> *) (n :: Nat) r.
MonadST m =>
MLockedSizedBytes n -> (Ptr Word8 -> m r) -> m r
mlsbUseAsCPtr MLockedSizedBytes CRYPTO_SHA256_BYTES
output forall a b. (a -> b) -> a -> b
$ \Ptr Word8
output' -> do
      Int
res <-
        forall out key.
Ptr out
-> CSize -> Ptr CUChar -> CULLong -> Ptr key -> CSize -> IO Int
c_crypto_generichash_blake2b
          Ptr Word8
output'
          (forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @CRYPTO_BLAKE2B_256_BYTES)) -- output
          (forall a b. Ptr a -> Ptr b
castPtr Ptr a
input)
          (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inputlen) -- input
          forall a. Ptr a
nullPtr
          CSize
0 -- key, unused
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
res forall a. Eq a => a -> a -> Bool
== Int
0) forall a b. (a -> b) -> a -> b
$ do
        Errno
errno <- IO Errno
getErrno
        forall a. IOException -> IO a
ioException forall a b. (a -> b) -> a -> b
$
          String -> Errno -> Maybe Handle -> Maybe String -> IOException
errnoToIOError String
"digestMLocked @Blake2b_256: c_crypto_hash_sha256" Errno
errno forall a. Maybe a
Nothing forall a. Maybe a
Nothing
    forall (m :: * -> *) a. Monad m => a -> m a
return MLockedSizedBytes CRYPTO_SHA256_BYTES
output

_testBlake2b256 :: SizeHash Blake2b_256 :~: CRYPTO_BLAKE2B_256_BYTES
_testBlake2b256 :: SizeHash Blake2b_256 :~: CRYPTO_SHA256_BYTES
_testBlake2b256 = forall {k} (a :: k). a :~: a
Refl