{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE StandaloneDeriving #-}

module Test.Crypto.EqST where

import Control.Monad.Class.MonadST (MonadST)
import qualified Data.Vector as Vec
import GHC.TypeLits (KnownNat)

import Cardano.Crypto.DSIGN.Class
import Cardano.Crypto.DSIGN.Ed25519
import Cardano.Crypto.KES.Simple
import Cardano.Crypto.Libsodium.MLockedBytes.Internal
import Cardano.Crypto.Libsodium.MLockedSeed

-- | Monadic flavor of 'Eq', for things that can only be compared in a monadic
-- context that satisfies 'MonadST'.
-- This is needed because we cannot have a sound 'Eq' instance on mlocked
-- memory types, but we do need to compare them for equality in tests.
class EqST a where
  equalsM :: MonadST m => a -> a -> m Bool

nequalsM :: (MonadST m, EqST a) => a -> a -> m Bool
nequalsM :: forall (m :: * -> *) a. (MonadST m, EqST a) => a -> a -> m Bool
nequalsM a
a a
b = Bool -> Bool
not forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM a
a a
b

-- | Infix version of 'equalsM'
(==!) :: (MonadST m, EqST a) => a -> a -> m Bool
==! :: forall (m :: * -> *) a. (MonadST m, EqST a) => a -> a -> m Bool
(==!) = forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM

infix 4 ==!

-- | Infix version of 'nequalsM'
(!=!) :: (MonadST m, EqST a) => a -> a -> m Bool
!=! :: forall (m :: * -> *) a. (MonadST m, EqST a) => a -> a -> m Bool
(!=!) = forall (m :: * -> *) a. (MonadST m, EqST a) => a -> a -> m Bool
nequalsM

infix 4 !=!

instance EqST a => EqST (Maybe a) where
  equalsM :: forall (m :: * -> *). MonadST m => Maybe a -> Maybe a -> m Bool
equalsM Maybe a
Nothing Maybe a
Nothing = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
  equalsM (Just a
a) (Just a
b) = forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM a
a a
b
  equalsM Maybe a
_ Maybe a
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

instance (EqST a, EqST b) => EqST (Either a b) where
  equalsM :: forall (m :: * -> *).
MonadST m =>
Either a b -> Either a b -> m Bool
equalsM (Left a
x) (Left a
y) = forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM a
x a
y
  equalsM (Right b
x) (Right b
y) = forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM b
x b
y
  equalsM Either a b
_ Either a b
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

instance (EqST a, EqST b) => EqST (a, b) where
  equalsM :: forall (m :: * -> *). MonadST m => (a, b) -> (a, b) -> m Bool
equalsM (a
a, b
b) (a
a', b
b') = Bool -> Bool -> Bool
(&&) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM a
a a
a' forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM b
b b
b'

instance (EqST a, EqST b, EqST c) => EqST (a, b, c) where
  equalsM :: forall (m :: * -> *). MonadST m => (a, b, c) -> (a, b, c) -> m Bool
equalsM (a
a, b
b, c
c) (a
a', b
b', c
c') = forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM ((a
a, b
b), c
c) ((a
a', b
b'), c
c')

instance (EqST a, EqST b, EqST c, EqST d) => EqST (a, b, c, d) where
  equalsM :: forall (m :: * -> *).
MonadST m =>
(a, b, c, d) -> (a, b, c, d) -> m Bool
equalsM (a
a, b
b, c
c, d
d) (a
a', b
b', c
c', d
d') = forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM ((a
a, b
b, c
c), d
d) ((a
a', b
b', c
c'), d
d')

-- TODO: If anyone needs larger tuples, add more instances here...

-- | Helper newtype, useful for defining 'EqST' in terms of 'Eq' for types that
-- have sound 'Eq' instances, using @DerivingVia@. An 'Applicative' context
-- must be provided for such instances to work, so this will generally require
-- @StandaloneDeriving@ as well.
--
-- Ex.: @deriving via PureEq Int instance Applicative m => EqST m Int@
newtype PureEqST a = PureEqST a

instance Eq a => EqST (PureEqST a) where
  equalsM :: forall (m :: * -> *).
MonadST m =>
PureEqST a -> PureEqST a -> m Bool
equalsM (PureEqST a
a) (PureEqST a
b) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
a forall a. Eq a => a -> a -> Bool
== a
b)

instance KnownNat n => EqST (MLockedSizedBytes n) where
  equalsM :: forall (m :: * -> *).
MonadST m =>
MLockedSizedBytes n -> MLockedSizedBytes n -> m Bool
equalsM = forall (n :: Nat) (m :: * -> *).
(MonadST m, KnownNat n) =>
MLockedSizedBytes n -> MLockedSizedBytes n -> m Bool
mlsbEq

deriving via
  MLockedSizedBytes n
  instance
    KnownNat n => EqST (MLockedSeed n)

deriving via
  (MLockedSizedBytes (SizeSignKeyDSIGN Ed25519DSIGN))
  instance
    EqST (SignKeyDSIGNM Ed25519DSIGN)

instance EqST (SignKeyDSIGNM d) => EqST (SignKeyKES (SimpleKES d t)) where
  equalsM :: forall (m :: * -> *).
MonadST m =>
SignKeyKES (SimpleKES d t) -> SignKeyKES (SimpleKES d t) -> m Bool
equalsM (ThunkySignKeySimpleKES Vector (SignKeyDSIGNM d)
a) (ThunkySignKeySimpleKES Vector (SignKeyDSIGNM d)
b) =
    -- No need to check that lengths agree, the types already guarantee this.
    Vector Bool -> Bool
Vec.and forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> Vector a -> Vector b -> m (Vector c)
Vec.zipWithM forall a (m :: * -> *). (EqST a, MonadST m) => a -> a -> m Bool
equalsM Vector (SignKeyDSIGNM d)
a Vector (SignKeyDSIGNM d)
b