{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE DeriveAnyClass      #-}
{-# LANGUAGE DerivingStrategies  #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE TypeFamilies        #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -Wno-simplifiable-class-constraints #-}
{-# OPTIONS_GHC -fno-omit-interface-pragmas #-}
{-# OPTIONS_GHC -fno-specialise #-}

module Plutus.Script.Utils.V1.Typed.Scripts.Validators
  ( UntypedValidator,
    mkUntypedValidator,
    ---
    ValidatorTypes (..),
    ValidatorType,
    TypedValidator,
    mkTypedValidator,
    mkTypedValidatorParam,
    validatorHash,
    validatorAddress,
    validatorScript,
    vValidatorScript,
    forwardingMintingPolicy,
    vForwardingMintingPolicy,
    forwardingMintingPolicyHash,
    generalise,
    ---
    WrongOutTypeError (..),
    ConnectionError (..),
    checkValidatorAddress,
    checkDatum,
    checkRedeemer,
  )
where

import Control.Monad (unless)
import Control.Monad.Except (MonadError (throwError))
import Data.Aeson (FromJSON, ToJSON)
import Data.Kind (Type)
import GHC.Generics (Generic)
import Plutus.Script.Utils.Scripts (Datum, Language (PlutusV1), Versioned (Versioned))
import Plutus.Script.Utils.Typed (DatumType, RedeemerType,
                                  TypedValidator (TypedValidator, tvForwardingMPS, tvForwardingMPSHash, tvValidator, tvValidatorHash),
                                  UntypedValidator, ValidatorTypes, forwardingMintingPolicy,
                                  forwardingMintingPolicyHash, generalise, mkUntypedValidator, vForwardingMintingPolicy,
                                  vValidatorScript, validatorAddress, validatorHash, validatorScript)
import Plutus.Script.Utils.V1.Scripts qualified as Scripts
import Plutus.Script.Utils.V1.Typed.Scripts.MonetaryPolicies qualified as MPS
import Plutus.V1.Ledger.Address qualified as PV1
import Plutus.V1.Ledger.Api qualified as PV1
import PlutusCore.Default (DefaultUni)
import PlutusTx (CompiledCode, Lift, applyCode, liftCode)
import Prettyprinter (Pretty (pretty), viaShow, (<+>))

-- | The type of validators for the given connection type.
type ValidatorType (a :: Type) = DatumType a -> RedeemerType a -> PV1.ScriptContext -> Bool

-- | Make a 'TypedValidator' from the 'CompiledCode' of a validator script and its wrapper.
mkTypedValidator ::
  -- | Validator script (compiled)
  CompiledCode (ValidatorType a) ->
  -- | A wrapper for the compiled validator
  CompiledCode (ValidatorType a -> UntypedValidator) ->
  TypedValidator a
mkTypedValidator :: CompiledCode (ValidatorType a)
-> CompiledCode (ValidatorType a -> UntypedValidator)
-> TypedValidator a
mkTypedValidator CompiledCode (ValidatorType a)
vc CompiledCode (ValidatorType a -> UntypedValidator)
wrapper =
  TypedValidator :: forall a.
Versioned Validator
-> ValidatorHash
-> Versioned MintingPolicy
-> MintingPolicyHash
-> TypedValidator a
TypedValidator
    { tvValidator :: Versioned Validator
tvValidator = Validator -> Language -> Versioned Validator
forall script. script -> Language -> Versioned script
Versioned Validator
val Language
PlutusV1
    , tvValidatorHash :: ValidatorHash
tvValidatorHash = ValidatorHash
hsh
    , tvForwardingMPS :: Versioned MintingPolicy
tvForwardingMPS = MintingPolicy -> Language -> Versioned MintingPolicy
forall script. script -> Language -> Versioned script
Versioned MintingPolicy
mps Language
PlutusV1
    , tvForwardingMPSHash :: MintingPolicyHash
tvForwardingMPSHash = MintingPolicy -> MintingPolicyHash
Scripts.mintingPolicyHash MintingPolicy
mps
    }
  where
    val :: Validator
val = CompiledCode UntypedValidator -> Validator
PV1.mkValidatorScript (CompiledCode UntypedValidator -> Validator)
-> CompiledCode UntypedValidator -> Validator
forall a b. (a -> b) -> a -> b
$ CompiledCode (ValidatorType a -> UntypedValidator)
wrapper CompiledCode (ValidatorType a -> UntypedValidator)
-> CompiledCode (ValidatorType a) -> CompiledCode UntypedValidator
forall (uni :: * -> *) fun a b.
(Closed uni, Everywhere uni Flat, Flat fun,
 Everywhere uni PrettyConst, GShow uni, Pretty fun) =>
CompiledCodeIn uni fun (a -> b)
-> CompiledCodeIn uni fun a -> CompiledCodeIn uni fun b
`applyCode` CompiledCode (ValidatorType a)
vc
    hsh :: ValidatorHash
hsh = Validator -> ValidatorHash
Scripts.validatorHash Validator
val
    mps :: MintingPolicy
mps = ValidatorHash -> MintingPolicy
MPS.mkForwardingMintingPolicy ValidatorHash
hsh

-- | Make a 'TypedValidator' from the 'CompiledCode' of a parameterized validator script and its wrapper.
mkTypedValidatorParam ::
  forall a param.
  Lift DefaultUni param =>
  -- | Validator script (compiled)
  CompiledCode (param -> ValidatorType a) ->
  -- | A wrapper for the compiled validator
  CompiledCode (ValidatorType a -> UntypedValidator) ->
  -- | The extra paramater for the validator script
  param ->
  TypedValidator a
mkTypedValidatorParam :: CompiledCode (param -> ValidatorType a)
-> CompiledCode (ValidatorType a -> UntypedValidator)
-> param
-> TypedValidator a
mkTypedValidatorParam CompiledCode (param -> ValidatorType a)
vc CompiledCode (ValidatorType a -> UntypedValidator)
wrapper param
param =
  CompiledCode (ValidatorType a)
-> CompiledCode (ValidatorType a -> UntypedValidator)
-> TypedValidator a
forall a.
CompiledCode (ValidatorType a)
-> CompiledCode (ValidatorType a -> UntypedValidator)
-> TypedValidator a
mkTypedValidator (CompiledCode (param -> ValidatorType a)
vc CompiledCode (param -> ValidatorType a)
-> CompiledCodeIn DefaultUni DefaultFun param
-> CompiledCode (ValidatorType a)
forall (uni :: * -> *) fun a b.
(Closed uni, Everywhere uni Flat, Flat fun,
 Everywhere uni PrettyConst, GShow uni, Pretty fun) =>
CompiledCodeIn uni fun (a -> b)
-> CompiledCodeIn uni fun a -> CompiledCodeIn uni fun b
`applyCode` param -> CompiledCodeIn DefaultUni DefaultFun param
forall (uni :: * -> *) a fun.
(Lift uni a, Throwable uni fun, Typecheckable uni fun) =>
a -> CompiledCodeIn uni fun a
liftCode param
param) CompiledCode (ValidatorType a -> UntypedValidator)
wrapper

data WrongOutTypeError
  = ExpectedScriptGotPubkey
  | ExpectedPubkeyGotScript
  deriving stock (Int -> WrongOutTypeError -> ShowS
[WrongOutTypeError] -> ShowS
WrongOutTypeError -> String
(Int -> WrongOutTypeError -> ShowS)
-> (WrongOutTypeError -> String)
-> ([WrongOutTypeError] -> ShowS)
-> Show WrongOutTypeError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WrongOutTypeError] -> ShowS
$cshowList :: [WrongOutTypeError] -> ShowS
show :: WrongOutTypeError -> String
$cshow :: WrongOutTypeError -> String
showsPrec :: Int -> WrongOutTypeError -> ShowS
$cshowsPrec :: Int -> WrongOutTypeError -> ShowS
Show, WrongOutTypeError -> WrongOutTypeError -> Bool
(WrongOutTypeError -> WrongOutTypeError -> Bool)
-> (WrongOutTypeError -> WrongOutTypeError -> Bool)
-> Eq WrongOutTypeError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WrongOutTypeError -> WrongOutTypeError -> Bool
$c/= :: WrongOutTypeError -> WrongOutTypeError -> Bool
== :: WrongOutTypeError -> WrongOutTypeError -> Bool
$c== :: WrongOutTypeError -> WrongOutTypeError -> Bool
Eq, Eq WrongOutTypeError
Eq WrongOutTypeError
-> (WrongOutTypeError -> WrongOutTypeError -> Ordering)
-> (WrongOutTypeError -> WrongOutTypeError -> Bool)
-> (WrongOutTypeError -> WrongOutTypeError -> Bool)
-> (WrongOutTypeError -> WrongOutTypeError -> Bool)
-> (WrongOutTypeError -> WrongOutTypeError -> Bool)
-> (WrongOutTypeError -> WrongOutTypeError -> WrongOutTypeError)
-> (WrongOutTypeError -> WrongOutTypeError -> WrongOutTypeError)
-> Ord WrongOutTypeError
WrongOutTypeError -> WrongOutTypeError -> Bool
WrongOutTypeError -> WrongOutTypeError -> Ordering
WrongOutTypeError -> WrongOutTypeError -> WrongOutTypeError
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: WrongOutTypeError -> WrongOutTypeError -> WrongOutTypeError
$cmin :: WrongOutTypeError -> WrongOutTypeError -> WrongOutTypeError
max :: WrongOutTypeError -> WrongOutTypeError -> WrongOutTypeError
$cmax :: WrongOutTypeError -> WrongOutTypeError -> WrongOutTypeError
>= :: WrongOutTypeError -> WrongOutTypeError -> Bool
$c>= :: WrongOutTypeError -> WrongOutTypeError -> Bool
> :: WrongOutTypeError -> WrongOutTypeError -> Bool
$c> :: WrongOutTypeError -> WrongOutTypeError -> Bool
<= :: WrongOutTypeError -> WrongOutTypeError -> Bool
$c<= :: WrongOutTypeError -> WrongOutTypeError -> Bool
< :: WrongOutTypeError -> WrongOutTypeError -> Bool
$c< :: WrongOutTypeError -> WrongOutTypeError -> Bool
compare :: WrongOutTypeError -> WrongOutTypeError -> Ordering
$ccompare :: WrongOutTypeError -> WrongOutTypeError -> Ordering
$cp1Ord :: Eq WrongOutTypeError
Ord, (forall x. WrongOutTypeError -> Rep WrongOutTypeError x)
-> (forall x. Rep WrongOutTypeError x -> WrongOutTypeError)
-> Generic WrongOutTypeError
forall x. Rep WrongOutTypeError x -> WrongOutTypeError
forall x. WrongOutTypeError -> Rep WrongOutTypeError x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep WrongOutTypeError x -> WrongOutTypeError
$cfrom :: forall x. WrongOutTypeError -> Rep WrongOutTypeError x
Generic)
  deriving anyclass ([WrongOutTypeError] -> Encoding
[WrongOutTypeError] -> Value
WrongOutTypeError -> Encoding
WrongOutTypeError -> Value
(WrongOutTypeError -> Value)
-> (WrongOutTypeError -> Encoding)
-> ([WrongOutTypeError] -> Value)
-> ([WrongOutTypeError] -> Encoding)
-> ToJSON WrongOutTypeError
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
toEncodingList :: [WrongOutTypeError] -> Encoding
$ctoEncodingList :: [WrongOutTypeError] -> Encoding
toJSONList :: [WrongOutTypeError] -> Value
$ctoJSONList :: [WrongOutTypeError] -> Value
toEncoding :: WrongOutTypeError -> Encoding
$ctoEncoding :: WrongOutTypeError -> Encoding
toJSON :: WrongOutTypeError -> Value
$ctoJSON :: WrongOutTypeError -> Value
ToJSON, Value -> Parser [WrongOutTypeError]
Value -> Parser WrongOutTypeError
(Value -> Parser WrongOutTypeError)
-> (Value -> Parser [WrongOutTypeError])
-> FromJSON WrongOutTypeError
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
parseJSONList :: Value -> Parser [WrongOutTypeError]
$cparseJSONList :: Value -> Parser [WrongOutTypeError]
parseJSON :: Value -> Parser WrongOutTypeError
$cparseJSON :: Value -> Parser WrongOutTypeError
FromJSON)

-- | An error we can get while trying to type an existing transaction part.
data ConnectionError
  = WrongValidatorAddress PV1.Address PV1.Address
  | WrongOutType WrongOutTypeError
  | WrongValidatorType String
  | WrongRedeemerType PV1.BuiltinData
  | WrongDatumType PV1.BuiltinData
  | NoDatum PV1.TxOutRef PV1.DatumHash
  | UnknownRef PV1.TxOutRef
  deriving stock (Int -> ConnectionError -> ShowS
[ConnectionError] -> ShowS
ConnectionError -> String
(Int -> ConnectionError -> ShowS)
-> (ConnectionError -> String)
-> ([ConnectionError] -> ShowS)
-> Show ConnectionError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionError] -> ShowS
$cshowList :: [ConnectionError] -> ShowS
show :: ConnectionError -> String
$cshow :: ConnectionError -> String
showsPrec :: Int -> ConnectionError -> ShowS
$cshowsPrec :: Int -> ConnectionError -> ShowS
Show, ConnectionError -> ConnectionError -> Bool
(ConnectionError -> ConnectionError -> Bool)
-> (ConnectionError -> ConnectionError -> Bool)
-> Eq ConnectionError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnectionError -> ConnectionError -> Bool
$c/= :: ConnectionError -> ConnectionError -> Bool
== :: ConnectionError -> ConnectionError -> Bool
$c== :: ConnectionError -> ConnectionError -> Bool
Eq, Eq ConnectionError
Eq ConnectionError
-> (ConnectionError -> ConnectionError -> Ordering)
-> (ConnectionError -> ConnectionError -> Bool)
-> (ConnectionError -> ConnectionError -> Bool)
-> (ConnectionError -> ConnectionError -> Bool)
-> (ConnectionError -> ConnectionError -> Bool)
-> (ConnectionError -> ConnectionError -> ConnectionError)
-> (ConnectionError -> ConnectionError -> ConnectionError)
-> Ord ConnectionError
ConnectionError -> ConnectionError -> Bool
ConnectionError -> ConnectionError -> Ordering
ConnectionError -> ConnectionError -> ConnectionError
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ConnectionError -> ConnectionError -> ConnectionError
$cmin :: ConnectionError -> ConnectionError -> ConnectionError
max :: ConnectionError -> ConnectionError -> ConnectionError
$cmax :: ConnectionError -> ConnectionError -> ConnectionError
>= :: ConnectionError -> ConnectionError -> Bool
$c>= :: ConnectionError -> ConnectionError -> Bool
> :: ConnectionError -> ConnectionError -> Bool
$c> :: ConnectionError -> ConnectionError -> Bool
<= :: ConnectionError -> ConnectionError -> Bool
$c<= :: ConnectionError -> ConnectionError -> Bool
< :: ConnectionError -> ConnectionError -> Bool
$c< :: ConnectionError -> ConnectionError -> Bool
compare :: ConnectionError -> ConnectionError -> Ordering
$ccompare :: ConnectionError -> ConnectionError -> Ordering
$cp1Ord :: Eq ConnectionError
Ord, (forall x. ConnectionError -> Rep ConnectionError x)
-> (forall x. Rep ConnectionError x -> ConnectionError)
-> Generic ConnectionError
forall x. Rep ConnectionError x -> ConnectionError
forall x. ConnectionError -> Rep ConnectionError x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep ConnectionError x -> ConnectionError
$cfrom :: forall x. ConnectionError -> Rep ConnectionError x
Generic)

instance Pretty ConnectionError where
  pretty :: ConnectionError -> Doc ann
pretty (WrongValidatorAddress Address
a1 Address
a2) = Doc ann
"Wrong validator address. Expected:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Address -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Address
a1 Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"Actual:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Address -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Address
a2
  pretty (WrongOutType WrongOutTypeError
t)              = Doc ann
"Wrong out type:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> WrongOutTypeError -> Doc ann
forall a ann. Show a => a -> Doc ann
viaShow WrongOutTypeError
t
  pretty (WrongValidatorType String
t)        = Doc ann
"Wrong validator type:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
t
  pretty (WrongRedeemerType BuiltinData
d)         = Doc ann
"Wrong redeemer type" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Data -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (BuiltinData -> Data
PV1.builtinDataToData BuiltinData
d)
  pretty (WrongDatumType BuiltinData
d)            = Doc ann
"Wrong datum type" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Data -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (BuiltinData -> Data
PV1.builtinDataToData BuiltinData
d)
  pretty (NoDatum TxOutRef
t DatumHash
d)                 = Doc ann
"No datum with hash " Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> DatumHash -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty DatumHash
d Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"for tx output" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> TxOutRef -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty TxOutRef
t
  pretty (UnknownRef TxOutRef
d)                = Doc ann
"Unknown reference" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> TxOutRef -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty TxOutRef
d

-- | Checks that the given validator hash is consistent with the actual validator.
checkValidatorAddress :: forall a m. (MonadError ConnectionError m) => TypedValidator a -> PV1.Address -> m ()
checkValidatorAddress :: TypedValidator a -> Address -> m ()
checkValidatorAddress TypedValidator a
ct Address
actualAddr = do
  let expectedAddr :: Address
expectedAddr = TypedValidator a -> Address
forall a. TypedValidator a -> Address
validatorAddress TypedValidator a
ct
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Address
expectedAddr Address -> Address -> Bool
forall a. Eq a => a -> a -> Bool
== Address
actualAddr) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ ConnectionError -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ConnectionError -> m ()) -> ConnectionError -> m ()
forall a b. (a -> b) -> a -> b
$ Address -> Address -> ConnectionError
WrongValidatorAddress Address
expectedAddr Address
actualAddr

