{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Test.Crypto.Instances (
  withMLSBFromPSB,
  withMLockedSeedFromPSB,
) where

import Cardano.Crypto.Libsodium
import Cardano.Crypto.Libsodium.MLockedSeed
import Cardano.Crypto.PinnedSizedBytes
import Control.Monad.Class.MonadST
import Control.Monad.Class.MonadThrow
import Data.Maybe (mapMaybe)
import Data.Proxy (Proxy (Proxy))
import GHC.Exts (fromList, fromListN, toList)
import GHC.TypeLits (KnownNat, natVal)
import Test.QuickCheck (Arbitrary (..))
import qualified Test.QuickCheck.Gen as Gen

-- We cannot allow this instance, because it doesn't guarantee timely
-- forgetting of the MLocked memory, and in a QuickCheck context, where
-- tens of thousands of these values may be generated, waiting for GC to clean
-- up after us could have us run over our mlock quota.
--
-- Instead, use 'arbitrary' to generate a suitably sized PinnedSizedBytes
-- value, and then mlsbFromPSB or withMLSBFromPSB to convert it to an
-- MLockedSizedBytes value.
--
-- instance KnownNat n => Arbitrary (MLockedSizedBytes n) where
--     arbitrary = unsafePerformIO . mlsbFromByteString . BS.pack <$> vectorOf size arbitrary
--       where
--         size :: Int
--         size = fromInteger (natVal (Proxy :: Proxy n))

mlsbFromPSB :: (MonadST m, KnownNat n) => PinnedSizedBytes n -> m (MLockedSizedBytes n)
mlsbFromPSB :: forall (m :: * -> *) (n :: Nat).
(MonadST m, KnownNat n) =>
PinnedSizedBytes n -> m (MLockedSizedBytes n)
mlsbFromPSB = forall (n :: Nat) (m :: * -> *).
(KnownNat n, MonadST m) =>
ByteString -> m (MLockedSizedBytes n)
mlsbFromByteString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat). PinnedSizedBytes n -> ByteString
psbToByteString

withMLSBFromPSB ::
  (MonadST m, MonadThrow m, KnownNat n) => PinnedSizedBytes n -> (MLockedSizedBytes n -> m a) -> m a
withMLSBFromPSB :: forall (m :: * -> *) (n :: Nat) a.
(MonadST m, MonadThrow m, KnownNat n) =>
PinnedSizedBytes n -> (MLockedSizedBytes n -> m a) -> m a
withMLSBFromPSB PinnedSizedBytes n
psb =
  forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    (forall (m :: * -> *) (n :: Nat).
(MonadST m, KnownNat n) =>
PinnedSizedBytes n -> m (MLockedSizedBytes n)
mlsbFromPSB PinnedSizedBytes n
psb)
    forall (m :: * -> *) (n :: Nat).
MonadST m =>
MLockedSizedBytes n -> m ()
mlsbFinalize

mlockedSeedFromPSB :: (MonadST m, KnownNat n) => PinnedSizedBytes n -> m (MLockedSeed n)
mlockedSeedFromPSB :: forall (m :: * -> *) (n :: Nat).
(MonadST m, KnownNat n) =>
PinnedSizedBytes n -> m (MLockedSeed n)
mlockedSeedFromPSB = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat). MLockedSizedBytes n -> MLockedSeed n
MLockedSeed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (n :: Nat).
(MonadST m, KnownNat n) =>
PinnedSizedBytes n -> m (MLockedSizedBytes n)
mlsbFromPSB

withMLockedSeedFromPSB ::
  (MonadST m, MonadThrow m, KnownNat n) => PinnedSizedBytes n -> (MLockedSeed n -> m a) -> m a
withMLockedSeedFromPSB :: forall (m :: * -> *) (n :: Nat) a.
(MonadST m, MonadThrow m, KnownNat n) =>
PinnedSizedBytes n -> (MLockedSeed n -> m a) -> m a
withMLockedSeedFromPSB PinnedSizedBytes n
psb =
  forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    (forall (m :: * -> *) (n :: Nat).
(MonadST m, KnownNat n) =>
PinnedSizedBytes n -> m (MLockedSeed n)
mlockedSeedFromPSB PinnedSizedBytes n
psb)
    forall (m :: * -> *) (n :: Nat). MonadST m => MLockedSeed n -> m ()
mlockedSeedFinalize

instance KnownNat n => Arbitrary (PinnedSizedBytes n) where
  arbitrary :: Gen (PinnedSizedBytes n)
arbitrary = do
    let Int
size :: Int = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
Proxy @n
    forall a b. Gen a -> (a -> Maybe b) -> Gen b
Gen.suchThatMap
      (forall l. IsList l => Int -> [Item l] -> l
fromListN Int
size forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Int -> Gen a -> Gen [a]
Gen.vectorOf Int
size forall a. Arbitrary a => Gen a
arbitrary)
      forall (n :: Nat).
KnownNat n =>
ByteString -> Maybe (PinnedSizedBytes n)
psbFromByteStringCheck
  shrink :: PinnedSizedBytes n -> [PinnedSizedBytes n]
shrink PinnedSizedBytes n
psb = case forall l. IsList l => l -> [Item l]
toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat). PinnedSizedBytes n -> ByteString
psbToByteString forall a b. (a -> b) -> a -> b
$ PinnedSizedBytes n
psb of
    [Word8]
bytes -> forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall (n :: Nat).
KnownNat n =>
ByteString -> Maybe (PinnedSizedBytes n)
psbFromByteStringCheck forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall l. IsList l => [Item l] -> l
fromList) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Arbitrary a => a -> [a]
shrink forall a b. (a -> b) -> a -> b
$ [Word8]
bytes