{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}
module UntypedPlutusCore.Subst
    ( substVarA
    , substVar
    , termSubstNamesM
    , termSubstNames
    , termSubstFreeNamesA
    , termSubstFreeNames
    , termMapNames
    , programMapNames
    , uniquesTerm
    , vTerm
    ) where

import PlutusPrelude

import PlutusCore.Core (HasUniques)
import PlutusCore.Name
import UntypedPlutusCore.Core

import Control.Lens
import Data.Set as Set
import Data.Set.Lens (setOf)

purely :: ((a -> Identity b) -> c -> Identity d) -> (a -> b) -> c -> d
purely :: ((a -> Identity b) -> c -> Identity d) -> (a -> b) -> c -> d
purely = ((a -> Identity b) -> c -> Identity d) -> (a -> b) -> c -> d
coerce

-- | Applicatively replace a variable using the given function.
substVarA
    :: Applicative f
    => (name -> f (Maybe (Term name uni fun ann)))
    -> Term name uni fun ann
    -> f (Term name uni fun ann)
substVarA :: (name -> f (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> f (Term name uni fun ann)
substVarA name -> f (Maybe (Term name uni fun ann))
nameF t :: Term name uni fun ann
t@(Var ann
_ name
name) = Term name uni fun ann
-> Maybe (Term name uni fun ann) -> Term name uni fun ann
forall a. a -> Maybe a -> a
fromMaybe Term name uni fun ann
t (Maybe (Term name uni fun ann) -> Term name uni fun ann)
-> f (Maybe (Term name uni fun ann)) -> f (Term name uni fun ann)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> name -> f (Maybe (Term name uni fun ann))
nameF name
name
substVarA name -> f (Maybe (Term name uni fun ann))
_     Term name uni fun ann
t              = Term name uni fun ann -> f (Term name uni fun ann)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term name uni fun ann
t

-- | Replace a variable using the given function.
substVar
    :: (name -> Maybe (Term name uni fun ann))
    -> Term name uni fun ann
    -> Term name uni fun ann
substVar :: (name -> Maybe (Term name uni fun ann))
-> Term name uni fun ann -> Term name uni fun ann
substVar = ((name -> Identity (Maybe (Term name uni fun ann)))
 -> Term name uni fun ann -> Identity (Term name uni fun ann))
-> (name -> Maybe (Term name uni fun ann))
-> Term name uni fun ann
-> Term name uni fun ann
forall a b c d.
((a -> Identity b) -> c -> Identity d) -> (a -> b) -> c -> d
purely (name -> Identity (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> Identity (Term name uni fun ann)
forall (f :: * -> *) name (uni :: * -> *) fun ann.
Applicative f =>
(name -> f (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> f (Term name uni fun ann)
substVarA

-- | Naively monadically substitute names using the given function (i.e. do not substitute binders).
termSubstNamesM
    :: Monad m
    => (name -> m (Maybe (Term name uni fun ann)))
    -> Term name uni fun ann
    -> m (Term name uni fun ann)
termSubstNamesM :: (name -> m (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> m (Term name uni fun ann)
termSubstNamesM = LensLike
  (WrappedMonad m)
  (Term name uni fun ann)
  (Term name uni fun ann)
  (Term name uni fun ann)
  (Term name uni fun ann)
-> (Term name uni fun ann -> m (Term name uni fun ann))
-> Term name uni fun ann
-> m (Term name uni fun ann)
forall (m :: * -> *) a b.
Monad m =>
LensLike (WrappedMonad m) a b a b -> (b -> m b) -> a -> m b
transformMOf LensLike
  (WrappedMonad m)
  (Term name uni fun ann)
  (Term name uni fun ann)
  (Term name uni fun ann)
  (Term name uni fun ann)
forall name (uni :: * -> *) fun ann.
Traversal' (Term name uni fun ann) (Term name uni fun ann)
termSubterms ((Term name uni fun ann -> m (Term name uni fun ann))
 -> Term name uni fun ann -> m (Term name uni fun ann))
-> ((name -> m (Maybe (Term name uni fun ann)))
    -> Term name uni fun ann -> m (Term name uni fun ann))
-> (name -> m (Maybe (Term name uni fun ann)))
-> Term name uni fun ann
-> m (Term name uni fun ann)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (name -> m (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> m (Term name uni fun ann)
forall (f :: * -> *) name (uni :: * -> *) fun ann.
Applicative f =>
(name -> f (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> f (Term name uni fun ann)
substVarA

-- | Naively substitute names using the given function (i.e. do not substitute binders).
termSubstNames
    :: (name -> Maybe (Term name uni fun ann))
    -> Term name uni fun ann
    -> Term name uni fun ann
termSubstNames :: (name -> Maybe (Term name uni fun ann))
-> Term name uni fun ann -> Term name uni fun ann
termSubstNames = ((name -> Identity (Maybe (Term name uni fun ann)))
 -> Term name uni fun ann -> Identity (Term name uni fun ann))
-> (name -> Maybe (Term name uni fun ann))
-> Term name uni fun ann
-> Term name uni fun ann
forall a b c d.
((a -> Identity b) -> c -> Identity d) -> (a -> b) -> c -> d
purely (name -> Identity (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> Identity (Term name uni fun ann)
forall (m :: * -> *) name (uni :: * -> *) fun ann.
Monad m =>
(name -> m (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> m (Term name uni fun ann)
termSubstNamesM

-- | Applicatively substitute *free* names using the given function.
termSubstFreeNamesA
    :: (Applicative f, HasUnique name TermUnique)
    => (name -> f (Maybe (Term name uni fun ann)))
    -> Term name uni fun ann
    -> f (Term name uni fun ann)
termSubstFreeNamesA :: (name -> f (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubstFreeNamesA name -> f (Maybe (Term name uni fun ann))
f = Set TermUnique
-> Term name uni fun ann -> f (Term name uni fun ann)
forall a.
(Ord a, HasUnique name a) =>
Set a -> Term name uni fun ann -> f (Term name uni fun ann)
go Set TermUnique
forall a. Set a
Set.empty where
    go :: Set a -> Term name uni fun ann -> f (Term name uni fun ann)
go Set a
bvs var :: Term name uni fun ann
var@(Var ann
_ name
name)           =
        if (name
name name -> Getting a name a -> a
forall s a. s -> Getting a s a -> a
^. Getting a name a
forall a unique. HasUnique a unique => Lens' a unique
unique) a -> Set a -> Bool
forall a. Ord a => a -> Set a -> Bool
`member` Set a
bvs
            then Term name uni fun ann -> f (Term name uni fun ann)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term name uni fun ann
var
            else Term name uni fun ann
-> Maybe (Term name uni fun ann) -> Term name uni fun ann
forall a. a -> Maybe a -> a
fromMaybe Term name uni fun ann
var (Maybe (Term name uni fun ann) -> Term name uni fun ann)
-> f (Maybe (Term name uni fun ann)) -> f (Term name uni fun ann)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> name -> f (Maybe (Term name uni fun ann))
f name
name
    go Set a
bvs (LamAbs ann
ann name
name Term name uni fun ann
body) = ann -> name -> Term name uni fun ann -> Term name uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann -> Term name uni fun ann
LamAbs ann
ann name
name (Term name uni fun ann -> Term name uni fun ann)
-> f (Term name uni fun ann) -> f (Term name uni fun ann)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set a -> Term name uni fun ann -> f (Term name uni fun ann)
go (a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
insert (name
name name -> Getting a name a -> a
forall s a. s -> Getting a s a -> a
^. Getting a name a
forall a unique. HasUnique a unique => Lens' a unique
unique) Set a
bvs) Term name uni fun ann
body
    go Set a
bvs (Apply ann
ann Term name uni fun ann
fun Term name uni fun ann
arg)    = ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
Apply ann
ann (Term name uni fun ann
 -> Term name uni fun ann -> Term name uni fun ann)
-> f (Term name uni fun ann)
-> f (Term name uni fun ann -> Term name uni fun ann)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set a -> Term name uni fun ann -> f (Term name uni fun ann)
go Set a
bvs Term name uni fun ann
fun f (Term name uni fun ann -> Term name uni fun ann)
-> f (Term name uni fun ann) -> f (Term name uni fun ann)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Set a -> Term name uni fun ann -> f (Term name uni fun ann)
go Set a
bvs Term name uni fun ann
arg
    go Set a
bvs (Delay ann
ann Term name uni fun ann
term)       = ann -> Term name uni fun ann -> Term name uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Delay ann
ann (Term name uni fun ann -> Term name uni fun ann)
-> f (Term name uni fun ann) -> f (Term name uni fun ann)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set a -> Term name uni fun ann -> f (Term name uni fun ann)
go Set a
bvs Term name uni fun ann
term
    go Set a
bvs (Force ann
ann Term name uni fun ann
term)       = ann -> Term name uni fun ann -> Term name uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Force ann
ann (Term name uni fun ann -> Term name uni fun ann)
-> f (Term name uni fun ann) -> f (Term name uni fun ann)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set a -> Term name uni fun ann -> f (Term name uni fun ann)
go Set a
bvs Term name uni fun ann
term
    go Set a
_   term :: Term name uni fun ann
term@Constant{}        = Term name uni fun ann -> f (Term name uni fun ann)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term name uni fun ann
term
    go Set a
_   term :: Term name uni fun ann
term@Builtin{}         = Term name uni fun ann -> f (Term name uni fun ann)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term name uni fun ann
term
    go Set a
_   term :: Term name uni fun ann
term@Error{}           = Term name uni fun ann -> f (Term name uni fun ann)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term name uni fun ann
term

-- | Substitute *free* names using the given function.
termSubstFreeNames
    :: HasUnique name TermUnique
    => (name -> Maybe (Term name uni fun ann))
    -> Term name uni fun ann
    -> Term name uni fun ann
termSubstFreeNames :: (name -> Maybe (Term name uni fun ann))
-> Term name uni fun ann -> Term name uni fun ann
termSubstFreeNames = ((name -> Identity (Maybe (Term name uni fun ann)))
 -> Term name uni fun ann -> Identity (Term name uni fun ann))
-> (name -> Maybe (Term name uni fun ann))
-> Term name uni fun ann
-> Term name uni fun ann
forall a b c d.
((a -> Identity b) -> c -> Identity d) -> (a -> b) -> c -> d
purely (name -> Identity (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> Identity (Term name uni fun ann)
forall (f :: * -> *) name (uni :: * -> *) fun ann.
(Applicative f, HasUnique name TermUnique) =>
(name -> f (Maybe (Term name uni fun ann)))
-> Term name uni fun ann -> f (Term name uni fun ann)
termSubstFreeNamesA

-- | Completely replace the names with a new name type.
termMapNames
    :: forall name name' uni fun ann
    . (name -> name')
    -> Term name uni fun ann
    -> Term name' uni fun ann
termMapNames :: (name -> name') -> Term name uni fun ann -> Term name' uni fun ann
termMapNames name -> name'
f = Term name uni fun ann -> Term name' uni fun ann
go
    where
        -- This is all a bit clunky because of the type-changing, I'm not sure of a nicer way to do it
        go :: Term name uni fun ann -> Term name' uni fun ann
        go :: Term name uni fun ann -> Term name' uni fun ann
go = \case
            LamAbs ann
ann name
name Term name uni fun ann
body -> ann -> name' -> Term name' uni fun ann -> Term name' uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann -> Term name uni fun ann
LamAbs ann
ann (name -> name'
f name
name) (Term name uni fun ann -> Term name' uni fun ann
go Term name uni fun ann
body)
            Var ann
ann name
name         -> ann -> name' -> Term name' uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> name -> Term name uni fun ann
Var ann
ann (name -> name'
f name
name)

            Apply ann
ann Term name uni fun ann
t1 Term name uni fun ann
t2      -> ann
-> Term name' uni fun ann
-> Term name' uni fun ann
-> Term name' uni fun ann
forall name (uni :: * -> *) fun ann.
ann
-> Term name uni fun ann
-> Term name uni fun ann
-> Term name uni fun ann
Apply ann
ann (Term name uni fun ann -> Term name' uni fun ann
go Term name uni fun ann
t1) (Term name uni fun ann -> Term name' uni fun ann
go Term name uni fun ann
t2)
            Delay ann
ann Term name uni fun ann
t          -> ann -> Term name' uni fun ann -> Term name' uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Delay ann
ann (Term name uni fun ann -> Term name' uni fun ann
go Term name uni fun ann
t)
            Force ann
ann Term name uni fun ann
t          -> ann -> Term name' uni fun ann -> Term name' uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> Term name uni fun ann -> Term name uni fun ann
Force ann
ann (Term name uni fun ann -> Term name' uni fun ann
go Term name uni fun ann
t)

            Constant ann
ann Some (ValueOf uni)
c       -> ann -> Some (ValueOf uni) -> Term name' uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> Some (ValueOf uni) -> Term name uni fun ann
Constant ann
ann Some (ValueOf uni)
c
            Builtin ann
ann fun
b        -> ann -> fun -> Term name' uni fun ann
forall name (uni :: * -> *) fun ann.
ann -> fun -> Term name uni fun ann
Builtin ann
ann fun
b
            Error ann
ann            -> ann -> Term name' uni fun ann
forall name (uni :: * -> *) fun ann. ann -> Term name uni fun ann
Error ann
ann

programMapNames
    :: forall name name' uni fun ann
    . (name -> name')
    -> Program name uni fun ann
    -> Program name' uni fun ann
programMapNames :: (name -> name')
-> Program name uni fun ann -> Program name' uni fun ann
programMapNames name -> name'
f (Program ann
a Version ann
v Term name uni fun ann
term) = ann
-> Version ann
-> Term name' uni fun ann
-> Program name' uni fun ann
forall name (uni :: * -> *) fun ann.
ann
-> Version ann -> Term name uni fun ann -> Program name uni fun ann
Program ann
a Version ann
v ((name -> name') -> Term name uni fun ann -> Term name' uni fun ann
forall name name' (uni :: * -> *) fun ann.
(name -> name') -> Term name uni fun ann -> Term name' uni fun ann
termMapNames name -> name'
f Term name uni fun ann
term)

-- | Get all the term variables in a term.
vTerm :: Ord name => Term name uni fun ann -> Set name
vTerm :: Term name uni fun ann -> Set name
vTerm = Getting (Set name) (Term name uni fun ann) name
-> Term name uni fun ann -> Set name
forall a s. Getting (Set a) s a -> s -> Set a
setOf (Getting (Set name) (Term name uni fun ann) name
 -> Term name uni fun ann -> Set name)
-> Getting (Set name) (Term name uni fun ann) name
-> Term name uni fun ann
-> Set name
forall a b. (a -> b) -> a -> b
$ (Term name uni fun ann -> Const (Set name) (Term name uni fun ann))
-> Term name uni fun ann
-> Const (Set name) (Term name uni fun ann)
forall name (uni :: * -> *) fun ann.
Fold (Term name uni fun ann) (Term name uni fun ann)
termSubtermsDeep ((Term name uni fun ann
  -> Const (Set name) (Term name uni fun ann))
 -> Term name uni fun ann
 -> Const (Set name) (Term name uni fun ann))
-> Getting (Set name) (Term name uni fun ann) name
-> Getting (Set name) (Term name uni fun ann) name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Getting (Set name) (Term name uni fun ann) name
forall name (uni :: * -> *) fun ann.
Traversal' (Term name uni fun ann) name
termVars

-- All uniques

-- | Get all the uniques in a term
uniquesTerm :: HasUniques (Term name uni fun ann) => Term name uni fun ann -> Set Unique
uniquesTerm :: Term name uni fun ann -> Set Unique
uniquesTerm = Getting (Set Unique) (Term name uni fun ann) Unique
-> Term name uni fun ann -> Set Unique
forall a s. Getting (Set a) s a -> s -> Set a
setOf Getting (Set Unique) (Term name uni fun ann) Unique
forall name (uni :: * -> *) fun ann.
HasUniques (Term name uni fun ann) =>
Fold (Term name uni fun ann) Unique
termUniquesDeep