{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE ViewPatterns        #-}
{-# OPTIONS_GHC -fno-warn-unused-top-binds #-}
module PlutusIR.Transform.LetFloat (floatTerm) where

import Control.Arrow ((>>>))
import Control.Lens hiding (Strict)
import Control.Monad.Reader
import Control.Monad.Writer
import Data.Coerce
import Data.Foldable
import Data.List.NonEmpty qualified as NE
import Data.Map qualified as M
import Data.Map.Monoidal qualified as MM
import Data.Semigroup.Foldable
import Data.Semigroup.Generic
import Data.Set qualified as S
import Data.Set.Lens (setOf)
import GHC.Generics
import PlutusCore qualified as PLC
import PlutusCore.Builtin qualified as PLC
import PlutusCore.Name qualified as PLC
import PlutusIR
import PlutusIR.Purity
import PlutusIR.Subst

{- Note [Let Floating pass]

The goal of this pass is to move (float) let-bindings as outwards as possible,
without breaking the scoping & meaning of the original PIR term.

This transformation (a.k.a. full laziness), together with a possible implementation
is described in Peyton Jones, Simon, Will Partain, and Andre Santos. "Let-Floating: Moving Bindings to Give Faster Programs."
In Proceedings of the First ACM SIGPLAN International Conference on Functional Programming, 1-12.
ICFP '96. New York, NY, USA: ACM, 1996. https://doi.org/10.1145/232627.232630.

An implementation, as described in the paper, is comprised of two "passes":

1) a "mark" pass to traverse the term tree and
  - in case of lam/Lam, mark this lam/Lam name with current depth, and
    increase the depth for the lam/Lam's-abstraction body term and recurse.
  - in case of a Letrecgroup, collect the free term&type variables and mark every let-introduced name
    with the maximum depth among all the free variables (the free variables should be already marked)
  - in case of letnonrec group, you can treat it the same as (letrec g in letrec gs)

2) a "float-back" pass which, given the collected marks,
   traverses the term tree again and whenever a let(rec or nonrec) is encountered,
   decides locally if it is worth to float the current let outwards at its marked depth.
   If yes, the let-group's binding is floated exactly outside a lambda abstraction that has lam_depth=let_depth+1

There are some  differences with the paper's described implementation above, namely:

a) we use 3 passes. the 1st pass is similar to the original; a second pass
"cleans" the term from all the to-be-floated lets and stores them in a separate table.
the 3rd pass is responsible to float back the removed lets inside the cleaned term
according to their markers. So we use an extra pass because we float back lets in a global fashion,
instead of deciding locally.

b) Since the 3rd (float-back) pass operates on the cleaned term, we have lost
the original location of the lets, so we cannot float them "right outside" the **maximum-independent lambda-abstraction**,
but we float them "right inside" the maximum **dependent** lambda-abstraction's body. This has the downside
of allocating&holding the lets for longer than needed, but will not alter the meaning of the original PIR term.

c) Since PIR has strict (compared to the paper's lazy-only lang), we have to make
sure that any let-group containing at least one **effectful** (i.e. non-pure) strict binding is
not floated at all. See the implementation of 'hasNoEffects'.

This does not mean that such an "effectful" let
will appear in the same absolute location as the original term:
An outside/parent let may float around, changing the child's (effectful let) absolute location;
however, the child's relative location to the parent *must* remain the same. Consider this example:

`... let (nonstrict) x1= (let (strict) x2 = error in rhs1) in body1`

The parent of x2 is x1 and it is floatable
The child of x1 is x2 and it is unmovable
The x2 binding will not move with respect to x1, but its original absolute program location may change,
because the parent may float upwards somewhere else.

Since another let variable may depend on such *effectful* let, and to preserve the execution order,
we treat an effectful also as an "anchor", by increasing the current depth
both on entering any of its rhs'es *and* inside its inTerm.
-}

newtype Depth = Depth Int
    deriving newtype (Depth -> Depth -> Bool
(Depth -> Depth -> Bool) -> (Depth -> Depth -> Bool) -> Eq Depth
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Depth -> Depth -> Bool
$c/= :: Depth -> Depth -> Bool
== :: Depth -> Depth -> Bool
$c== :: Depth -> Depth -> Bool
Eq, Eq Depth
Eq Depth
-> (Depth -> Depth -> Ordering)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Bool)
-> (Depth -> Depth -> Depth)
-> (Depth -> Depth -> Depth)
-> Ord Depth
Depth -> Depth -> Bool
Depth -> Depth -> Ordering
Depth -> Depth -> Depth
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Depth -> Depth -> Depth
$cmin :: Depth -> Depth -> Depth
max :: Depth -> Depth -> Depth
$cmax :: Depth -> Depth -> Depth
>= :: Depth -> Depth -> Bool
$c>= :: Depth -> Depth -> Bool
> :: Depth -> Depth -> Bool
$c> :: Depth -> Depth -> Bool
<= :: Depth -> Depth -> Bool
$c<= :: Depth -> Depth -> Bool
< :: Depth -> Depth -> Bool
$c< :: Depth -> Depth -> Bool
compare :: Depth -> Depth -> Ordering
$ccompare :: Depth -> Depth -> Ordering
$cp1Ord :: Eq Depth
Ord, Int -> Depth -> ShowS
[Depth] -> ShowS
Depth -> String
(Int -> Depth -> ShowS)
-> (Depth -> String) -> ([Depth] -> ShowS) -> Show Depth
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Depth] -> ShowS
$cshowList :: [Depth] -> ShowS
show :: Depth -> String
$cshow :: Depth -> String
showsPrec :: Int -> Depth -> ShowS
$cshowsPrec :: Int -> Depth -> ShowS
Show, Integer -> Depth
Depth -> Depth
Depth -> Depth -> Depth
(Depth -> Depth -> Depth)
-> (Depth -> Depth -> Depth)
-> (Depth -> Depth -> Depth)
-> (Depth -> Depth)
-> (Depth -> Depth)
-> (Depth -> Depth)
-> (Integer -> Depth)
-> Num Depth
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> Depth
$cfromInteger :: Integer -> Depth
signum :: Depth -> Depth
$csignum :: Depth -> Depth
abs :: Depth -> Depth
$cabs :: Depth -> Depth
negate :: Depth -> Depth
$cnegate :: Depth -> Depth
* :: Depth -> Depth -> Depth
$c* :: Depth -> Depth -> Depth
- :: Depth -> Depth -> Depth
$c- :: Depth -> Depth -> Depth
+ :: Depth -> Depth -> Depth
$c+ :: Depth -> Depth -> Depth
Num)

{-| Position of an anchor (lam,Lam,unfloatable-let or Top).
The original paper's algorithm relies just on using the depth as the anchor's position;
for us this is no enough, because we act mark/remove/float globally and the depth is not globally-unique.
To fix this, we use an extra "representative" identifier (PLC.Unique) of the anchor.
Since (unfloatable) lets can also be anchors, we also use an extra 'PosType' to differentiate
between two cases of a let-anchor, see 'PosType'.
-}
data Pos = Pos
    { Pos -> Depth
_posDepth  :: Depth
    , Pos -> Unique
_posUnique :: PLC.Unique -- ^ The lam name or Lam tyname or Let's representative unique
    , Pos -> PosType
_posType   :: PosType
    }
    deriving stock (Pos -> Pos -> Bool
(Pos -> Pos -> Bool) -> (Pos -> Pos -> Bool) -> Eq Pos
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Pos -> Pos -> Bool
$c/= :: Pos -> Pos -> Bool
== :: Pos -> Pos -> Bool
$c== :: Pos -> Pos -> Bool
Eq, Eq Pos
Eq Pos
-> (Pos -> Pos -> Ordering)
-> (Pos -> Pos -> Bool)
-> (Pos -> Pos -> Bool)
-> (Pos -> Pos -> Bool)
-> (Pos -> Pos -> Bool)
-> (Pos -> Pos -> Pos)
-> (Pos -> Pos -> Pos)
-> Ord Pos
Pos -> Pos -> Bool
Pos -> Pos -> Ordering
Pos -> Pos -> Pos
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Pos -> Pos -> Pos
$cmin :: Pos -> Pos -> Pos
max :: Pos -> Pos -> Pos
$cmax :: Pos -> Pos -> Pos
>= :: Pos -> Pos -> Bool
$c>= :: Pos -> Pos -> Bool
> :: Pos -> Pos -> Bool
$c> :: Pos -> Pos -> Bool
<= :: Pos -> Pos -> Bool
$c<= :: Pos -> Pos -> Bool
< :: Pos -> Pos -> Bool
$c< :: Pos -> Pos -> Bool
compare :: Pos -> Pos -> Ordering
$ccompare :: Pos -> Pos -> Ordering
$cp1Ord :: Eq Pos
Ord, Int -> Pos -> ShowS
[Pos] -> ShowS
Pos -> String
(Int -> Pos -> ShowS)
-> (Pos -> String) -> ([Pos] -> ShowS) -> Show Pos
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Pos] -> ShowS
$cshowList :: [Pos] -> ShowS
show :: Pos -> String
$cshow :: Pos -> String
showsPrec :: Int -> Pos -> ShowS
$cshowsPrec :: Int -> Pos -> ShowS
Show)

