{-# OPTIONS -v allTactics:100 #-}
{-# OPTIONS --safe #-}
module iog-prelude.Tactic.DeriveComp where

open import Prelude.Init hiding (Type; mapMaybe)
open import Data.Bool using (true; false)
open import Data.List using (List; [_]; dropWhile; take; takeWhile; mapMaybe)
open import Data.List.NonEmpty as NE using (fromList)
open import Data.Maybe using (Maybe; just; nothing; from-just)
open import Data.Maybe.Properties using (just-injective)
open import Data.Nat using (; _∸_)
open import Data.String using (String; _<+>_)
open import Class.DecEq
open import Class.Decidable
open import Class.MonadTC
open import Class.Show
open import Function using (mk⇔)
open import Meta.Init
open import Meta.Prelude using (zipWithIndex)
open import Reflection.AST.DeBruijn using (weaken)
open import stdlib-meta.Reflection using (extendContextTel′)
open import Reflection.Utils hiding (mkRecord)
open import Reflection.Utils.TCI
open import Reflection.Tactic
open import Relation.Nullary.Decidable using (_because_)
open import Relation.Nullary.Reflects using (ofʸ; ofⁿ)
open import Class.MonadTC
  hiding (extendContext)
open import Class.MonadError
  using (MonadError; MonadError-TC)
open import Class.Functor
open import Class.Traversable
open import Tactic.ClauseBuilder
open import Tactic.ReduceDec
open import Tactic.Defaults

open import Class.Computational

open MonadTC ⦃...⦄
open MonadError ⦃...⦄ using (error; catch)

instance
  defaultDebugOptionsI : DebugOptions
  defaultDebugOptionsI = record defaultDebugOptions
    { selection = All
    ; filter = Filter.⊥ }

pattern ``yes x = quote _because_ ◇⟦ quote true    quote ofʸ ◇⟦ x  
pattern ``no x  = quote _because_ ◇⟦ quote false   quote ofⁿ ◇⟦ x  

instance _ = ContextMonad-MonadTC

record STSConstr : Set where
  field
    name    : Name
    params  : List $ Abs $ Arg Type
    clauses : List Type
    context : Pattern
    state   : Pattern
    signal  : Pattern
    result  : Term

conOrVarToPattern :   Term  Maybe Pattern
conOrVarToPattern k ( v) = just (Pattern.var (v  k))
conOrVarToPattern k (con c args) =
  Pattern.con c <$> (sequence $ conOrVarToPattern′ k args)
  where
    conOrVarToPattern′ :   List (Arg Term)  List (Maybe (Arg Pattern))
    conOrVarToPattern′ k = λ where
      []  []
      ((arg i x)  l)  (arg i <$> conOrVarToPattern k x)  conOrVarToPattern′ k l
conOrVarToPattern _ _ = nothing

isArg : (a : Abs (Arg Term))  Dec _
isArg a = ¬? (getVisibility (unAbs a)  visible)

toSTSConstr : Name × TypeView  TC STSConstr
toSTSConstr (n , (cs , def _ args))
  with args | mapMaybe (conOrVarToPattern (length $ dropWhile isArg cs)  unArg)
                       (take 3 args)
... | _  _  _  r  [] | c  s  sig  [] =
  return record
    { name = n
    ; params = takeWhile isArg cs
    ; clauses = zipWithIndex  i  mapVars (_∸ i))
              $ (unArg  unAbs) <$> dropWhile isArg cs
    ; context = c
    ; state = s
    ; signal = sig
    ; result = mapVars (_∸ (length $ dropWhile isArg cs)) $ unArg r }
... | l | l' =
  error1 ("toSTSConstr: wrong number of arguments:"
    <+> show (length l) <+> "," <+> show (length l'))
toSTSConstr _ = error1 "toSTSConstr: wrong constructor"

errorIfNothing :  {a} {A : Set a}  Maybe A  String  TC A
errorIfNothing (just x) s = return x
errorIfNothing nothing s = error1 s

getSTSConstrs : Name  TC (List STSConstr)
getSTSConstrs n = inDebugPath "getSTSConstrs" do
  dataDef  getDataDef n
  res  traverse toSTSConstr (DataDef.constructors dataDef)
  debugLogᵐ (res ᵛⁿ ∷ᵈᵐ []ᵐ)
  return res

generatePred : List Term  Term
generatePred clauses = quote ¿_¿ ∙⟦ helper clauses 
  where
    helper : List Term  Term
    helper [] = quoteTerm 
    helper (x  []) = x
    helper (x  x₁  l) = quote _×_ ∙⟦ x  helper (x₁  l) 

predWitness :   Term
predWitness 0 = quoteTerm tt
predWitness (suc n) = helper n
  where
    helper :   Term
    helper 0 =  0
    helper (suc n) = quote _,_ ◆⟦  (suc n)  helper n 

curryPredProof :   Term  List (Arg Term)
curryPredProof 0 t = []
curryPredProof 1 t = [ vArg t ]
curryPredProof 2 t = quote proj₁ ∙⟦ t  ⟨∷⟩ quote proj₂ ∙⟦ t  ⟨∷⟩ []
curryPredProof (suc (suc (suc k))) t =
  quote proj₁ ∙⟦ t  ⟨∷⟩ curryPredProof (suc (suc k)) (quote proj₂ ∙⟦ t )

computeFunctionBody : Term  Term  Term
computeFunctionBody g result =
  quote if_then_else_ ∙⟦ g  quote just ◆⟦ result   quote nothing  
  where open import Data.Bool

generateFunctionClause : (List Term  Term)  STSConstr  Clause
generateFunctionClause genPred c = let open STSConstr c in
   context  state  signal   { (abs s x)  (s , x) }) <$> params ⦆⇒
    computeFunctionBody (quote ⌊_⌋ ∙⟦ genPred clauses ) result 

generateFunction : List STSConstr  Term
generateFunction c = pat-lam (generateFunctionClause generatePred <$> c) []

rdOpts : ReductionOptions
rdOpts = onlyReduce [ quote ⌊_⌋ ]

open ClauseExprM

derive⇐ : List STSConstr  ITactic
derive⇐ (record { clauses = clauses; result = result }  []) = inDebugPath "derive⇐" do
  (constrPat  [])  currentTyConstrPatterns
    where _  error1 "TODO: Support more than one constructor!"
  expr  singleMatchExpr constrPat $ finishMatch $ withGoalHole $ do
    let n = 1 + length clauses
        pred = weaken n $ generatePred clauses
        scheme =  "g"  (quote _≡_ ∙⟦ computeFunctionBody ( 0) (weaken (1 + n) result)
                                       quote just ◆⟦ weaken (1 + n) result  )
        eq = quote fromWitness' ∙⟦ pred  predWitness (length clauses) 
    unifyWithGoal $ quote subst ∙⟦ scheme  quote sym ∙⟦ eq   quote refl  
  unifyWithGoal $ clauseExprToPatLam expr
derive⇐ _ = error1 "TODO: support multiple constructors"

derive⇒ : Name  List STSConstr  ITactic
derive⇒ n
  (record { name = stsConstrName ; clauses = clauses; result = result }  [])
  = inDebugPath "derive⇒" do
    let pred = weaken 3 (generatePred clauses)
    expr  introsExpr (from-just $ NE.fromList ("h" ⟨∷⟩ [])) $ finishMatch $
      caseMatch (mapVars (_+ 2) $ generatePred clauses) $ matchExprM
        ((multiSinglePattern [ "" ] (vArg (``no  (` 0))) , finishMatch do
          let reducedHyp = quote subst
                ∙⟦  "g"  (quote _≡_ ∙⟦ computeFunctionBody ( 0) (weaken 4 result)
                                         quote just ◆⟦  3  )
                  quote fromWitnessFalse' ∙⟦ pred   0 
                   1
                 
          return $ quote case_of_ ∙⟦ reducedHyp  `λ∅ ) 
        (multiSinglePattern [ "" ] (vArg (``yes (` 0))) , finishMatch do
          let reducedHyp = quote subst
                ∙⟦  "g"  (quote _≡_ ∙⟦ computeFunctionBody ( 0) (weaken 4 result)
                                         quote just ◆⟦  3  )
                  quote fromWitness' ∙⟦ pred   0 
                   1
                 
          ty ∙⟦ c  s  sig  s'   goalTy
            where ty  error ("BUG: Unexpected type" ∷ᵈ ty ∷ᵈ [])
          return $ quote subst
            ∙⟦ ty ∙⟦ c  s  sig 
              quote just-injective ∙⟦ reducedHyp 
              con stsConstrName (curryPredProof (length clauses) ( 0))
             )  [])
    unifyWithGoal $ clauseExprToPatLam expr
derive⇒ _ _ = error1 "TODO: support multiple constructors"

module _  _ : DebugOptions  where
  derive⇔ : Name  List STSConstr  Tactic
  derive⇔ n stsConstrs = initTac $ inDebugPath "derive⇔" do
    hole⇒  newMeta unknown
    hole⇐  newMeta unknown
    unifyWithGoal $ quote mk⇔ ∙⟦ hole⇒  hole⇐ 
    runWithHole hole⇐ $ derive⇐ stsConstrs
    runWithHole hole⇒ $ derive⇒ n stsConstrs

  deriveComp : Name  List STSConstr  TC 
  deriveComp definedType stsConstrs = do
    debugLog ("\nDerive computation function for: " ∷ᵈ definedType ∷ᵈ [])
    unifyWithGoal (generateFunction stsConstrs)

  deriveComputational : Name  Name  UnquoteDecl
  deriveComputational n compName =
    initUnquoteWithGoal (quoteTerm Name) $ inDebugPath "deriveComputational" do
      let goalTy = quote Computational ∙⟦ n  
      debugLog1 goalTy
      declareDef (vArg compName) goalTy
      stsConstrs  getSTSConstrs n
      compRes  withSafeReset $ do
        compHole  newMeta unknown
        equivHole  newMeta unknown
        definition  mkRecord (quote Computational) (compHole ⟨∷⟩ equivHole ⟨∷⟩ [])
        _  checkType definition goalTy
        debugLog1ᵐ (compHole )
        runWithHole compHole $ deriveComp n stsConstrs
        reduce compHole
      debugLog ("compRes: " ∷ᵈ compRes ∷ᵈ [])
      cTy    newMeta unknown
      sTy    newMeta unknown
      sigTy  newMeta unknown
      let isoCxt = Telescope  ("c" , hArg cTy)
                  ("s" , hArg (weaken 1 sTy))
                  ("sig" , hArg (weaken 2 sigTy))
                  ("s'" , hArg (weaken 3 sTy))
                  []
      iso  extendContextTel′ isoCxt $ newMeta unknown
      let isolam = `λ⟅ "c" ⟆⇒ `λ⟅ "s" ⟆⇒ `λ⟅ "sig" ⟆⇒ `λ⟅ "s'" ⟆⇒ iso
      definition  mkRecord (quote Computational) (compRes ⟨∷⟩ isolam ⟨∷⟩ [])
      defineFun compName [ nonBindingClause definition ]
      extendContextTel′ isoCxt $ λ _  derive⇔ n stsConstrs iso

-- private module _ {A B : Set} ⦃ _ : DecEq A ⦄ ⦃ _ : DecEq B ⦄ where
--   variable
--     c : A × B
--     s s' : A
--     sig : B

--   data Test : A × B → A → B → A → Set where
--     test : proj₁ c ≡ s
--          → proj₂ c ≡ sig
--          ------------------------
--          → Test c s sig (proj₁ c)

--   unquoteDecl Computational-Test = deriveComputational (quote Test) Computational-Test

--   -- Sanity checks
--   testFun : A × B → A → B → Maybe A
--   testFun c s sig =
--     if ⌊ ¿ proj₁ c ≡ s × proj₂ c ≡ sig ¿ ⌋ then just (proj₁ c) else nothing

--   testFunPf⇒ : testFun c s sig ≡ just s' → Test c s sig s'
--   testFunPf⇒ {c} {s} {sig} h = case ¿ proj₁ c ≡ s × proj₂ c ≡ sig ¿ of λ where
--     (no ¬p) → case by-reduceDec h of λ ()
--     (yes p) → subst (Test c s sig) (just-injective $ by-reduceDec h)
--             $ test (proj₁ p) (proj₂ p)

--   testFunPf⇐ : Test c s sig s' → testFun c s sig ≡ just s'
--   testFunPf⇐ {s = s} (test x x₁) = by-reduceDecInGoal (refl {x = just s})

--   Computational-Test-Manual : Computational Test
--   Computational-Test-Manual = record
--     { compute = testFun ; ≡-just⇔STS = mk⇔ testFunPf⇒ testFunPf⇐ }

--   _ : ∀ {c s sig} → testFun c s sig ≡ Computational.compute Computational-Test c s sig
--   _ = refl