{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}
-- | Compile non-strict bindings into strict bindings.
module PlutusIR.Transform.NonStrict (compileNonStrictBindings) where

import PlutusIR
import PlutusIR.Transform.Rename ()
import PlutusIR.Transform.Substitute

import PlutusCore.Quote
import PlutusCore.StdLib.Data.ScottUnit qualified as Unit

import Control.Lens hiding (Strict)
import Control.Monad.State

import Data.Map qualified as Map

{- Note [Compiling non-strict bindings]
Given `let x : ty = rhs in body`, we
- Replace `let x : ty = rhs in ...` with `let x : () -> ty = \arg : () -> rhs in ...`
- Replace all references to `x` in `body` with `x ()`

To avoid quadratic behaviour, we do the latter substitution in one go, by collecting
all the substitutions to do as we go, and then doing them in one go at the end.

Since we are constructing a global substitution, so we need globally unique
names to avoid clashes.
-}

{- Note [Using unit versus force/delay]
We don't have force/delay in PIR, but we can use trivial type-abstractions and instantiations,
which will erase to force and delay in UPLC. Not quite as nice, but it doesn't require an extension
to the language.

However, we retain the *option* to use unit-lambdas instead, since we rely on this pass to
handle recursive, non-function bindings and give them function types. `delayed x` is not a
function type but `() -> x` is!
-}

type Substs uni fun a = Map.Map Name (Term TyName Name uni fun a)

-- | Compile all the non-strict bindings in a term into strict bindings. Note: requires globally
-- unique names.
compileNonStrictBindings :: MonadQuote m => Bool -> Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
compileNonStrictBindings :: Bool
-> Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
compileNonStrictBindings Bool
useUnit Term TyName Name uni fun a
t = do
    (Term TyName Name uni fun a
t', Map Name (Term TyName Name uni fun a)
substs) <- Quote
  (Term TyName Name uni fun a, Map Name (Term TyName Name uni fun a))
-> m (Term TyName Name uni fun a,
      Map Name (Term TyName Name uni fun a))
forall (m :: * -> *) a. MonadQuote m => Quote a -> m a
liftQuote (Quote
   (Term TyName Name uni fun a, Map Name (Term TyName Name uni fun a))
 -> m (Term TyName Name uni fun a,
       Map Name (Term TyName Name uni fun a)))
-> Quote
     (Term TyName Name uni fun a, Map Name (Term TyName Name uni fun a))
-> m (Term TyName Name uni fun a,
      Map Name (Term TyName Name uni fun a))
forall a b. (a -> b) -> a -> b
$ (StateT
   (Map Name (Term TyName Name uni fun a))
   (QuoteT Identity)
   (Term TyName Name uni fun a)
 -> Map Name (Term TyName Name uni fun a)
 -> Quote
      (Term TyName Name uni fun a,
       Map Name (Term TyName Name uni fun a)))
-> Map Name (Term TyName Name uni fun a)
-> StateT
     (Map Name (Term TyName Name uni fun a))
     (QuoteT Identity)
     (Term TyName Name uni fun a)
-> Quote
     (Term TyName Name uni fun a, Map Name (Term TyName Name uni fun a))
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT
  (Map Name (Term TyName Name uni fun a))
  (QuoteT Identity)
  (Term TyName Name uni fun a)
-> Map Name (Term TyName Name uni fun a)
-> Quote
     (Term TyName Name uni fun a, Map Name (Term TyName Name uni fun a))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT Map Name (Term TyName Name uni fun a)
forall a. Monoid a => a
mempty (StateT
   (Map Name (Term TyName Name uni fun a))
   (QuoteT Identity)
   (Term TyName Name uni fun a)
 -> Quote
      (Term TyName Name uni fun a,
       Map Name (Term TyName Name uni fun a)))
-> StateT
     (Map Name (Term TyName Name uni fun a))
     (QuoteT Identity)
     (Term TyName Name uni fun a)
-> Quote
     (Term TyName Name uni fun a, Map Name (Term TyName Name uni fun a))
forall a b. (a -> b) -> a -> b
$ Bool
-> Term TyName Name uni fun a
-> StateT
     (Map Name (Term TyName Name uni fun a))
     (QuoteT Identity)
     (Term TyName Name uni fun a)
forall (uni :: * -> *) fun a (m :: * -> *).
(MonadState (Substs uni fun a) m, MonadQuote m) =>
Bool
-> Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
strictifyTerm Bool
useUnit Term TyName Name uni fun a
t
    -- See Note [Compiling non-strict bindings]
    Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
forall a b. (a -> b) -> a -> b
$ (Name -> Maybe (Term TyName Name uni fun a))
-> Term TyName Name uni fun a -> Term TyName Name uni fun a
forall name tyname (uni :: * -> *) fun a.
(name -> Maybe (Term tyname name uni fun a))
-> Term tyname name uni fun a -> Term tyname name uni fun a
termSubstNames (\Name
n -> Name
-> Map Name (Term TyName Name uni fun a)
-> Maybe (Term TyName Name uni fun a)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
n Map Name (Term TyName Name uni fun a)
substs) Term TyName Name uni fun a
t'

