{-# LANGUAGE FlexibleContexts #-}
-- | Functions for computing variable usage inside terms.
module UntypedPlutusCore.Analysis.Usages (runTermUsages, Usages, getUsageCount, allUsed) where

import UntypedPlutusCore.Core.Plated
import UntypedPlutusCore.Core.Type

import PlutusCore qualified as PLC
import PlutusCore.Name qualified as PLC

import Control.Lens
import Control.Monad.State

import Data.Coerce
import Data.Foldable
import Data.Map qualified as Map
import Data.Set qualified as Set

-- | Variable uses, as a map from the 'PLC.Unique' to its usage count. Unused variables may be missing
-- or have usage count 0.
type Usages = Map.Map PLC.Unique Int

addUsage :: (PLC.HasUnique n unique) => n -> Usages -> Usages
addUsage :: n -> Usages -> Usages
addUsage n
n Usages
usages =
    let
        u :: Unique
u = unique -> Unique
coerce (unique -> Unique) -> unique -> Unique
forall a b. (a -> b) -> a -> b
$ n
n n -> Getting unique n unique -> unique
forall s a. s -> Getting a s a -> a
^. Getting unique n unique
forall a unique. HasUnique a unique => Lens' a unique
PLC.unique
        old :: Int
old = Int -> Unique -> Usages -> Int
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Int
0 Unique
u Usages
usages
    in Unique -> Int -> Usages -> Usages
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Unique
u (Int
oldInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Usages
usages

-- | Get the usage count of @n@.
getUsageCount :: (PLC.HasUnique n unique) => n -> Usages -> Int
getUsageCount :: n -> Usages -> Int
getUsageCount n
n Usages
usages = Int -> Unique -> Usages -> Int
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Int
0 (n
n n -> Getting Unique n Unique -> Unique
forall s a. s -> Getting a s a -> a
^. (unique -> Const Unique unique) -> n -> Const Unique n
forall a unique. HasUnique a unique => Lens' a unique
PLC.unique ((unique -> Const Unique unique) -> n -> Const Unique n)
-> ((Unique -> Const Unique Unique)
    -> unique -> Const Unique unique)
-> Getting Unique n Unique
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Unique -> Const Unique Unique) -> unique -> Const Unique unique
forall s t a b. (Coercible s a, Coercible t b) => Iso s t a b
coerced) Usages
usages

-- | Get a set of @n@s which are used at least once.
allUsed :: Usages -> Set.Set PLC.Unique
allUsed :: Usages -> Set Unique
allUsed Usages
usages = Usages -> Set Unique
forall k a. Map k a -> Set k
Map.keysSet (Usages -> Set Unique) -> Usages -> Set Unique
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> Usages -> Usages
forall a k. (a -> Bool) -> Map k a -> Map k a
Map.filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) Usages
usages

-- | Compute the 'Usages' for a 'Term'.
runTermUsages
    :: (PLC.HasUnique name PLC.TermUnique)
    => Term name uni fun a
    -> Usages
runTermUsages :: Term name uni fun a -> Usages
runTermUsages Term name uni fun a
term = State Usages () -> Usages -> Usages
forall s a. State s a -> s -> s
execState (Term name uni fun a -> State Usages ()
forall (m :: * -> *) name (uni :: * -> *) fun a.
(MonadState Usages m, HasUnique name TermUnique) =>
Term name uni fun a -> m ()
termUsages Term name uni fun a
term) Usages
forall a. Monoid a => a
mempty

termUsages
    :: (MonadState Usages m, PLC.HasUnique name PLC.TermUnique)
    => Term name uni fun a
    -> m ()
termUsages :: Term name uni fun a -> m ()
termUsages (Var a
_ name
n) = (Usages -> Usages) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (name -> Usages -> Usages
forall n unique. HasUnique n unique => n -> Usages -> Usages
addUsage name
n)
termUsages Term name uni fun a
term      = (Term name uni fun a -> m ()) -> [Term name uni fun a] -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Term name uni fun a -> m ()
forall (m :: * -> *) name (uni :: * -> *) fun a.
(MonadState Usages m, HasUnique name TermUnique) =>
Term name uni fun a -> m ()
termUsages (Term name uni fun a
term Term name uni fun a
-> Getting
     (Endo [Term name uni fun a])
     (Term name uni fun a)
     (Term name uni fun a)
-> [Term name uni fun a]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. Getting
  (Endo [Term name uni fun a])
  (Term name uni fun a)
  (Term name uni fun a)
forall name (uni :: * -> *) fun ann.
Traversal' (Term name uni fun ann) (Term name uni fun ann)
termSubterms)