{-# LANGUAGE DataKinds         #-}
{-# LANGUAGE NamedFieldPuns    #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TypeFamilies      #-}
{-# LANGUAGE ViewPatterns      #-}

{-# OPTIONS_GHC -fno-specialise #-}
{-# OPTIONS_GHC -Wno-simplifiable-class-constraints #-}
{-# OPTIONS_GHC -fno-omit-interface-pragmas #-}
{-# OPTIONS_GHC -fno-ignore-interface-pragmas #-}
module Plutus.Script.Utils.V2.Typed.Scripts.StakeValidators
    ( mkForwardingStakeValidator
    , forwardToValidator
    ) where

import Plutus.Script.Utils.Typed (mkUntypedStakeValidator)
import Plutus.V2.Ledger.Api (Address (Address, addressCredential), Credential (ScriptCredential), StakeValidator,
                             ValidatorHash, mkStakeValidatorScript)
import Plutus.V2.Ledger.Contexts (ScriptContext (ScriptContext, scriptContextPurpose, scriptContextTxInfo),
                                  ScriptPurpose (Certifying, Rewarding), TxInfo (TxInfo, txInfoInputs))
import Plutus.V2.Ledger.Contexts qualified as PV2
import Plutus.V2.Ledger.Tx (TxOut (TxOut, txOutAddress))
import PlutusTx qualified
import PlutusTx.Prelude (Bool (False), any, ($), (.), (==))

-- TODO: we should add a TypedStakeValidator interface here

-- | A stake validator that checks whether the validator script was run
--   in the right transaction.
mkForwardingStakeValidator :: ValidatorHash -> StakeValidator
mkForwardingStakeValidator :: ValidatorHash -> StakeValidator
mkForwardingStakeValidator ValidatorHash
vshsh =
    CompiledCode (BuiltinData -> BuiltinData -> ()) -> StakeValidator
mkStakeValidatorScript
    (CompiledCode (BuiltinData -> BuiltinData -> ()) -> StakeValidator)
-> CompiledCode (BuiltinData -> BuiltinData -> ())
-> StakeValidator
forall a b. (a -> b) -> a -> b
$ $$(PlutusTx.compile [|| \(hsh :: ValidatorHash) ->
        mkUntypedStakeValidator (forwardToValidator hsh)
        ||])
      CompiledCode (ValidatorHash -> BuiltinData -> BuiltinData -> ())
-> CompiledCodeIn DefaultUni DefaultFun ValidatorHash
-> CompiledCode (BuiltinData -> BuiltinData -> ())
forall (uni :: * -> *) fun a b.
(Closed uni, Everywhere uni Flat, Flat fun,
 Everywhere uni PrettyConst, GShow uni, Pretty fun) =>
CompiledCodeIn uni fun (a -> b)
-> CompiledCodeIn uni fun a -> CompiledCodeIn uni fun b
`PlutusTx.applyCode` ValidatorHash -> CompiledCodeIn DefaultUni DefaultFun ValidatorHash
forall (uni :: * -> *) a fun.
(Lift uni a, Throwable uni fun, Typecheckable uni fun) =>
a -> CompiledCodeIn uni fun a
PlutusTx.liftCode ValidatorHash
vshsh

{-# INLINABLE forwardToValidator #-}
forwardToValidator :: ValidatorHash -> () -> ScriptContext -> Bool
forwardToValidator :: ValidatorHash -> () -> ScriptContext -> Bool
forwardToValidator ValidatorHash
h ()
_ ScriptContext{scriptContextTxInfo :: ScriptContext -> TxInfo
scriptContextTxInfo=TxInfo{[TxInInfo]
txInfoInputs :: [TxInInfo]
txInfoInputs :: TxInfo -> [TxInInfo]
txInfoInputs}, ScriptPurpose
scriptContextPurpose :: ScriptPurpose
scriptContextPurpose :: ScriptContext -> ScriptPurpose
scriptContextPurpose} =
    let checkHash :: TxOut -> Bool
checkHash TxOut{txOutAddress :: TxOut -> Address
txOutAddress=Address{addressCredential :: Address -> Credential
addressCredential=ScriptCredential ValidatorHash
vh}} = ValidatorHash
vh ValidatorHash -> ValidatorHash -> Bool
forall a. Eq a => a -> a -> Bool
== ValidatorHash
h
        checkHash TxOut
_                                                                  = Bool
False
        result :: Bool
result = (TxInInfo -> Bool) -> [TxInInfo] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (TxOut -> Bool
checkHash (TxOut -> Bool) -> (TxInInfo -> TxOut) -> TxInInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TxInInfo -> TxOut
PV2.txInInfoResolved) [TxInInfo]
txInfoInputs
    in case ScriptPurpose
scriptContextPurpose of
        Rewarding StakingCredential
_  -> Bool
result
        Certifying DCert
_ -> Bool
result
        ScriptPurpose
_            -> Bool
False