strictifyTerm
    :: (MonadState (Substs uni fun a) m, MonadQuote m)
    => Bool -> Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
strictifyTerm :: Bool
-> Term TyName Name uni fun a -> m (Term TyName Name uni fun a)
strictifyTerm Bool
useUnit =
    -- See Note [Using unit versus force/delay]
    let transformation :: Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
transformation = if Bool
useUnit then Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
forall (uni :: * -> *) fun a (m :: * -> *).
(MonadState (Substs uni fun a) m, MonadQuote m) =>
Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
strictifyBindingWithUnit else Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
forall (uni :: * -> *) fun a (m :: * -> *).
(MonadState (Substs uni fun a) m, MonadQuote m) =>
Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
strictifyBinding
    in LensLike
  (WrappedMonad m)
  (Term TyName Name uni fun a)
  (Term TyName Name uni fun a)
  (Term TyName Name uni fun a)
  (Term TyName Name uni fun a)
-> (Term TyName Name uni fun a -> m (Term TyName Name uni fun a))
-> Term TyName Name uni fun a
-> m (Term TyName Name uni fun a)
forall (m :: * -> *) a b.
Monad m =>
LensLike (WrappedMonad m) a b a b -> (b -> m b) -> a -> m b
transformMOf LensLike
  (WrappedMonad m)
  (Term TyName Name uni fun a)
  (Term TyName Name uni fun a)
  (Term TyName Name uni fun a)
  (Term TyName Name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Term tyname name uni fun a) (Term tyname name uni fun a)
termSubterms (LensLike
  m
  (Term TyName Name uni fun a)
  (Term TyName Name uni fun a)
  (Binding TyName Name uni fun a)
  (Binding TyName Name uni fun a)
-> LensLike
     m
     (Term TyName Name uni fun a)
     (Term TyName Name uni fun a)
     (Binding TyName Name uni fun a)
     (Binding TyName Name uni fun a)
forall (f :: * -> *) s t a b.
LensLike f s t a b -> LensLike f s t a b
traverseOf LensLike
  m
  (Term TyName Name uni fun a)
  (Term TyName Name uni fun a)
  (Binding TyName Name uni fun a)
  (Binding TyName Name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Term tyname name uni fun a) (Binding tyname name uni fun a)
termBindings Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
transformation)

strictifyBinding
    :: (MonadState (Substs uni fun a) m, MonadQuote m)
    => Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
strictifyBinding :: Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
strictifyBinding = \case
    TermBind a
x Strictness
NonStrict (VarDecl a
x' Name
name Type TyName uni a
ty) Term TyName Name uni fun a
rhs -> do
        -- The annotation to use for new synthetic nodes
        let ann :: a
ann = a
x'

        TyName
a <- Text -> m TyName
forall (m :: * -> *). MonadQuote m => Text -> m TyName
freshTyName Text
"dead"
        -- See Note [Compiling non-strict bindings]
        (Map Name (Term TyName Name uni fun a)
 -> Map Name (Term TyName Name uni fun a))
-> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map Name (Term TyName Name uni fun a)
  -> Map Name (Term TyName Name uni fun a))
 -> m ())
-> (Map Name (Term TyName Name uni fun a)
    -> Map Name (Term TyName Name uni fun a))
-> m ()
forall a b. (a -> b) -> a -> b
$ Name
-> Term TyName Name uni fun a
-> Map Name (Term TyName Name uni fun a)
-> Map Name (Term TyName Name uni fun a)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Name
name (Term TyName Name uni fun a
 -> Map Name (Term TyName Name uni fun a)
 -> Map Name (Term TyName Name uni fun a))
-> Term TyName Name uni fun a
-> Map Name (Term TyName Name uni fun a)
-> Map Name (Term TyName Name uni fun a)
forall a b. (a -> b) -> a -> b
$ a
-> Term TyName Name uni fun a
-> Type TyName uni a
-> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Term tyname name uni fun a
-> Type tyname uni a
-> Term tyname name uni fun a
TyInst a
ann (a -> Name -> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
a -> name -> Term tyname name uni fun a
Var a
ann Name
name) (a -> TyName -> Kind a -> Type TyName uni a -> Type TyName uni a
forall tyname (uni :: * -> *) ann.
ann
-> tyname -> Kind ann -> Type tyname uni ann -> Type tyname uni ann
TyForall a
ann TyName
a (a -> Kind a
forall ann. ann -> Kind ann
Type a
ann) (a -> TyName -> Type TyName uni a
forall tyname (uni :: * -> *) ann.
ann -> tyname -> Type tyname uni ann
TyVar a
ann TyName
a))

        Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Binding TyName Name uni fun a
 -> m (Binding TyName Name uni fun a))
