{-# LANGUAGE LambdaCase   #-}
{-# LANGUAGE ViewPatterns #-}
{-|
A simple beta-reduction pass.
-}
module PlutusIR.Transform.Beta (
  beta
  ) where

import PlutusIR
import PlutusIR.Core.Type

import Control.Lens.Setter ((%~))
import Data.Function ((&))
import Data.List.NonEmpty qualified as NE

{- Note [Beta for types]
We can do beta on type abstractions too, turning them into type-lets. We don't do that because
a) It can lead to us inlining types too much, which can slow down compilation a lot.
b) It's currently unsound: https://input-output.atlassian.net/browse/SCP-2570.

We should fix both of these in due course, though.
-}

{- Note [Multi-beta]
Consider two examples where applying beta should be helpful.

1: [(\x . [(\y . t) b]) a]
2: [[(\x . (\y . t)) a] b]

(1) is the typical "let-binding" pattern: each binding corresponds to an immediately-applied lambda.
(2) is the typical "function application" pattern: a multi-argument function applied to multiple arguments.

In both cases we would like to produce something like

let
  x = a
  y = b
in t

However, if we naively do a bottom-up pattern-matching transformation on the AST
to look for immediately-applied lambda abstractions then we will get the following:

1:
  [(\x . [(\y . t) b]) a]
  -->
  [(\x . let y = b in t) a]
  ->
  let x = a in let y = b in t

2:
  [[(\x . (\y . t)) a] b]
  -->
  [(let x = a in (\y . t)) b]

Now, if we later lift the let out, then we will be able to see that we can transform (2) further.
But that means that a) we'd have to do the expensive let-floating pass in every iteration of the simplifier, and
b) we can only inline one function argument per iteration of the  simplifier, so for a function of
arity N we *must* do at least N passes.

This isn't great, so the solution is to recognize case (2) properly and handle all the arguments in one go.
That will also match cases like (1) just fine, since it's just made up of unary function applications.

That does mean that we need to do a manual traversal rather than doing standard bottom-up processing.
-}

{-| Extract the list of bindings from a term, a bit like a "multi-beta" reduction.

Some examples will help:

[(\x . t) a] -> Just ([x |-> a], t)

[[[(\x . (\y . (\z . t))) a] b] c] -> Just ([x |-> a, y |-> b, z |-> c]) t)

[[(\x . t) a] b] -> Nothing

When we decide that we want to do beta for types, we will need to extend this to handle type instantiations too.
-}
extractBindings :: Term tyname name uni fun a -> Maybe (NE.NonEmpty (Binding tyname name uni fun a), Term tyname name uni fun a)
extractBindings :: Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
extractBindings = [Term tyname name uni fun a]
-> Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun ann.
[Term tyname name uni fun ann]
-> Term tyname name uni fun ann
-> Maybe
     (NonEmpty (Binding tyname name uni fun ann),
      Term tyname name uni fun ann)
collectArgs []
  where
      collectArgs :: [Term tyname name uni fun ann]
-> Term tyname name uni fun ann
-> Maybe
     (NonEmpty (Binding tyname name uni fun ann),
      Term tyname name uni fun ann)
collectArgs [Term tyname name uni fun ann]
argStack (Apply ann
_ Term tyname name uni fun ann
f Term tyname name uni fun ann
arg) = [Term tyname name uni fun ann]
-> Term tyname name uni fun ann
-> Maybe
     (NonEmpty (Binding tyname name uni fun ann),
      Term tyname name uni fun ann)
collectArgs (Term tyname name uni fun ann
argTerm tyname name uni fun ann
-> [Term tyname name uni fun ann] -> [Term tyname name uni fun ann]
forall a. a -> [a] -> [a]
:[Term tyname name uni fun ann]
argStack) Term tyname name uni fun ann
f
      collectArgs [Term tyname name uni fun ann]
argStack Term tyname name uni fun ann
t               = [Term tyname name uni fun ann]
-> [Binding tyname name uni fun ann]
-> Term tyname name uni fun ann
-> Maybe
     (NonEmpty (Binding tyname name uni fun ann),
      Term tyname name uni fun ann)
forall tyname name (uni :: * -> *) fun ann fun.
[Term tyname name uni fun ann]
-> [Binding tyname name uni fun ann]
-> Term tyname name uni fun ann
-> Maybe
     (NonEmpty (Binding tyname name uni fun ann),
      Term tyname name uni fun ann)
matchArgs [Term tyname name uni fun ann]
argStack [] Term tyname name uni fun ann
t
      matchArgs :: [Term tyname name uni fun ann]