{-| The type of the anchor's position. We only need this because
we need to differentiate between two cases of a 'let-anchor' position:

A floatable let-binding can (maximally) depend on an (unfloatable, effectful) let anchor,
which means that it will either float in two different places, depending upon the floatable let's original location:

a) floated *next to* the let-anchor it depends upon (inside its let-group), if it originated from the rhs of the let-anchor
b) floated directly under the `in` of the let-anchor it depends upon, if it originated from the inTerm of the let-anchor.
-}
data PosType = LamBody -- ^ lam, Lam, let body, or Top
             | LetRhs -- ^ let rhs
             deriving stock (PosType -> PosType -> Bool
(PosType -> PosType -> Bool)
-> (PosType -> PosType -> Bool) -> Eq PosType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PosType -> PosType -> Bool
$c/= :: PosType -> PosType -> Bool
== :: PosType -> PosType -> Bool
$c== :: PosType -> PosType -> Bool
Eq, Eq PosType
Eq PosType
-> (PosType -> PosType -> Ordering)
-> (PosType -> PosType -> Bool)
-> (PosType -> PosType -> Bool)
-> (PosType -> PosType -> Bool)
-> (PosType -> PosType -> Bool)
-> (PosType -> PosType -> PosType)
-> (PosType -> PosType -> PosType)
-> Ord PosType
PosType -> PosType -> Bool
PosType -> PosType -> Ordering
PosType -> PosType -> PosType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PosType -> PosType -> PosType
$cmin :: PosType -> PosType -> PosType
max :: PosType -> PosType -> PosType
$cmax :: PosType -> PosType -> PosType
>= :: PosType -> PosType -> Bool
$c>= :: PosType -> PosType -> Bool
> :: PosType -> PosType -> Bool
$c> :: PosType -> PosType -> Bool
<= :: PosType -> PosType -> Bool
$c<= :: PosType -> PosType -> Bool
< :: PosType -> PosType -> Bool
$c< :: PosType -> PosType -> Bool
compare :: PosType -> PosType -> Ordering
$ccompare :: PosType -> PosType -> Ordering
$cp1Ord :: Eq PosType
Ord, Int -> PosType -> ShowS
[PosType] -> ShowS
PosType -> String
(Int -> PosType -> ShowS)
-> (PosType -> String) -> ([PosType] -> ShowS) -> Show PosType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PosType] -> ShowS
$cshowList :: [PosType] -> ShowS
show :: PosType -> String
$cshow :: PosType -> String
showsPrec :: Int -> PosType -> ShowS
$cshowsPrec :: Int -> PosType -> ShowS
Show)

topPos :: Pos
topPos :: Pos
topPos = Depth -> Unique -> PosType -> Pos
Pos Depth
topDepth Unique
topUnique PosType
topType

-- | For simplicity, the top position is also linked to a unique number
-- chosen to not clash with any actual uniques of names/tynames of the program
topUnique :: PLC.Unique
topUnique :: Unique
topUnique = Int -> Unique
coerce (-Int
1 :: Int)

-- | arbitrarily chosen
topDepth :: Depth
topDepth :: Depth
topDepth = -Depth
1

-- | arbitrary chosen as LamBody, because top can be imagined as a global inbody (of an empty letterm)
topType :: PosType
topType :: PosType
topType = PosType
LamBody

-- | Arbitrary: return a single unique among all the introduced uniques of the given letgroup.
representativeBindingUnique
    :: (PLC.HasUnique name PLC.TermUnique, PLC.HasUnique tyname PLC.TypeUnique)
    => NE.NonEmpty (Binding tyname name uni fun a) -> PLC.Unique
representativeBindingUnique :: NonEmpty (Binding tyname name uni fun a) -> Unique
representativeBindingUnique =
    -- Arbitrary: select the first unique from the representative binding
    Getting (First Unique) (Binding tyname name uni fun a) Unique
-> Binding tyname name uni fun a -> Unique
forall a s. Getting (First a) s a -> s -> a
first1Of Getting (First Unique) (Binding tyname name uni fun a) Unique
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Traversal1' (Binding tyname name uni fun a) Unique
bindingIds (Binding tyname name uni fun a -> Unique)
-> (NonEmpty (Binding tyname name uni fun a)
    -> Binding tyname name uni fun a)
-> NonEmpty (Binding tyname name uni fun a)
-> Unique
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty (Binding tyname name uni fun a)
-> Binding tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
NonEmpty (Binding tyname name uni fun a)
-> Binding tyname name uni fun a
representativeBinding
  where
    --  Arbitrary: a binding to be used as representative binding in MARKING the group of bindings.
    representativeBinding :: NE.NonEmpty (Binding tyname name uni fun a) -> Binding tyname name uni fun a
    representativeBinding :: NonEmpty (Binding tyname name uni fun a)
-> Binding tyname name uni fun a
representativeBinding = NonEmpty (Binding tyname name uni fun a)
-> Binding tyname name uni fun a
forall a. NonEmpty a -> a
NE.head

-- | Every term and type variable in current scope
-- is paired with its own computed marker (maximum dependent position)
-- OPTIMIZE: use UniqueMap instead
type Scope = M.Map PLC.Unique Pos

-- | The first pass has a reader context of current depth, and (term&type)variables in scope.
data MarkCtx = MarkCtx { MarkCtx -> Depth
_markCtxDepth :: Depth, MarkCtx -> Scope
_markCtxScope :: Scope }
makeLenses ''MarkCtx

-- | The result of the first pass is a subset(union of all computed scopes).
-- This subset contains only the marks of the floatable lets.
type Marks = Scope

{-|
A 'BindingGrp' is a group of bindings and a *minimum* recursivity for the group.
We use this intermediate structure when tracking groups of bindings to be floated or re-inserted.

It's convenient when doing this work to be able to combine binding groups (with the 'Semigroup') instance.
However, appending 'BindingGrp's does not account for the possibility that binding groups may *share*
variables. This means that the combination of multiple non-recursive binding groups may be recursive.
As such, if you have reason to believe that the variables used by the combined binding groups may not be disjoint,
you should manually require the term to be recursive when you convert back to a let term with 'bindingGrpToLet'.
-}
data BindingGrp tyname name uni fun a = BindingGrp {
    BindingGrp tyname name uni fun a -> a
_bgAnn      :: a,
    BindingGrp tyname name uni fun a -> Recursivity
_bgRec      :: Recursivity,
    BindingGrp tyname name uni fun a
-> NonEmpty (Binding tyname name uni fun a)
_bgBindings :: NE.NonEmpty (Binding tyname name uni fun a)
    }
    deriving stock (forall x.
 BindingGrp tyname name uni fun a
 -> Rep (BindingGrp tyname name uni fun a) x)
-> (forall x.
    Rep (BindingGrp tyname name uni fun a) x
    -> BindingGrp tyname name uni fun a)
