{-# OPTIONS --safe #-}

import      Data.Nat as  renaming (_⊔_ to max)
import      Data.Integer as  renaming (_⊔_ to max)
import      Data.Integer.Properties as 
open import Data.Rational using (; floor; _*_; _÷_; _/_; _-_)
import      Data.Rational as  renaming (_⊓_ to min; _⊔_ to max)
open import Data.Rational.Literals using (number; fromℤ)
import      Data.Rational.Properties as 

open import Ledger.Conway.Abstract
open import Ledger.Conway.Transaction
open import Ledger.Conway.Types.Numeric.UnitInterval

open import Agda.Builtin.FromNat
open        Number number renaming (fromNat to fromℕ)

module Ledger.Conway.Rewards
  (txs : _) (open TransactionStructure txs)
  (abs : AbstractFunctions txs)
  where

open import Ledger.Conway.Certs govStructure
open import Ledger.Conway.Ledger txs abs
open import Ledger.Prelude hiding (_/_; _*_; _-_)
open import Ledger.Conway.Utxo txs abs



nonZero-max-1 :  (n : )  ℕ.NonZero (ℕ.max 1 n)
nonZero-max-1 zero = ℕ.nonZero
nonZero-max-1 (suc n) = ℕ.nonZero

nonZero-1/n :  (n : )  .{{_ : ℕ.NonZero n}}  ℚ.NonZero (1 / n)
nonZero-1/n n {{prf}} =
  ℚ.pos⇒nonZero (1 / n) {{ℚ.normalize-pos 1 n {{prf}} {{_}} }}

nonZero-1+max0-x :  (x : )  ℚ.NonZero (1 + ℚ.max 0 x)
nonZero-1+max0-x x =
  ℚ.>-nonZero (ℚ.+-mono-<-≤ (ℚ.positive⁻¹ 1) (ℚ.p≤p⊔q 0 x))

private instance
  nonNegative :  {i}  ℤ.NonNegative (ℤ.max 0 i)
  nonNegative {i} = ℤ.nonNegative (ℤ.i≤i⊔j 0 i)


maxPool : PParams  Coin  UnitInterval  UnitInterval  Coin
maxPool pparams rewardPot stake pledge = rewardℕ
  where
    a0      = ℚ.max 0 (pparams .PParams.a0)
    1+a0    = 1 + a0
    nopt    = ℕ.max 1 (pparams .PParams.nopt)


    instance
      nonZero-nopt : ℕ.NonZero nopt
      nonZero-nopt = nonZero-max-1 (pparams .PParams.nopt)


    z0       = 1 / nopt
    stake'   = ℚ.min (fromUnitInterval stake) z0
    pledge'  = ℚ.min (fromUnitInterval pledge) z0


    instance
      nonZeroz0 : ℚ.NonZero z0
      nonZeroz0 = nonZero-1/n nopt

      nonZero-1+a0 : ℚ.NonZero (1+a0)
      nonZero-1+a0 = nonZero-1+max0-x (pparams .PParams.a0)


    rewardℚ =
        ((fromℕ rewardPot) ÷ 1+a0)
        * (stake' + pledge' * a0 * (stake' - pledge' * (z0 - stake') ÷ z0) ÷ z0)
    rewardℕ = posPart (floor rewardℚ)


mkApparentPerformance : UnitInterval      
mkApparentPerformance stake poolBlocks totalBlocks = ratioBlocks ÷₀ stake'
  where
    stake' = fromUnitInterval stake


    instance
      nonZero-totalBlocks : ℕ.NonZero (ℕ.max 1 totalBlocks)
      nonZero-totalBlocks = nonZero-max-1 totalBlocks


    ratioBlocks = (ℤ.+ poolBlocks) / (ℕ.max 1 totalBlocks)


rewardOwners : Coin  PoolParams  UnitInterval  UnitInterval  Coin
rewardOwners rewards pool ownerStake stake = if rewards  cost
  then rewards
  else cost + posPart (floor (
        (fromℕ rewards - fromℕ cost) * (margin + (1 - margin) * ratioStake)))
  where
    ratioStake   = fromUnitInterval ownerStake ÷₀ fromUnitInterval stake
    cost         = pool .PoolParams.cost
    margin       = fromUnitInterval (pool .PoolParams.margin)


rewardMember : Coin  PoolParams  UnitInterval  UnitInterval  Coin
rewardMember rewards pool memberStake stake = if rewards  cost
  then 0
  else posPart (floor (
         (fromℕ rewards - fromℕ cost) * ((1 - margin) * ratioStake)))
  where
    ratioStake    = fromUnitInterval memberStake ÷₀ fromUnitInterval stake
    cost          = pool .PoolParams.cost
    margin        = fromUnitInterval (pool .PoolParams.margin)


Stake = Credential  Coin

rewardOnePool : PParams  Coin      PoolParams
   Stake  UnitInterval  UnitInterval  Coin  (Credential  Coin)
rewardOnePool pparams rewardPot n N pool stakeDistr σ σa tot = rewards
  where
    mkRelativeStake = λ coin  clamp (coin /₀ tot)
    owners = mapˢ KeyHashObj (pool .PoolParams.owners) 
    ownerStake = ∑[ c  stakeDistr  owners ] c
    pledge = pool .PoolParams.pledge
    maxP = if pledge  ownerStake
      then maxPool pparams rewardPot σ (mkRelativeStake pledge)
      else 0
    apparentPerformance = mkApparentPerformance σa n N
    poolReward = posPart (floor (apparentPerformance * fromℕ maxP))
    memberRewards =
      mapValues  coin  rewardMember poolReward pool (mkRelativeStake coin) σ)
        (stakeDistr  owners )
    ownersRewards  =
       pool .PoolParams.rewardAccount
      , rewardOwners poolReward pool (mkRelativeStake ownerStake) σ ❵ᵐ
    rewards = memberRewards ∪⁺ ownersRewards


Delegations = Credential  KeyHash

poolStake  : KeyHash  Delegations  Stake  Stake
poolStake hk delegs stake = stake  dom (delegs ∣^  hk )


BlocksMade = KeyHash  

uncurryᵐ :


   {A B C : Type}  _ : DecEq A   _ : DecEq B  


  A  (B  C)  (A × B)  C


uncurryᵐ {A} {B} {C} abc = mapFromPartialFun lookup' domain'
  where
    lookup' : (A × B)  Maybe C
    lookup' (a , b) = lookupᵐ? abc a >>=  bc  lookupᵐ? bc b)

    joinˢ :  {X}   ( X)   X
    joinˢ = concatMapˢ id

    domain' :  (A × B)
    domain' = joinˢ (range (mapWithKey  a bc  range (mapWithKey  b _  (a , b)) bc)) abc))



reward : PParams  BlocksMade  Coin  (KeyHash  PoolParams)
   Stake  Delegations  Coin  (Credential  Coin)
reward pp blocks rewardPot poolParams stake delegs total = rewards
  where
    active = ∑[ c  stake ] c
    Σ_/total = λ st  clamp ((∑[ c  st ] c) /₀ total)
    Σ_/active = λ st  clamp ((∑[ c  st ] c) /₀ active)
    N = ∑[ m  blocks ] m
    mkPoolData = λ hk p 
      map  n  (n , p , poolStake hk delegs stake)) (lookupᵐ? blocks hk)
    pdata = mapMaybeWithKeyᵐ mkPoolData poolParams

    results  : (KeyHash × Credential)  Coin
    results = uncurryᵐ (mapValues  (n , p , s)
       rewardOnePool pp rewardPot n N p s (Σ s /total) (Σ s /active) total)
      pdata)
    rewards  = aggregateBy
      (mapˢ  (kh , cred)  (kh , cred) , cred) (dom results))
      results