-> Binding TyName Name uni fun a
-> m (Binding TyName Name uni fun a)
forall a b. (a -> b) -> a -> b
$ a
-> Strictness
-> VarDecl TyName Name uni fun a
-> Term TyName Name uni fun a
-> Binding TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Strictness
-> VarDecl tyname name uni fun a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
TermBind a
x Strictness
Strict (a -> Name -> Type TyName uni a -> VarDecl TyName Name uni fun a
forall k tyname name (uni :: * -> *) (fun :: k) ann.
ann
-> name -> Type tyname uni ann -> VarDecl tyname name uni fun ann
VarDecl a
x' Name
name (a -> TyName -> Kind a -> Type TyName uni a -> Type TyName uni a
forall tyname (uni :: * -> *) ann.
ann
-> tyname -> Kind ann -> Type tyname uni ann -> Type tyname uni ann
TyForall a
ann TyName
a (a -> Kind a
forall ann. ann -> Kind ann
Type a
ann) Type TyName uni a
ty)) (a
-> TyName
-> Kind a
-> Term TyName Name uni fun a
-> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> tyname
-> Kind a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
TyAbs a
ann TyName
a (a -> Kind a
forall ann. ann -> Kind ann
Type a
ann) Term TyName Name uni fun a
rhs)
    Binding TyName Name uni fun a
x -> Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding TyName Name uni fun a
x

strictifyBindingWithUnit
    :: (MonadState (Substs uni fun a) m, MonadQuote m)
    => Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
strictifyBindingWithUnit :: Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
strictifyBindingWithUnit = \case
    TermBind a
x Strictness
NonStrict (VarDecl a
x' Name
name Type TyName uni a
ty) Term TyName Name uni fun a
rhs -> do
        -- The annotation to use for new synthetic nodes
        let ann :: a
ann = a
x'

        Name
argName <- Quote Name -> m Name
forall (m :: * -> *) a. MonadQuote m => Quote a -> m a
liftQuote (Quote Name -> m Name) -> Quote Name -> m Name
forall a b. (a -> b) -> a -> b
$ Text -> Quote Name
forall (m :: * -> *). MonadQuote m => Text -> m Name
freshName Text
"arg"
        -- TODO: These are created at every use site, we should bind them globally
        let unit :: Type TyName uni a
unit = a
ann a -> Type TyName uni () -> Type TyName uni a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Type TyName uni ()
forall (uni :: * -> *). Type TyName uni ()
Unit.unit
            unitval :: Term TyName Name uni fun a
unitval = a
ann a -> Term TyName Name uni fun () -> Term TyName Name uni fun a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Term TyName Name uni fun ()
forall (term :: * -> *) (uni :: * -> *) fun.
TermLike term TyName Name uni fun =>
term ()
Unit.unitval
            forced :: Term TyName Name uni fun a
forced = a
-> Term TyName Name uni fun a
-> Term TyName Name uni fun a
-> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Apply a
ann (a -> Name -> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
a -> name -> Term tyname name uni fun a
Var a
ann Name
name) Term TyName Name uni fun a
unitval

        -- See Note [Compiling non-strict bindings]
        (Map Name (Term TyName Name uni fun a)
 -> Map Name (Term TyName Name uni fun a))
-> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map Name (Term TyName Name uni fun a)
  -> Map Name (Term TyName Name uni fun a))
 -> m ())
-> (Map Name (Term TyName Name uni fun a)
    -> Map Name (Term TyName Name uni fun a))
-> m ()
forall a b. (a -> b) -> a -> b
$ Name
-> Term TyName Name uni fun a
-> Map Name (Term TyName Name uni fun a)
-> Map Name (Term TyName Name uni fun a)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Name
name Term TyName Name uni fun a
forced

        Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Binding TyName Name uni fun a
 -> m (Binding TyName Name uni fun a))
-> Binding TyName Name uni fun a
-> m (Binding TyName Name uni fun a)
forall a b. (a -> b) -> a -> b
$ a
-> Strictness
-> VarDecl TyName Name uni fun a
-> Term TyName Name uni fun a
-> Binding TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Strictness
-> VarDecl tyname name uni fun a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
TermBind a
x Strictness
Strict (a -> Name -> Type TyName uni a -> VarDecl TyName Name uni fun a
forall k tyname name (uni :: * -> *) (fun :: k) ann.
ann
-> name -> Type tyname uni ann -> VarDecl tyname name uni fun ann
VarDecl a
x' Name
name (a -> Type TyName uni a -> Type TyName uni a -> Type TyName uni a
forall tyname (uni :: * -> *) ann.
ann
-> Type tyname uni ann
-> Type tyname uni ann
-> Type tyname uni ann
TyFun a
ann Type TyName uni a
forall (uni :: * -> *). Type TyName uni a
unit Type TyName uni a
ty)) (a
-> Name
-> Type TyName uni a
-> Term TyName Name uni fun a
-> Term TyName Name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> name
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
LamAbs a
ann Name
argName Type TyName uni a
forall (uni :: * -> *). Type TyName uni a
unit Term TyName Name uni fun a
rhs)
    Binding TyName Name uni fun a
x -> Binding TyName Name uni fun a -> m (Binding TyName Name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding TyName Name uni fun a
x