-> Generic (BindingGrp tyname name uni fun a)
forall x.
Rep (BindingGrp tyname name uni fun a) x
-> BindingGrp tyname name uni fun a
forall x.
BindingGrp tyname name uni fun a
-> Rep (BindingGrp tyname name uni fun a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall tyname name (uni :: * -> *) fun a x.
Rep (BindingGrp tyname name uni fun a) x
-> BindingGrp tyname name uni fun a
forall tyname name (uni :: * -> *) fun a x.
BindingGrp tyname name uni fun a
-> Rep (BindingGrp tyname name uni fun a) x
$cto :: forall tyname name (uni :: * -> *) fun a x.
Rep (BindingGrp tyname name uni fun a) x
-> BindingGrp tyname name uni fun a
$cfrom :: forall tyname name (uni :: * -> *) fun a x.
BindingGrp tyname name uni fun a
-> Rep (BindingGrp tyname name uni fun a) x
Generic
    deriving b
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
NonEmpty (BindingGrp tyname name uni fun a)
-> BindingGrp tyname name uni fun a
BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
(BindingGrp tyname name uni fun a
 -> BindingGrp tyname name uni fun a
 -> BindingGrp tyname name uni fun a)
-> (NonEmpty (BindingGrp tyname name uni fun a)
    -> BindingGrp tyname name uni fun a)
-> (forall b.
    Integral b =>
    b
    -> BindingGrp tyname name uni fun a
    -> BindingGrp tyname name uni fun a)
-> Semigroup (BindingGrp tyname name uni fun a)
forall b.
Integral b =>
b
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall tyname name (uni :: * -> *) fun a.
Semigroup a =>
NonEmpty (BindingGrp tyname name uni fun a)
-> BindingGrp tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Semigroup a =>
BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
forall tyname name (uni :: * -> *) fun a b.
(Semigroup a, Integral b) =>
b
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
stimes :: b
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
$cstimes :: forall tyname name (uni :: * -> *) fun a b.
(Semigroup a, Integral b) =>
b
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
sconcat :: NonEmpty (BindingGrp tyname name uni fun a)
-> BindingGrp tyname name uni fun a
$csconcat :: forall tyname name (uni :: * -> *) fun a.
Semigroup a =>
NonEmpty (BindingGrp tyname name uni fun a)
-> BindingGrp tyname name uni fun a
<> :: BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
$c<> :: forall tyname name (uni :: * -> *) fun a.
Semigroup a =>
BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
Semigroup via (GenericSemigroupMonoid (BindingGrp tyname name uni fun a))
-- Note on Semigroup: appending bindingGroups will not try to fix the well-scopedness by
-- rearranging any bindings or promoting to a Rec if bindings in case some bindinings refer to each other.
makeLenses ''BindingGrp

-- | Turn a 'BindingGrp' into a let, when given a minimum recursivity and let body.
bindingGrpToLet :: Recursivity
        -> BindingGrp tyname name uni fun a
        -> (Term tyname name uni fun a -> Term tyname name uni fun a)
bindingGrpToLet :: Recursivity
-> BindingGrp tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
bindingGrpToLet Recursivity
r (BindingGrp a
a Recursivity
r' NonEmpty (Binding tyname name uni fun a)
bs) = a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let a
a (Recursivity
rRecursivity -> Recursivity -> Recursivity
forall a. Semigroup a => a -> a -> a
<>Recursivity
r') NonEmpty (Binding tyname name uni fun a)
bs

-- | A store of lets to be floated at their new position
type FloatTable tyname name uni fun a = MM.MonoidalMap Pos (NE.NonEmpty (BindingGrp tyname name uni fun a))

-- | The 1st pass of marking floatable lets
mark :: forall tyname name uni fun a.
      (Ord tyname, Ord name, PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique, PLC.ToBuiltinMeaning uni fun)
     => Term tyname name uni fun a
     -> Marks
mark :: Term tyname name uni fun a -> Scope
mark = ((), Scope) -> Scope
forall a b. (a, b) -> b
snd (((), Scope) -> Scope)
-> (Term tyname name uni fun a -> ((), Scope))
-> Term tyname name uni fun a
-> Scope
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Writer Scope () -> ((), Scope)
forall w a. Writer w a -> (a, w)
runWriter (Writer Scope () -> ((), Scope))
-> (Term tyname name uni fun a -> Writer Scope ())
-> Term tyname name uni fun a
-> ((), Scope)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ReaderT MarkCtx (Writer Scope) () -> MarkCtx -> Writer Scope ())
-> MarkCtx -> ReaderT MarkCtx (Writer Scope) () -> Writer Scope ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT MarkCtx (Writer Scope) () -> MarkCtx -> Writer Scope ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Depth -> Scope -> MarkCtx
MarkCtx Depth
topDepth Scope
forall a. Monoid a => a
mempty) (ReaderT MarkCtx (Writer Scope) () -> Writer Scope ())
-> (Term tyname name uni fun a
    -> ReaderT MarkCtx (Writer Scope) ())
-> Term tyname name uni fun a
-> Writer Scope ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go
  where
    go :: Term tyname name uni fun a -> ReaderT MarkCtx (Writer Marks) ()
    go :: Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go = Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
breakNonRec (Term tyname name uni fun a -> Term tyname name uni fun a)
-> (Term tyname name uni fun a
    -> ReaderT MarkCtx (Writer Scope) ())
-> Term tyname name uni fun a
-> ReaderT MarkCtx (Writer Scope) ()
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
        -- lam/Lam are treated the same.
        LamAbs a