-- | Checks that the given redeemer script has the right type.
checkRedeemer ::
  forall inn m.
  (PV1.FromData (RedeemerType inn), MonadError ConnectionError m) =>
  TypedValidator inn ->
  PV1.Redeemer ->
  m (RedeemerType inn)
checkRedeemer :: TypedValidator inn -> Redeemer -> m (RedeemerType inn)
checkRedeemer TypedValidator inn
_ (PV1.Redeemer BuiltinData
d) =
  case BuiltinData -> Maybe (RedeemerType inn)
forall a. FromData a => BuiltinData -> Maybe a
PV1.fromBuiltinData BuiltinData
d of
    Just RedeemerType inn
v  -> RedeemerType inn -> m (RedeemerType inn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure RedeemerType inn
v
    Maybe (RedeemerType inn)
Nothing -> ConnectionError -> m (RedeemerType inn)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ConnectionError -> m (RedeemerType inn))
-> ConnectionError -> m (RedeemerType inn)
forall a b. (a -> b) -> a -> b
$ BuiltinData -> ConnectionError
WrongRedeemerType BuiltinData
d

-- | Checks that the given datum has the right type.
checkDatum ::
  forall a m.
  (PV1.FromData (DatumType a), MonadError ConnectionError m) =>
  TypedValidator a ->
  Datum ->
  m (DatumType a)
checkDatum :: TypedValidator a -> Datum -> m (DatumType a)
checkDatum TypedValidator a
_ (PV1.Datum BuiltinData
d) =
  case BuiltinData -> Maybe (DatumType a)
forall a. FromData a => BuiltinData -> Maybe a
PV1.fromBuiltinData BuiltinData
d of
    Just DatumType a
v  -> DatumType a -> m (DatumType a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DatumType a
v
    Maybe (DatumType a)
Nothing -> ConnectionError -> m (DatumType a)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ConnectionError -> m (DatumType a))
-> ConnectionError -> m (DatumType a)
forall a b. (a -> b) -> a -> b
$ BuiltinData -> ConnectionError
WrongDatumType BuiltinData
d