-> [Binding tyname name uni fun ann]
-> Term tyname name uni fun ann
-> Maybe
     (NonEmpty (Binding tyname name uni fun ann),
      Term tyname name uni fun ann)
matchArgs (Term tyname name uni fun ann
arg:[Term tyname name uni fun ann]
rest) [Binding tyname name uni fun ann]
acc (LamAbs ann
a name
n Type tyname uni ann
ty Term tyname name uni fun ann
body) = [Term tyname name uni fun ann]
-> [Binding tyname name uni fun ann]
-> Term tyname name uni fun ann
-> Maybe
     (NonEmpty (Binding tyname name uni fun ann),
      Term tyname name uni fun ann)
matchArgs [Term tyname name uni fun ann]
rest (ann
-> Strictness
-> VarDecl tyname name uni fun ann
-> Term tyname name uni fun ann
-> Binding tyname name uni fun ann
forall tyname name (uni :: * -> *) fun a.
a
-> Strictness
-> VarDecl tyname name uni fun a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
TermBind ann
a Strictness
Strict (ann
-> name -> Type tyname uni ann -> VarDecl tyname name uni fun ann
forall k tyname name (uni :: * -> *) (fun :: k) ann.
ann
-> name -> Type tyname uni ann -> VarDecl tyname name uni fun ann
VarDecl ann
a name
n Type tyname uni ann
ty) Term tyname name uni fun ann
argBinding tyname name uni fun ann
-> [Binding tyname name uni fun ann]
-> [Binding tyname name uni fun ann]
forall a. a -> [a] -> [a]
:[Binding tyname name uni fun ann]
acc) Term tyname name uni fun ann
body
      matchArgs []         [Binding tyname name uni fun ann]
acc Term tyname name uni fun ann
t                    =
          case [Binding tyname name uni fun ann]
-> Maybe (NonEmpty (Binding tyname name uni fun ann))
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty ([Binding tyname name uni fun ann]
-> [Binding tyname name uni fun ann]
forall a. [a] -> [a]
reverse [Binding tyname name uni fun ann]
acc) of
              Maybe (NonEmpty (Binding tyname name uni fun ann))
Nothing   -> Maybe
  (NonEmpty (Binding tyname name uni fun ann),
   Term tyname name uni fun ann)
forall a. Maybe a
Nothing
              Just NonEmpty (Binding tyname name uni fun ann)
acc' -> (NonEmpty (Binding tyname name uni fun ann),
 Term tyname name uni fun ann)
-> Maybe
     (NonEmpty (Binding tyname name uni fun ann),
      Term tyname name uni fun ann)
forall a. a -> Maybe a
Just (NonEmpty (Binding tyname name uni fun ann)
acc', Term tyname name uni fun ann
t)
      matchArgs (Term tyname name uni fun ann
_:[Term tyname name uni fun ann]
_)      [Binding tyname name uni fun ann]
_   Term tyname name uni fun ann
_                    = Maybe
  (NonEmpty (Binding tyname name uni fun ann),
   Term tyname name uni fun ann)
forall a. Maybe a
Nothing

{-|
Recursively apply the beta transformation on the code, both for the terms

@
    (\ (x : A). M) N
    ==>
    let x : A = N in M
@

and types

@
    (/\ a. \(x : a) . x) {A}
    ==>
    let a : * = A in
    (\ (x : A). x)
@

-}
beta
    :: Term tyname name uni fun a
    -> Term tyname name uni fun a
beta :: Term tyname name uni fun a -> Term tyname name uni fun a
beta = \case
    -- See Note [Multi-beta]
    -- This maybe isn't the best annotation for this term, but it will do.
    (Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a
-> Maybe
     (NonEmpty (Binding tyname name uni fun a),
      Term tyname name uni fun a)
extractBindings -> Just (NonEmpty (Binding tyname name uni fun a)
bs, Term tyname name uni fun a
t)) -> a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let (Term tyname name uni fun a -> a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> a
termAnn Term tyname name uni fun a
t) Recursivity
NonRec NonEmpty (Binding tyname name uni fun a)
bs (Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
beta Term tyname name uni fun a
t)
    Term tyname name uni fun a
t                                 -> Term tyname name uni fun a
t Term tyname name uni fun a
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
forall a b. a -> (a -> b) -> b
& (Term tyname name uni fun a
 -> Identity (Term tyname name uni fun a))
-> Term tyname name uni fun a
-> Identity (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Term tyname name uni fun a) (Term tyname name uni fun a)
termSubterms ((Term tyname name uni fun a
  -> Identity (Term tyname name uni fun a))
 -> Term tyname name uni fun a
 -> Identity (Term tyname name uni fun a))
-> (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
beta