{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE RankNTypes            #-}

-- | This module makes sure types are normalized inside programs.
module PlutusCore.Check.Normal
    ( checkProgram
    , checkTerm
    , isNormalType
    , NormCheckError (..)
    ) where

import PlutusPrelude

import PlutusCore.Core
import PlutusCore.Error

import Control.Monad.Except

-- | Ensure that all types in the 'Program' are normalized.
checkProgram
    :: (AsNormCheckError e tyname name uni fun ann, MonadError e m)
    => Program tyname name uni fun ann -> m ()
checkProgram :: Program tyname name uni fun ann -> m ()
checkProgram (Program ann
_ Version ann
_ Term tyname name uni fun ann
t) = Term tyname name uni fun ann -> m ()
forall e tyname name (uni :: * -> *) fun ann (m :: * -> *).
(AsNormCheckError e tyname name uni fun ann, MonadError e m) =>
Term tyname name uni fun ann -> m ()
checkTerm Term tyname name uni fun ann
t

-- | Ensure that all types in the 'Term' are normalized.
checkTerm
    :: (AsNormCheckError e tyname name uni fun ann, MonadError e m)
    => Term tyname name uni fun ann -> m ()
checkTerm :: Term tyname name uni fun ann -> m ()
checkTerm Term tyname name uni fun ann
p = AReview e (NormCheckError tyname name uni fun ann)
-> Either (NormCheckError tyname name uni fun ann) () -> m ()
forall e (m :: * -> *) t a.
MonadError e m =>
AReview e t -> Either t a -> m a
throwingEither AReview e (NormCheckError tyname name uni fun ann)
forall r tyname name (uni :: * -> *) fun ann.
AsNormCheckError r tyname name uni fun ann =>
Prism' r (NormCheckError tyname name uni fun ann)
_NormCheckError (Either (NormCheckError tyname name uni fun ann) () -> m ())
-> Either (NormCheckError tyname name uni fun ann) () -> m ()
forall a b. (a -> b) -> a -> b
$ Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check Term tyname name uni fun ann
p

check :: Term tyname name uni fun ann -> Either (NormCheckError tyname name uni fun ann) ()
check :: Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check (Error ann
_ Type tyname uni ann
ty)           = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
ty
check (TyInst ann
_ Term tyname name uni fun ann
t Type tyname uni ann
ty)        = Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check Term tyname name uni fun ann
t Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
ty
check (IWrap ann
_ Type tyname uni ann
pat Type tyname uni ann
arg Term tyname name uni fun ann
term) = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
pat Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
arg Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check Term tyname name uni fun ann
term
check (Unwrap ann
_ Term tyname name uni fun ann
t)           = Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check Term tyname name uni fun ann
t
check (LamAbs ann
_ name
_ Type tyname uni ann
ty Term tyname name uni fun ann
t)      = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
ty Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check Term tyname name uni fun ann
t
check (Apply ann
_ Term tyname name uni fun ann
t1 Term tyname name uni fun ann
t2)        = Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check Term tyname name uni fun ann
t1 Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check Term tyname name uni fun ann
t2
check (TyAbs ann
_ tyname
_ Kind ann
_ Term tyname name uni fun ann
t)        = Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname name (uni :: * -> *) fun ann.
Term tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
check Term tyname name uni fun ann
t
check Var{}                  = () -> Either (NormCheckError tyname name uni fun ann) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
check Constant{}             = () -> Either (NormCheckError tyname name uni fun ann) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
check Builtin{}              = () -> Either (NormCheckError tyname name uni fun ann) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

isNormalType :: Type tyname uni ann -> Bool
isNormalType :: Type tyname uni ann -> Bool
isNormalType = Either (NormCheckError tyname Any uni Any ann) () -> Bool
forall a b. Either a b -> Bool
isRight (Either (NormCheckError tyname Any uni Any ann) () -> Bool)
-> (Type tyname uni ann
    -> Either (NormCheckError tyname Any uni Any ann) ())
-> Type tyname uni ann
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type tyname uni ann
-> Either (NormCheckError tyname Any uni Any ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType

normalType :: Type tyname uni ann -> Either (NormCheckError tyname name uni fun ann) ()
normalType :: Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType (TyFun ann
_ Type tyname uni ann
i Type tyname uni ann
o)       = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
i Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
o
normalType (TyForall ann
_ tyname
_ Kind ann
_ Type tyname uni ann
ty) = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
ty
normalType (TyIFix ann
_ Type tyname uni ann
pat Type tyname uni ann
arg)  = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
pat Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
arg
normalType (TyLam ann
_ tyname
_ Kind ann
_ Type tyname uni ann
ty)    = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
ty
-- See Note [PLC types and universes].
normalType TyBuiltin{}         = () -> Either (NormCheckError tyname name uni fun ann) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
normalType Type tyname uni ann
ty                  = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
neutralType Type tyname uni ann
ty

neutralType :: Type tyname uni ann -> Either (NormCheckError tyname name uni fun ann) ()
neutralType :: Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
neutralType TyVar{}           = () -> Either (NormCheckError tyname name uni fun ann) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
neutralType (TyApp ann
_ Type tyname uni ann
ty1 Type tyname uni ann
ty2) = Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
neutralType Type tyname uni ann
ty1 Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
-> Either (NormCheckError tyname name uni fun ann) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall tyname (uni :: * -> *) ann name fun.
Type tyname uni ann
-> Either (NormCheckError tyname name uni fun ann) ()
normalType Type tyname uni ann
ty2
neutralType Type tyname uni ann
ty                = NormCheckError tyname name uni fun ann
-> Either (NormCheckError tyname name uni fun ann) ()
forall a b. a -> Either a b
Left (ann
-> Type tyname uni ann
-> Text
-> NormCheckError tyname name uni fun ann
forall tyname name (uni :: * -> *) fun ann.
ann
-> Type tyname uni ann
-> Text
-> NormCheckError tyname name uni fun ann
BadType (Type tyname uni ann -> ann
forall tyname (uni :: * -> *) ann. Type tyname uni ann -> ann
typeAnn Type tyname uni ann
ty) Type tyname uni ann
ty Text
"neutral type")