_ name
n Type tyname uni a
_ Term tyname name uni fun a
tBody  -> name
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall r (m :: * -> *) name unique a.
(r ~ MarkCtx, MonadReader r m, HasUnique name unique) =>
name -> m a -> m a
withLam name
n (ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$ Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go Term tyname name uni fun a
tBody
        TyAbs a
_ tyname
n Kind a
_ Term tyname name uni fun a
tBody   -> tyname
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall r (m :: * -> *) name unique a.
(r ~ MarkCtx, MonadReader r m, HasUnique name unique) =>
name -> m a -> m a
withLam tyname
n (ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$ Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go Term tyname name uni fun a
tBody

        -- main operation: for letrec or single letnonrec
        Let a
ann Recursivity
r bs :: NonEmpty (Binding tyname name uni fun a)
bs@(NonEmpty (Binding tyname name uni fun a) -> Unique
forall name tyname (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique) =>
NonEmpty (Binding tyname name uni fun a) -> Unique
representativeBindingUnique -> Unique
letU) Term tyname name uni fun a
tIn ->
          let letN :: BindingGrp tyname name uni fun a
letN = a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> BindingGrp tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> BindingGrp tyname name uni fun a
BindingGrp a
ann Recursivity
r NonEmpty (Binding tyname name uni fun a)
bs in
          if BindingGrp tyname name uni fun a -> Bool
forall (uni :: * -> *) fun tyname name a.
ToBuiltinMeaning uni fun =>
BindingGrp tyname name uni fun a -> Bool
floatable BindingGrp tyname name uni fun a
letN
          then do
            Scope
scope <- (MarkCtx -> Scope) -> ReaderT MarkCtx (Writer Scope) Scope
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks MarkCtx -> Scope
_markCtxScope
            let freeVars :: Set Unique
freeVars =
                    -- if Rec, remove the here-bindings from free
                    Recursivity
-> (Set Unique -> Set Unique) -> Set Unique -> Set Unique
forall a. Recursivity -> (a -> a) -> a -> a
ifRec Recursivity
r (Set Unique -> Set Unique -> Set Unique
forall a. Ord a => Set a -> Set a -> Set a
S.\\ Getting
  (Set Unique) (NonEmpty (Binding tyname name uni fun a)) Unique
-> NonEmpty (Binding tyname name uni fun a) -> Set Unique
forall a s. Getting (Set a) s a -> s -> Set a
setOf ((Binding tyname name uni fun a
 -> Const (Set Unique) (Binding tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> Const (Set Unique) (NonEmpty (Binding tyname name uni fun a))
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed((Binding tyname name uni fun a
  -> Const (Set Unique) (Binding tyname name uni fun a))
 -> NonEmpty (Binding tyname name uni fun a)
 -> Const (Set Unique) (NonEmpty (Binding tyname name uni fun a)))
-> ((Unique -> Const (Set Unique) Unique)
    -> Binding tyname name uni fun a
    -> Const (Set Unique) (Binding tyname name uni fun a))
-> Getting
     (Set Unique) (NonEmpty (Binding tyname name uni fun a)) Unique
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Unique -> Const (Set Unique) Unique)
-> Binding tyname name uni fun a
-> Const (Set Unique) (Binding tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Traversal1' (Binding tyname name uni fun a) Unique
bindingIds) NonEmpty (Binding tyname name uni fun a)
bs) (Set Unique -> Set Unique) -> Set Unique -> Set Unique
forall a b. (a -> b) -> a -> b
$
                       BindingGrp tyname name uni fun a -> Set Unique
forall tyname name (uni :: * -> *) fun a.
(Ord tyname, Ord name, HasUnique tyname TypeUnique,
 HasUnique name TermUnique) =>
BindingGrp tyname name uni fun a -> Set Unique
calcFreeVars BindingGrp tyname name uni fun a
letN

            -- The "heart" of the algorithm: the future position to float this let to
            -- is determined as the maximum among its dependencies (free vars).
            let floatPos :: Pos
floatPos@(Pos Depth
floatDepth Unique
_ PosType
_) = Scope -> Pos
forall k. Map k Pos -> Pos
maxPos (Scope -> Pos) -> Scope -> Pos
forall a b. (a -> b) -> a -> b
$ Scope
scope Scope -> Set Unique -> Scope
forall k a. Ord k => Map k a -> Set k -> Map k a
`M.restrictKeys` Set Unique
freeVars

            -- visit the rhs'es
            -- IMPORTANT: inside the rhs, act like the current depth
            -- is the future floated depth of this rhs.
            (Depth -> Depth)
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall r (m :: * -> *) a.
(r ~ MarkCtx, MonadReader r m) =>
(Depth -> Depth) -> m a -> m a
withDepth (Depth -> Depth -> Depth
forall a b. a -> b -> a
const Depth
floatDepth) (ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$
                -- if rec, then its bindings are in scope in the rhs'es
                Recursivity
-> (ReaderT MarkCtx (Writer Scope) ()
    -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a. Recursivity -> (a -> a) -> a -> a
ifRec Recursivity
r (NonEmpty (Binding tyname name uni fun a)
-> Pos
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall r (m :: * -> *) name tyname (uni :: * -> *) fun a3 a.
(r ~ MarkCtx, MonadReader r m, HasUnique name TermUnique,
 HasUnique tyname TypeUnique) =>
NonEmpty (Binding tyname name uni fun a3) -> Pos -> m a -> m a
withBs NonEmpty (Binding tyname name uni fun a)
bs Pos
floatPos) (ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$
                    (Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ())
-> [Term tyname name uni fun a]
-> ReaderT MarkCtx (Writer Scope) ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go (NonEmpty (Binding tyname name uni fun a)
bsNonEmpty (Binding tyname name uni fun a)
-> Getting
     (Endo [Term tyname name uni fun a])
     (NonEmpty (Binding tyname name uni fun a))
     (Term tyname name uni fun a)
-> [Term tyname name uni fun a]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..(Binding tyname name uni fun a
 -> Const
      (Endo [Term tyname name uni fun a])
      (Binding tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> Const
     (Endo [Term tyname name uni fun a])
     (NonEmpty (Binding tyname name uni fun a))
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed((Binding tyname name uni fun a
  -> Const
       (Endo [Term tyname name uni fun a])
       (Binding tyname name uni fun a))
 -> NonEmpty (Binding tyname name uni fun a)
 -> Const
      (Endo [Term tyname name uni fun a])
      (NonEmpty (Binding tyname name uni fun a)))
-> ((Term tyname name uni fun a
     -> Const
          (Endo [Term tyname name uni fun a]) (Term tyname name uni fun a))
    -> Binding tyname name uni fun a
    -> Const
         (Endo [Term tyname name uni fun a])
         (Binding tyname name uni fun a))
-> Getting
     (Endo [Term tyname name uni fun a])
     (NonEmpty (Binding tyname name uni fun a))
     (Term tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Term tyname name uni fun a
 -> Const
      (Endo [Term tyname name uni fun a]) (Term tyname name uni fun a))
-> Binding tyname name uni fun a
-> Const
     (Endo [Term tyname name uni fun a]) (Binding tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Binding tyname name uni fun a) (Term tyname name uni fun a)
bindingSubterms)

            -- visit the inTerm
            -- bindings are inscope in the InTerm for both rec&nonrec
            NonEmpty (Binding tyname name uni fun a)
-> Pos
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall r (m :: * -> *) name tyname (uni :: * -> *) fun a3 a.
(r ~ MarkCtx, MonadReader r m, HasUnique name TermUnique,
 HasUnique tyname TypeUnique) =>
NonEmpty (Binding tyname name uni fun a3) -> Pos -> m a -> m a
withBs NonEmpty (Binding tyname name uni fun a)
bs Pos
floatPos (ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$ Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go Term tyname name uni fun a
tIn

            -- collect here the new mark and propagate all
            Scope -> ReaderT MarkCtx (Writer Scope) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Scope -> ReaderT MarkCtx (Writer Scope) ())
-> Scope -> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$ Unique -> Pos -> Scope
forall k a. k -> a -> Map k a
M.singleton Unique
letU Pos
floatPos
          else do
            -- since it is unfloatable (effectful), this let is a new anchor
            -- acts as anchor both in rhs'es and inTerm
            (Depth -> Depth)
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall r (m :: * -> *) a.
(r ~ MarkCtx, MonadReader r m) =>
(Depth -> Depth) -> m a -> m a
withDepth (Depth -> Depth -> Depth
forall a. Num a => a -> a -> a
+Depth
1) (ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$ do
                Depth
depth <- (MarkCtx -> Depth) -> ReaderT MarkCtx (Writer Scope) Depth
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks MarkCtx -> Depth
_markCtxDepth
                let toPos :: PosType -> Pos
toPos = Depth -> Unique -> PosType -> Pos
Pos Depth
depth Unique
letU
                -- visit the rhs'es
                -- if rec, then its bindings are in scope in the rhs'es
                Recursivity
-> (ReaderT MarkCtx (Writer Scope) ()
    -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a. Recursivity -> (a -> a) -> a -> a
ifRec Recursivity
r (NonEmpty (Binding tyname name uni fun a)
-> Pos
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall r (m :: * -> *) name tyname (uni :: * -> *) fun a3 a.
(r ~ MarkCtx, MonadReader r m, HasUnique name TermUnique,
 HasUnique tyname TypeUnique) =>
NonEmpty (Binding tyname name uni fun a3) -> Pos -> m a -> m a
withBs NonEmpty (Binding tyname name uni fun a)
bs (Pos
 -> ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> Pos
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$ PosType -> Pos
toPos PosType
LetRhs) (ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$ (Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ())
-> [Term tyname name uni fun a]
-> ReaderT MarkCtx (Writer Scope) ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go (NonEmpty (Binding tyname name uni fun a)
bsNonEmpty (Binding tyname name uni fun a)
-> Getting
     (Endo [Term tyname name uni fun a])
     (NonEmpty (Binding tyname name uni fun a))
     (Term tyname name uni fun a)
-> [Term tyname name uni fun a]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..(Binding tyname name uni fun a
 -> Const
      (Endo [Term tyname name uni fun a])
      (Binding tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> Const
     (Endo [Term tyname name uni fun a])
     (NonEmpty (Binding tyname name uni fun a))
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed((Binding tyname name uni fun a
  -> Const
       (Endo [Term tyname name uni fun a])
       (Binding tyname name uni fun a))
 -> NonEmpty (Binding tyname name uni fun a)
 -> Const
      (Endo [Term tyname name uni fun a])
      (NonEmpty (Binding tyname name uni fun a)))
-> ((Term tyname name uni fun a
     -> Const
          (Endo [Term tyname name uni fun a]) (Term tyname name uni fun a))
    -> Binding tyname name uni fun a
    -> Const
         (Endo [Term tyname name uni fun a])
         (Binding tyname name uni fun a))
-> Getting
     (Endo [Term tyname name uni fun a])
     (NonEmpty (Binding tyname name uni fun a))
     (Term tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Term tyname name uni fun a
 -> Const
      (Endo [Term tyname name uni fun a]) (Term tyname name uni fun a))
-> Binding tyname name uni fun a
-> Const
     (Endo [Term tyname name uni fun a]) (Binding tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Binding tyname name uni fun a) (Term tyname name uni fun a)
bindingSubterms)

                -- bindings are inscope in the InTerm for both rec&nonrec
                NonEmpty (Binding tyname name uni fun a)
-> Pos
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall r (m :: * -> *) name tyname (uni :: * -> *) fun a3 a.
(r ~ MarkCtx, MonadReader r m, HasUnique name TermUnique,
 HasUnique tyname TypeUnique) =>
NonEmpty (Binding tyname name uni fun a3) -> Pos -> m a -> m a
withBs NonEmpty (Binding tyname name uni fun a)
bs (PosType -> Pos
toPos PosType
LamBody) (ReaderT MarkCtx (Writer Scope) ()
 -> ReaderT MarkCtx (Writer Scope) ())
-> ReaderT MarkCtx (Writer Scope) ()
-> ReaderT MarkCtx (Writer Scope) ()
forall a b. (a -> b) -> a -> b
$ Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go Term tyname name uni fun a
tIn

        -- descend and collect
        Term tyname name uni fun a
t -> Getting
  (Traversed () (ReaderT MarkCtx (Writer Scope)))
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> (Term tyname name uni fun a
    -> ReaderT MarkCtx (Writer Scope) ())
-> Term tyname name uni fun a
-> ReaderT MarkCtx (Writer Scope) ()
forall (f :: * -> *) r s a.
Functor f =>
Getting (Traversed r f) s a -> (a -> f r) -> s -> f ()
traverseOf_ Getting
  (Traversed () (ReaderT MarkCtx (Writer Scope)))
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Term tyname name uni fun a) (Term tyname name uni fun a)
termSubterms Term tyname name uni fun a -> ReaderT MarkCtx (Writer Scope) ()
go Term tyname name uni fun a
t

-- | Given a 'BindingGrp', calculate its free vars and free tyvars and collect them in a set.
calcFreeVars :: forall tyname name uni fun a.
             (Ord tyname, Ord name, PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
             => BindingGrp tyname name uni fun a
             -> S.Set PLC.Unique
calcFreeVars :: BindingGrp tyname name uni fun a -> Set Unique
calcFreeVars (BindingGrp a
_ Recursivity
r NonEmpty (Binding tyname name uni fun a)
bs) = (Binding tyname name uni fun a -> Set Unique)
-> NonEmpty (Binding tyname name uni fun a) -> Set Unique
forall (t :: * -> *) m a.
(Foldable1 t, Semigroup m) =>
(a -> m) -> t a -> m
foldMap1 Binding tyname name uni fun a -> Set Unique
calcBinding NonEmpty (Binding tyname name uni fun a)
bs
  where
    -- given a binding return all its free term *AND* free type variables
    calcBinding :: Binding tyname name uni fun a -> S.Set PLC.Unique
    calcBinding :: Binding tyname name uni fun a -> Set Unique
calcBinding Binding tyname name uni fun a
b =
        -- OPTIMIZE: safe to change to S.mapMonotonic?
        (name -> Unique) -> Set name -> Set Unique
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (name -> Getting Unique name Unique -> Unique
forall s a. s -> Getting a s a -> a
^.Getting Unique name Unique
forall name unique. HasUnique name unique => Lens' name Unique
PLC.theUnique) (Binding tyname name uni fun a -> Set name
forall name tyname (uni :: * -> *) fun ann.
Ord name =>
Binding tyname name uni fun ann -> Set name
fvBinding Binding tyname name uni fun a
b)
        Set Unique -> Set Unique -> Set Unique
forall a. Semigroup a => a -> a -> a
<> (tyname -> Unique) -> Set tyname -> Set Unique
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (tyname -> Getting Unique tyname Unique -> Unique
forall s a. s -> Getting a s a -> a
^.Getting Unique tyname Unique
forall name unique. HasUnique name unique => Lens' name Unique
PLC.theUnique) (Recursivity -> Binding tyname name uni fun a -> Set tyname
forall tyname name (uni :: * -> *) fun ann.
Ord tyname =>
Recursivity -> Binding tyname name uni fun ann -> Set tyname
ftvBinding Recursivity
r Binding tyname name uni fun a
b)

-- | The second pass of cleaning the term of the floatable lets, and placing them in a separate map
-- OPTIMIZE: use State for building the FloatTable, and for reducing the Marks
removeLets :: forall tyname name uni fun a term.
            (term~Term tyname name uni fun a
            ,PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique)
           => Marks
           -> term
           -> (term, FloatTable tyname name uni fun a)
removeLets :: Scope -> term -> (term, FloatTable tyname name uni fun a)
removeLets Scope
marks term
term = Writer (FloatTable tyname name uni fun a) term
-> (term, FloatTable tyname name uni fun a)
forall w a. Writer w a -> (a, w)
runWriter (Writer (FloatTable tyname name uni fun a) term
 -> (term, FloatTable tyname name uni fun a))
-> Writer (FloatTable tyname name uni fun a) term
-> (term, FloatTable tyname name uni fun a)
forall a b. (a -> b) -> a -> b
$ term -> Writer (FloatTable tyname name uni fun a) term
go term
term
  where
    -- TODO: use State for the Marks to safeguard against any bugs where floatable lets are not removed as they should to.
    go :: term -> Writer (FloatTable tyname name uni fun a) term
    go :: term -> Writer (FloatTable tyname name uni fun a) term
go = Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Term tyname name uni fun a -> Term tyname name uni fun a
breakNonRec (Term tyname name uni fun a -> Term tyname name uni fun a)
-> (Term tyname name uni fun a
    -> WriterT
         (FloatTable tyname name uni fun a)
         Identity
         (Term tyname name uni fun a))
-> Term tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> \case
        -- main operation: for letrec or single letnonrec
        Let a
a Recursivity
r bs :: NonEmpty (Binding tyname name uni fun a)
bs@(NonEmpty (Binding tyname name uni fun a) -> Unique
forall name tyname (uni :: * -> *) fun a.
(HasUnique name TermUnique, HasUnique tyname TypeUnique) =>
NonEmpty (Binding tyname name uni fun a) -> Unique
representativeBindingUnique -> Unique
letU) Term tyname name uni fun a
tIn -> do
            -- go to rhs'es and collect their floattable + cleanedterm
            NonEmpty (Binding tyname name uni fun a)
bs' <- (Binding tyname name uni fun a
 -> WriterT
      (FloatTable tyname name uni fun a)
      Identity
      (Binding tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (NonEmpty (Binding tyname name uni fun a))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Binding tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Binding tyname name uni fun a)
goBinding NonEmpty (Binding tyname name uni fun a)
bs
            -- go to inTerm and collect its floattable + cleanedterm
            term
tIn' <- term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
tIn
            case Unique -> Scope -> Maybe Pos
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Unique
letU Scope
marks of
                -- this is not a floatable let
                Maybe Pos
Nothing  -> Term tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term tyname name uni fun a
 -> WriterT
      (FloatTable tyname name uni fun a)
      Identity
      (Term tyname name uni fun a))
-> Term tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall a b. (a -> b) -> a -> b
$ a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let a
a Recursivity
r NonEmpty (Binding tyname name uni fun a)
bs' term
Term tyname name uni fun a
tIn'
                -- floatable let found.
                -- move this let to the floattable, and just return the body
                Just Pos
pos -> do
                    FloatTable tyname name uni fun a
-> WriterT (FloatTable tyname name uni fun a) Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Pos
-> NonEmpty (BindingGrp tyname name uni fun a)
-> FloatTable tyname name uni fun a
forall k a. k -> a -> MonoidalMap k a
MM.singleton Pos
pos (BindingGrp tyname name uni fun a
-> NonEmpty (BindingGrp tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BindingGrp tyname name uni fun a
 -> NonEmpty (BindingGrp tyname name uni fun a))
-> BindingGrp tyname name uni fun a
-> NonEmpty (BindingGrp tyname name uni fun a)
forall a b. (a -> b) -> a -> b
$ a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> BindingGrp tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> BindingGrp tyname name uni fun a
BindingGrp a
a Recursivity
r NonEmpty (Binding tyname name uni fun a)
bs'))
                    term -> Writer (FloatTable tyname name uni fun a) term
forall (f :: * -> *) a. Applicative f => a -> f a
pure term
tIn'

        -- descend and collect
        Apply a
a Term tyname name uni fun a
t1 Term tyname name uni fun a
t2 -> a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Apply a
a (Term tyname name uni fun a
 -> Term tyname name uni fun a -> Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a -> Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
t1 WriterT
  (FloatTable tyname name uni fun a)
  Identity
  (Term tyname name uni fun a -> Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
t2
        TyInst a
a Term tyname name uni fun a
t Type tyname uni a
ty -> a
-> Term tyname name uni fun a
-> Type tyname uni a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Term tyname name uni fun a
-> Type tyname uni a
-> Term tyname name uni fun a
TyInst a
a (Term tyname name uni fun a
 -> Type tyname uni a -> Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Type tyname uni a -> Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
t WriterT
  (FloatTable tyname name uni fun a)
  Identity
  (Type tyname uni a -> Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a) Identity (Type tyname uni a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type tyname uni a
-> WriterT
     (FloatTable tyname name uni fun a) Identity (Type tyname uni a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type tyname uni a
ty
        TyAbs a
a tyname
tyname Kind a
k Term tyname name uni fun a
t -> a
-> tyname
-> Kind a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> tyname
-> Kind a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
TyAbs a
a tyname
tyname Kind a
k (Term tyname name uni fun a -> Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
t
        LamAbs a
a name
name Type tyname uni a
ty Term tyname name uni fun a
t -> a
-> name
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> name
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
LamAbs a
a name
name Type tyname uni a
ty (Term tyname name uni fun a -> Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
t
        IWrap a
a Type tyname uni a
ty1 Type tyname uni a
ty2 Term tyname name uni fun a
t -> a
-> Type tyname uni a
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Type tyname uni a
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
IWrap a
a Type tyname uni a
ty1 Type tyname uni a
ty2 (Term tyname name uni fun a -> Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
t
        Unwrap a
a Term tyname name uni fun a
t -> a -> Term tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a -> Term tyname name uni fun a -> Term tyname name uni fun a
Unwrap a
a (Term tyname name uni fun a -> Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
t

        -- no term inside here, nothing to do
        t :: Term tyname name uni fun a
t@Var{} -> Term tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term tyname name uni fun a
t
        t :: Term tyname name uni fun a
t@Constant{} -> Term tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term tyname name uni fun a
t
        t :: Term tyname name uni fun a
t@Builtin{} -> Term tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term tyname name uni fun a
t
        t :: Term tyname name uni fun a
t@Error{} -> Term tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Term tyname name uni fun a
t

    goBinding :: Binding tyname name uni fun a
              -> Writer (FloatTable tyname name uni fun a) (Binding tyname name uni fun a)
    goBinding :: Binding tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Binding tyname name uni fun a)
goBinding = \case
        TermBind a
x Strictness
s VarDecl tyname name uni fun a
d Term tyname name uni fun a
t -> a
-> Strictness
-> VarDecl tyname name uni fun a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Strictness
-> VarDecl tyname name uni fun a
-> Term tyname name uni fun a
-> Binding tyname name uni fun a
TermBind a
x Strictness
s VarDecl tyname name uni fun a
d (Term tyname name uni fun a -> Binding tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Term tyname name uni fun a)
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Binding tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> term -> Writer (FloatTable tyname name uni fun a) term
go term
Term tyname name uni fun a
t
        -- no term inside here, nothing to do
        b :: Binding tyname name uni fun a
b@TypeBind{}     -> Binding tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Binding tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding tyname name uni fun a
b
        b :: Binding tyname name uni fun a
b@DatatypeBind{} -> Binding tyname name uni fun a
-> WriterT
     (FloatTable tyname name uni fun a)
     Identity
     (Binding tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding tyname name uni fun a
b

-- | The 3rd and last pass that, given the result of 'removeLets', places the lets back (floats) at the right marked positions.
floatBackLets :: forall tyname name uni fun a term m.
                ( term~Term tyname name uni fun a
                , m~Reader Depth
                , PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique, Semigroup a)
              => term -- ^ the cleanedup, reducted term
              -> FloatTable tyname name uni fun a -- ^ the lets to be floated
              -> term -- ^ the final, floated, and correctly-scoped term
floatBackLets :: term -> FloatTable tyname name uni fun a -> term
floatBackLets term
term FloatTable tyname name uni fun a
fTable =
    -- our reader context is only the depth this time.
    (Reader Depth term -> Depth -> term)
-> Depth -> Reader Depth term -> term
forall a b c. (a -> b -> c) -> b -> a -> c
flip Reader Depth term -> Depth -> term
forall r a. Reader r a -> r -> a
runReader Depth
topDepth (Reader Depth term -> term) -> Reader Depth term -> term
forall a b. (a -> b) -> a -> b
$ term -> m term
goTop term
term
  where

    -- TODO: use State for FloatTable to safeguard against any bugs where floatable-lets were not floated as they should to.
    goTop, go :: term -> m term

    -- after traversing the cleaned term, try to float the lets that are destined for top (global lets)
    goTop :: term -> m term
goTop = Unique -> term -> m term
floatLam Unique
topUnique (term -> m term) -> (term -> m term) -> term -> m term
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< term -> m term
go

    go :: term -> m term
go = \case
        -- lam anchor, increase depth & try to float inside the lam's body
        LamAbs a n ty tBody -> (Depth -> Depth)
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Depth -> Depth -> Depth
forall a. Num a => a -> a -> a
+Depth
1) (m (Term tyname name uni fun a) -> m (Term tyname name uni fun a))
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall a b. (a -> b) -> a -> b
$
            a
-> name
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> name
-> Type tyname uni a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
LamAbs a
a name
n Type tyname uni a
ty (Term tyname name uni fun a -> Term tyname name uni fun a)
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Unique -> term -> m term
floatLam (name
nname -> Getting Unique name Unique -> Unique
forall s a. s -> Getting a s a -> a
^.Getting Unique name Unique
forall name unique. HasUnique name unique => Lens' name Unique
PLC.theUnique) (term -> m term) -> m term -> m term
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< term -> m term
go term
Term tyname name uni fun a
tBody)
        -- Lam anchor, increase depth & try to float inside the Lam's body
        TyAbs a n k tBody -> (Depth -> Depth)
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Depth -> Depth -> Depth
forall a. Num a => a -> a -> a
+Depth
1) (m (Term tyname name uni fun a) -> m (Term tyname name uni fun a))
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall a b. (a -> b) -> a -> b
$
            a
-> tyname
-> Kind a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> tyname
-> Kind a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
TyAbs a
a tyname
n Kind a
k (Term tyname name uni fun a -> Term tyname name uni fun a)
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Unique -> term -> m term
floatLam (tyname
ntyname -> Getting Unique tyname Unique -> Unique
forall s a. s -> Getting a s a -> a
^.Getting Unique tyname Unique
forall name unique. HasUnique name unique => Lens' name Unique
PLC.theUnique) (term -> m term) -> m term -> m term
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< term -> m term
go term
Term tyname name uni fun a
tBody)
        -- Unfloatable-let anchor, increase depth
        Let a r bs@(representativeBindingUnique -> letU) tIn -> (Depth -> Depth)
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Depth -> Depth -> Depth
forall a. Num a => a -> a -> a
+Depth
1) (m (Term tyname name uni fun a) -> m (Term tyname name uni fun a))
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall a b. (a -> b) -> a -> b
$ do
            -- note that we do not touch the original recursivity of the unfloatable-let
            BindingGrp tyname name uni fun a
unfloatableGrp <- a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> BindingGrp tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> BindingGrp tyname name uni fun a
BindingGrp a
a Recursivity
r (NonEmpty (Binding tyname name uni fun a)
 -> BindingGrp tyname name uni fun a)
-> m (NonEmpty (Binding tyname name uni fun a))
-> m (BindingGrp tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LensLike
  m
  (NonEmpty (Binding tyname name uni fun a))
  (NonEmpty (Binding tyname name uni fun a))
  (Term tyname name uni fun a)
  (Term tyname name uni fun a)
-> LensLike
     m
     (NonEmpty (Binding tyname name uni fun a))
     (NonEmpty (Binding tyname name uni fun a))
     (Term tyname name uni fun a)
     (Term tyname name uni fun a)
forall (f :: * -> *) s t a b.
LensLike f s t a b -> LensLike f s t a b
traverseOf ((Binding tyname name uni fun a
 -> m (Binding tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> m (NonEmpty (Binding tyname name uni fun a))
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed((Binding tyname name uni fun a
  -> m (Binding tyname name uni fun a))
 -> NonEmpty (Binding tyname name uni fun a)
 -> m (NonEmpty (Binding tyname name uni fun a)))
-> ((Term tyname name uni fun a -> m (Term tyname name uni fun a))
    -> Binding tyname name uni fun a
    -> m (Binding tyname name uni fun a))
-> LensLike
     m
     (NonEmpty (Binding tyname name uni fun a))
     (NonEmpty (Binding tyname name uni fun a))
     (Term tyname name uni fun a)
     (Term tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> Binding tyname name uni fun a
-> m (Binding tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Binding tyname name uni fun a) (Term tyname name uni fun a)
bindingSubterms) term -> m term
Term tyname name uni fun a -> m (Term tyname name uni fun a)
go NonEmpty (Binding tyname name uni fun a)
bs
            -- rebuild the let-group (we take the minimum bound, i.e. NonRec)
            Recursivity
-> BindingGrp tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Recursivity
-> BindingGrp tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
bindingGrpToLet Recursivity
NonRec
              (BindingGrp tyname name uni fun a
 -> Term tyname name uni fun a -> Term tyname name uni fun a)
-> m (BindingGrp tyname name uni fun a)
-> m (Term tyname name uni fun a -> Term tyname name uni fun a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> -- float inside the rhs of the unfloatable group, and merge the bindings
                  Unique
-> BindingGrp tyname name uni fun a
-> m (BindingGrp tyname name uni fun a)
forall grp.
(grp ~ BindingGrp tyname name uni fun a) =>
Unique -> grp -> m grp
floatRhs Unique
letU BindingGrp tyname name uni fun a
unfloatableGrp
                  -- float right inside the inTerm (similar to lam/Lam)
              m (Term tyname name uni fun a -> Term tyname name uni fun a)
-> m (Term tyname name uni fun a) -> m (Term tyname name uni fun a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Unique -> term -> m term
floatLam Unique
letU (term -> m term) -> m term -> m term
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< term -> m term
go term
Term tyname name uni fun a
tIn)

        -- descend
        term
t                  -> term
t term
-> (term -> m (Term tyname name uni fun a))
-> m (Term tyname name uni fun a)
forall a b. a -> (a -> b) -> b
& (Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> Term tyname name uni fun a -> m (Term tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Term tyname name uni fun a) (Term tyname name uni fun a)
termSubterms term -> m term
Term tyname name uni fun a -> m (Term tyname name uni fun a)
go

    -- Make a brand new let-group comprised of all the floatable lets just inside the lam-body/Lam-body/let-InTerm
    floatLam :: PLC.Unique -> term -> m term
    floatLam :: Unique -> term -> m term
floatLam Unique
lamU term
t = do
        Pos
herePos <- (Depth -> Pos) -> m Pos
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Depth -> Pos) -> m Pos) -> (Depth -> Pos) -> m Pos
forall a b. (a -> b) -> a -> b
$ \Depth
d -> Depth -> Unique -> PosType -> Pos
Pos Depth
d Unique
lamU PosType
LamBody
        -- We need to force to Rec because we might merge lets which depend on each other,
        -- but we can't tell because we don't do dependency resolution at this pass.
        -- So we have to be conservative. See Note [LetRec splitting pass]
        Pos
-> (BindingGrp tyname name uni fun a
    -> Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a
-> m (Term tyname name uni fun a)
forall c.
Pos -> (BindingGrp tyname name uni fun a -> c -> c) -> c -> m c
floatAt Pos
herePos (Recursivity
-> BindingGrp tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
Recursivity
-> BindingGrp tyname name uni fun a
-> Term tyname name uni fun a
-> Term tyname name uni fun a
bindingGrpToLet Recursivity
Rec) term
Term tyname name uni fun a
t

    floatRhs :: (grp ~ BindingGrp tyname name uni fun a)
             => PLC.Unique
             -> grp -- ^ the unfloatable group
             -> m grp -- ^ the result group extended with the floatable rhs'es (size(result_group) >= size(unfloatable_group))
    floatRhs :: Unique -> grp -> m grp
floatRhs Unique
letU grp
bs = do
        Pos
herePos <- (Depth -> Pos) -> m Pos
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Depth -> Pos) -> m Pos) -> (Depth -> Pos) -> m Pos
forall a b. (a -> b) -> a -> b
$ \Depth
d -> Depth -> Unique -> PosType -> Pos
Pos Depth
d Unique
letU PosType
LetRhs
        -- we don't know from which rhs the floatable-let(s) came from originally,
        -- so we instead are going to semigroup-append the floatable-let bindings together with the unfloatable let-group's bindings
        Pos
-> (BindingGrp tyname name uni fun a
    -> BindingGrp tyname name uni fun a
    -> BindingGrp tyname name uni fun a)
-> BindingGrp tyname name uni fun a
-> m (BindingGrp tyname name uni fun a)
forall c.
Pos -> (BindingGrp tyname name uni fun a -> c -> c) -> c -> m c
floatAt Pos
herePos BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
-> BindingGrp tyname name uni fun a
forall a. Semigroup a => a -> a -> a
(<>) grp
BindingGrp tyname name uni fun a
bs

    floatAt :: Pos -- ^ floating position
            -> (BindingGrp tyname name uni fun a -> c -> c) -- ^ how to place the unfloatable-group into the PIR result
            -> c -- ^ term or bindings to float AROUND
            -> m c -- ^ the combined PIR result (terms or bindings)
    floatAt :: Pos -> (BindingGrp tyname name uni fun a -> c -> c) -> c -> m c
floatAt Pos
herePos BindingGrp tyname name uni fun a -> c -> c
placeIntoFn c
termOrBindings = do
        -- is there something to be floated here?
        case Pos
-> FloatTable tyname name uni fun a
-> Maybe (NonEmpty (BindingGrp tyname name uni fun a))
forall k a. Ord k => k -> MonoidalMap k a -> Maybe a
MM.lookup Pos
herePos FloatTable tyname name uni fun a
fTable of
            -- nothing to float, just descend
            Maybe (NonEmpty (BindingGrp tyname name uni fun a))
Nothing -> c -> m c
forall (f :: * -> *) a. Applicative f => a -> f a
pure c
termOrBindings
            -- all the naked-lets to be floated here
            Just NonEmpty (BindingGrp tyname name uni fun a)
floatableGrps -> do
                -- visit the rhs'es of these floated lets for any potential floatings as well
                -- NOTE: we do not directly run `go(bgGroup)` because that would increase the depth,
                -- and the floated lets are not anchors themselves; instead we run go on the floated-let bindings' subterms.
                NonEmpty (BindingGrp tyname name uni fun a)
floatableGrps' <- NonEmpty (BindingGrp tyname name uni fun a)
floatableGrps NonEmpty (BindingGrp tyname name uni fun a)
-> (NonEmpty (BindingGrp tyname name uni fun a)
    -> m (NonEmpty (BindingGrp tyname name uni fun a)))
-> m (NonEmpty (BindingGrp tyname name uni fun a))
forall a b. a -> (a -> b) -> b
& ((BindingGrp tyname name uni fun a
 -> m (BindingGrp tyname name uni fun a))
-> NonEmpty (BindingGrp tyname name uni fun a)
-> m (NonEmpty (BindingGrp tyname name uni fun a))
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed((BindingGrp tyname name uni fun a
  -> m (BindingGrp tyname name uni fun a))
 -> NonEmpty (BindingGrp tyname name uni fun a)
 -> m (NonEmpty (BindingGrp tyname name uni fun a)))
-> ((Term tyname name uni fun a -> m (Term tyname name uni fun a))
    -> BindingGrp tyname name uni fun a
    -> m (BindingGrp tyname name uni fun a))
-> (Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> NonEmpty (BindingGrp tyname name uni fun a)
-> m (NonEmpty (BindingGrp tyname name uni fun a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(NonEmpty (Binding tyname name uni fun a)
 -> ReaderT
      Depth Identity (NonEmpty (Binding tyname name uni fun a)))
-> BindingGrp tyname name uni fun a
-> ReaderT Depth Identity (BindingGrp tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a tyname name
       (uni :: * -> *) fun.
Lens
  (BindingGrp tyname name uni fun a)
  (BindingGrp tyname name uni fun a)
  (NonEmpty (Binding tyname name uni fun a))
  (NonEmpty (Binding tyname name uni fun a))
bgBindings((NonEmpty (Binding tyname name uni fun a)
  -> ReaderT
       Depth Identity (NonEmpty (Binding tyname name uni fun a)))
 -> BindingGrp tyname name uni fun a
 -> ReaderT Depth Identity (BindingGrp tyname name uni fun a))
-> ((Term tyname name uni fun a -> m (Term tyname name uni fun a))
    -> NonEmpty (Binding tyname name uni fun a)
    -> ReaderT
         Depth Identity (NonEmpty (Binding tyname name uni fun a)))
-> (Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> BindingGrp tyname name uni fun a
-> ReaderT Depth Identity (BindingGrp tyname name uni fun a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Binding tyname name uni fun a
 -> ReaderT Depth Identity (Binding tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> ReaderT
     Depth Identity (NonEmpty (Binding tyname name uni fun a))
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed((Binding tyname name uni fun a
  -> ReaderT Depth Identity (Binding tyname name uni fun a))
 -> NonEmpty (Binding tyname name uni fun a)
 -> ReaderT
      Depth Identity (NonEmpty (Binding tyname name uni fun a)))
-> ((Term tyname name uni fun a -> m (Term tyname name uni fun a))
    -> Binding tyname name uni fun a
    -> ReaderT Depth Identity (Binding tyname name uni fun a))
-> (Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> NonEmpty (Binding tyname name uni fun a)
-> ReaderT
     Depth Identity (NonEmpty (Binding tyname name uni fun a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Term tyname name uni fun a -> m (Term tyname name uni fun a))
-> Binding tyname name uni fun a
-> ReaderT Depth Identity (Binding tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a.
Traversal'
  (Binding tyname name uni fun a) (Term tyname name uni fun a)
bindingSubterms) term -> m term
Term tyname name uni fun a -> m (Term tyname name uni fun a)
go
                -- fold the floatable groups into a *single* floatablegroup and combine that with some pir (term or bindings).
                c -> m c
forall (f :: * -> *) a. Applicative f => a -> f a
pure (c -> m c) -> c -> m c
forall a b. (a -> b) -> a -> b
$ NonEmpty (BindingGrp tyname name uni fun a)
-> BindingGrp tyname name uni fun a
forall (t :: * -> *) m. (Foldable1 t, Semigroup m) => t m -> m
fold1 NonEmpty (BindingGrp tyname name uni fun a)
floatableGrps' BindingGrp tyname name uni fun a -> c -> c
`placeIntoFn` c
termOrBindings

-- | The compiler pass of the algorithm (comprised of 3 connected passes).
floatTerm :: (PLC.ToBuiltinMeaning uni fun,
            PLC.HasUnique tyname PLC.TypeUnique, PLC.HasUnique name PLC.TermUnique,
            Ord tyname, Ord name, Semigroup a
            )
          => Term tyname name uni fun a -> Term tyname name uni fun a
floatTerm :: Term tyname name uni fun a -> Term tyname name uni fun a
floatTerm Term tyname name uni fun a
t =
    Term tyname name uni fun a -> Scope
forall tyname name (uni :: * -> *) fun a.
(Ord tyname, Ord name, HasUnique tyname TypeUnique,
 HasUnique name TermUnique, ToBuiltinMeaning uni fun) =>
Term tyname name uni fun a -> Scope
mark Term tyname name uni fun a
t
    Scope
-> (Scope
    -> (Term tyname name uni fun a, FloatTable tyname name uni fun a))
-> (Term tyname name uni fun a, FloatTable tyname name uni fun a)
forall a b. a -> (a -> b) -> b
& (Scope
 -> Term tyname name uni fun a
 -> (Term tyname name uni fun a, FloatTable tyname name uni fun a))
-> Term tyname name uni fun a
-> Scope
-> (Term tyname name uni fun a, FloatTable tyname name uni fun a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Scope
-> Term tyname name uni fun a
-> (Term tyname name uni fun a, FloatTable tyname name uni fun a)
forall tyname name (uni :: * -> *) fun a term.
(term ~ Term tyname name uni fun a, HasUnique tyname TypeUnique,
 HasUnique name TermUnique) =>
Scope -> term -> (term, FloatTable tyname name uni fun a)
removeLets Term tyname name uni fun a
t
    (Term tyname name uni fun a, FloatTable tyname name uni fun a)
-> ((Term tyname name uni fun a, FloatTable tyname name uni fun a)
    -> Term tyname name uni fun a)
-> Term tyname name uni fun a
forall a b. a -> (a -> b) -> b
& (Term tyname name uni fun a
 -> FloatTable tyname name uni fun a -> Term tyname name uni fun a)
-> (Term tyname name uni fun a, FloatTable tyname name uni fun a)
-> Term tyname name uni fun a
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Term tyname name uni fun a
-> FloatTable tyname name uni fun a -> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a term (m :: * -> *).
(term ~ Term tyname name uni fun a, m ~ Reader Depth,
 HasUnique tyname TypeUnique, HasUnique name TermUnique,
 Semigroup a) =>
term -> FloatTable tyname name uni fun a -> term
floatBackLets

-- HELPERS

maxPos :: M.Map k Pos -> Pos
maxPos :: Map k Pos -> Pos
maxPos = (Pos -> Pos -> Pos) -> Pos -> Map k Pos -> Pos
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Pos -> Pos -> Pos
forall a. Ord a => a -> a -> a
max Pos
topPos

withDepth :: (r ~ MarkCtx, MonadReader r m)
          => (Depth -> Depth) -> m a -> m a
withDepth :: (Depth -> Depth) -> m a -> m a
withDepth = (MarkCtx -> MarkCtx) -> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((MarkCtx -> MarkCtx) -> m a -> m a)
-> ((Depth -> Depth) -> MarkCtx -> MarkCtx)
-> (Depth -> Depth)
-> m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter MarkCtx MarkCtx Depth Depth
-> (Depth -> Depth) -> MarkCtx -> MarkCtx
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter MarkCtx MarkCtx Depth Depth
Lens' MarkCtx Depth
markCtxDepth

withLam :: (r ~ MarkCtx, MonadReader r m, PLC.HasUnique name unique)
        => name
        -> m a -> m a
withLam :: name -> m a -> m a
withLam name
n = (MarkCtx -> MarkCtx) -> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((MarkCtx -> MarkCtx) -> m a -> m a)
-> (MarkCtx -> MarkCtx) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ \ (MarkCtx Depth
d Scope
scope) ->
    let u :: Unique
u = name
nname -> Getting Unique name Unique -> Unique
forall s a. s -> Getting a s a -> a
^.Getting Unique name Unique
forall name unique. HasUnique name unique => Lens' name Unique
PLC.theUnique
        d' :: Depth
d' = Depth
dDepth -> Depth -> Depth
forall a. Num a => a -> a -> a
+Depth
1
        pos' :: Pos
pos' = Depth -> Unique -> PosType -> Pos
Pos Depth
d' Unique
u PosType
LamBody
    in Depth -> Scope -> MarkCtx
MarkCtx Depth
d' (Unique -> Pos -> Scope -> Scope
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Unique
u Pos
pos' Scope
scope)

withBs :: (r ~ MarkCtx, MonadReader r m, PLC.HasUnique name PLC.TermUnique, PLC.HasUnique tyname PLC.TypeUnique)
       => NE.NonEmpty (Binding tyname name uni fun a3)
       -> Pos
       -> m a -> m a
withBs :: NonEmpty (Binding tyname name uni fun a3) -> Pos -> m a -> m a
withBs NonEmpty (Binding tyname name uni fun a3)
bs Pos
pos = (MarkCtx -> MarkCtx) -> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((MarkCtx -> MarkCtx) -> m a -> m a)
-> ((Scope -> Scope) -> MarkCtx -> MarkCtx)
-> (Scope -> Scope)
-> m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter MarkCtx MarkCtx Scope Scope
-> (Scope -> Scope) -> MarkCtx -> MarkCtx
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter MarkCtx MarkCtx Scope Scope
Lens' MarkCtx Scope
markCtxScope ((Scope -> Scope) -> m a -> m a) -> (Scope -> Scope) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ \Scope
scope ->
    [(Unique, Pos)] -> Scope
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Unique
bid, Pos
pos) | Unique
bid <- NonEmpty (Binding tyname name uni fun a3)
bsNonEmpty (Binding tyname name uni fun a3)
-> Getting
     (Endo [Unique]) (NonEmpty (Binding tyname name uni fun a3)) Unique
-> [Unique]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^..(Binding tyname name uni fun a3
 -> Const (Endo [Unique]) (Binding tyname name uni fun a3))
-> NonEmpty (Binding tyname name uni fun a3)
-> Const
     (Endo [Unique]) (NonEmpty (Binding tyname name uni fun a3))
forall (f :: * -> *) a b.
Traversable f =>
IndexedTraversal Int (f a) (f b) a b
traversed((Binding tyname name uni fun a3
  -> Const (Endo [Unique]) (Binding tyname name uni fun a3))
 -> NonEmpty (Binding tyname name uni fun a3)
 -> Const
      (Endo [Unique]) (NonEmpty (Binding tyname name uni fun a3)))
-> ((Unique -> Const (Endo [Unique]) Unique)
    -> Binding tyname name uni fun a3
    -> Const (Endo [Unique]) (Binding tyname name uni fun a3))
-> Getting
     (Endo [Unique]) (NonEmpty (Binding tyname name uni fun a3)) Unique
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Unique -> Const (Endo [Unique]) Unique)
-> Binding tyname name uni fun a3
-> Const (Endo [Unique]) (Binding tyname name uni fun a3)
forall tyname name (uni :: * -> *) fun a.
(HasUnique tyname TypeUnique, HasUnique name TermUnique) =>
Traversal1' (Binding tyname name uni fun a) Unique
bindingIds] Scope -> Scope -> Scope
forall a. Semigroup a => a -> a -> a
<> Scope
scope

-- A helper to apply a function iff recursive
ifRec :: Recursivity -> (a -> a) -> a -> a
ifRec :: Recursivity -> (a -> a) -> a -> a
ifRec Recursivity
r a -> a
f a
a = case Recursivity
r of
    Recursivity
Rec    -> a -> a
f a
a
    Recursivity
NonRec -> a
a

floatable :: PLC.ToBuiltinMeaning uni fun => BindingGrp tyname name uni fun a -> Bool
floatable :: BindingGrp tyname name uni fun a -> Bool
floatable (BindingGrp a
_ Recursivity
_ NonEmpty (Binding tyname name uni fun a)
bs) = (Binding tyname name uni fun a -> Bool)
-> NonEmpty (Binding tyname name uni fun a) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Binding tyname name uni fun a -> Bool
forall (uni :: * -> *) fun tyname name a.
ToBuiltinMeaning uni fun =>
Binding tyname name uni fun a -> Bool
hasNoEffects NonEmpty (Binding tyname name uni fun a)
bs

{-| Returns if a binding has absolutely no effects  (see Value.hs)
See Note [Purity, strictness, and variables]
An extreme alternative implementation is to treat *all strict* bindings as unfloatable, e.g.:
`hasNoEffects = \case {TermBind _ Strict _  _ -> False; _ -> True}`
-}
hasNoEffects :: PLC.ToBuiltinMeaning uni fun => Binding tyname name uni fun a -> Bool
hasNoEffects :: Binding tyname name uni fun a -> Bool
hasNoEffects = \case
    TypeBind{}               -> Bool
True
    DatatypeBind{}           -> Bool
True
    TermBind a
_ Strictness
NonStrict VarDecl tyname name uni fun a
_ Term tyname name uni fun a
_ -> Bool
True
    -- have to check for purity
    -- TODO: We could maybe do better here, but not worth it at the moment
    TermBind a
_ Strictness
Strict VarDecl tyname name uni fun a
_ Term tyname name uni fun a
t    -> (name -> Strictness) -> Term tyname name uni fun a -> Bool
forall (uni :: * -> *) fun name tyname a.
ToBuiltinMeaning uni fun =>
(name -> Strictness) -> Term tyname name uni fun a -> Bool
isPure (Strictness -> name -> Strictness
forall a b. a -> b -> a
const Strictness
NonStrict) Term tyname name uni fun a
t

-- | Breaks down linear let nonrecs by
-- the rule: {let nonrec (b:bs) in t} === {let nonrec b in let nonrec bs in t}
breakNonRec :: Term tyname name uni fun a -> Term tyname name uni fun a
breakNonRec :: Term tyname name uni fun a -> Term tyname name uni fun a
breakNonRec = \case
    Let a
a Recursivity
NonRec (NonEmpty (Binding tyname name uni fun a)
-> (Binding tyname name uni fun a,
    Maybe (NonEmpty (Binding tyname name uni fun a)))
forall a. NonEmpty a -> (a, Maybe (NonEmpty a))
NE.uncons -> (Binding tyname name uni fun a
b, Just NonEmpty (Binding tyname name uni fun a)
bs)) Term tyname name uni fun a
tIn  ->
      (a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let a
a Recursivity
NonRec (Binding tyname name uni fun a
-> NonEmpty (Binding tyname name uni fun a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding tyname name uni fun a
b) (Term tyname name uni fun a -> Term tyname name uni fun a)
-> Term tyname name uni fun a -> Term tyname name uni fun a
forall a b. (a -> b) -> a -> b
$ a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
forall tyname name (uni :: * -> *) fun a.
a
-> Recursivity
-> NonEmpty (Binding tyname name uni fun a)
-> Term tyname name uni fun a
-> Term tyname name uni fun a
Let a
a Recursivity
NonRec NonEmpty (Binding tyname name uni fun a)
bs Term tyname name uni fun a
tIn)
    Term tyname name uni fun a
t -> Term tyname name uni fun a
t

{- Note [Floating rhs-nested lets]

A nested let inside a let-rhs that depends on that rhs is for example:

let rec parent = (let (rec|nonrec) child = parent in ...)  in ...
OR
let rec grandparent = (let rec parent = (let (rec|nonrec) child = grandparent in ...) in ...) in ...

If such a child is floatable and its calculated float marker (maximum position)
is another let's position (e.g. parent or grandparent),
we have to float right inside the let-rhs and not right inside the let-interm.
However we lost the information in which specific rhs from the group of rhse's)of the (grand)parent let-group,
the dependent let came from.

Squeezing with such a parent, unfloatable let means that the parent let *must* be recursive.
Since the child let is depending on the parent let --- uses some parent-introduced variable(s) ---,
it is implied that the parent was originally rec, to begin with; we do not touch the original recursivity of an unfloatable let.

Note about squeezing order:
(floatable<>unfloatable) VERSUS (unfloatable<>floatable) does not matter, because it does not change the meaning.

The end result is that no nested, floatable let will appear anymore inside another let's rhs at the algorithm's output,
(e.g. invalid output:  let x=1+(let y=3 in y) in ...)
*EXCEPT* if the nested let is intercepted by a lam/Lam anchor (depends on a lam/Lam that is located inside the parent-let's rhs)
e.g. valid output: let x= \z -> (let y = 3+z in y) in ...
-}