record RewardUpdate : Set where


  constructor ⟦_,_,_,_⟧ʳᵘ


  field
    Δt Δr Δf : 
    rs : Credential  Coin


record Snapshot : Set where
  field
    stake           : Credential  Coin
    delegations     : Credential  KeyHash
    poolParameters  : KeyHash  PoolParams



instance
  unquoteDecl HasCast-Snapshot =
    derive-HasCast [ (quote Snapshot , HasCast-Snapshot) ]


private
  getStakeCred : TxOut  Maybe Credential
  getStakeCred (a , _ , _ , _) = stakeCred a


stakeDistr : UTxO  DState  PState  Snapshot
stakeDistr utxo stᵈ pState =
     aggregate₊ (stakeRelation ᶠˢ) , stakeDelegs , poolParams 
  where
    poolParams = pState .PState.pools
    open DState stᵈ using (stakeDelegs; rewards)
    m = mapˢ  a  (a , cbalance (utxo ∣^' λ i  getStakeCred i  just a))) (dom rewards)
    stakeRelation = m   rewards 


record Snapshots : Set where
  field
    mark set go  : Snapshot
    feeSS        : Coin



instance
  unquoteDecl HasCast-Snapshots =
    derive-HasCast [ (quote Snapshots , HasCast-Snapshots) ]


private variable
  lstate : LState
  mark set go : Snapshot
  feeSS : Coin


data _⊢_⇀⦇_,SNAP⦈_ : LState  Snapshots    Snapshots  Type where
  SNAP : let open LState lstate; open UTxOState utxoSt; open CertState certState
             stake = stakeDistr utxo dState pState
    in
    lstate   mark , set , go , feeSS  ⇀⦇ tt ,SNAP⦈  stake , mark , set , fees