module MCG (
    MCG,
    make,
    period,
    next,
    reject,
) where

import           Data.Bits (countLeadingZeros, unsafeShiftR)
import           Data.List (nub)
import           Data.Numbers.Primes (isPrime, primeFactors)
import           Data.Word (Word64)

-- $setup
-- >>> import Data.List (unfoldr, nub)

-- | https://en.wikipedia.org/wiki/Lehmer_random_number_generator
data MCG = MCG { MCG -> Word64
m :: !Word64, MCG -> Word64
a :: !Word64, MCG -> Word64
x :: !Word64 }
  deriving stock Int -> MCG -> ShowS
[MCG] -> ShowS
MCG -> String
(Int -> MCG -> ShowS)
-> (MCG -> String) -> ([MCG] -> ShowS) -> Show MCG
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MCG -> ShowS
showsPrec :: Int -> MCG -> ShowS
$cshow :: MCG -> String
show :: MCG -> String
$cshowList :: [MCG] -> ShowS
showList :: [MCG] -> ShowS
Show

-- invariants: m is a prime
--             a is a primitive element of Z_m
--             x is in [1..m-1]

-- | Create a MCG
--
-- >>> make 20 04
-- MCG {m = 23, a = 11, x = 5}
--
-- >>> make 101_000_000 20240429
-- MCG {m = 101000023, a = 197265, x = 20240430}
--
make ::
       Word64  -- ^ a lower bound for the period
    -> Word64  -- ^ initial seed.
    -> MCG
make :: Word64 -> Word64 -> MCG
make (Word64 -> Word64 -> Word64
forall a. Ord a => a -> a -> a
max Word64
4 -> Word64
period_) Word64
seed = Word64 -> Word64 -> Word64 -> MCG
MCG Word64
m Word64
a (Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
mod (Word64
seed Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1) Word64
m)
  where
    -- start prime search from an odd number larger than asked period.
    m :: Word64
m  = Word64 -> Word64
forall {t}. Integral t => t -> t
findM (if Word64 -> Bool
forall a. Integral a => a -> Bool
odd Word64
period_ then Word64
period_ Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
2 else Word64
period_ Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1)
    m' :: Word64
m' = Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1
    qs :: [Word64]
qs = [Word64] -> [Word64]
forall a. Eq a => [a] -> [a]
nub ([Word64] -> [Word64]) -> [Word64] -> [Word64]
forall a b. (a -> b) -> a -> b
$ Word64 -> [Word64]
forall int. Integral int => int -> [int]
primeFactors Word64
m'

    a :: Word64
a = Word64 -> Word64
findA (Word64 -> Word64
guessA Word64
m)

    findM :: t -> t
findM t
p = if t -> Bool
forall a. Integral a => a -> Bool
isPrime t
p then t
p else t -> t
findM (t
p t -> t -> t
forall a. Num a => a -> a -> a
+ t
2)

    -- we find `a` using "brute-force" approach.
    -- luckily, many elements a prime factors, so we don't need to try too hard.
    -- and we only need to check prime factors of m - 1.
    findA :: Word64 -> Word64
findA Word64
x
        | (Word64 -> Bool) -> [Word64] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\Word64
q -> Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
mod (Word64
x Word64 -> Word64 -> Word64
forall a b. (Num a, Integral b) => a -> b -> a
^ Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
div Word64
m' Word64
q) Word64
m Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word64
1) [Word64]
qs
        = Word64
x

        | Bool
otherwise
        = Word64 -> Word64
findA (Word64
x Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1)

-- | Period of the MCG.
--
-- Period is usually a bit larger than asked for, we look for the next prime:
--
-- >>> let g = make 9 04
-- >>> period g
-- 10
--
-- >>> take 22 (unfoldr (Just . next) g)
-- [4,7,3,1,0,5,2,6,8,9,4,7,3,1,0,5,2,6,8,9,4,7]
--
period :: MCG -> Word64
period :: MCG -> Word64
period (MCG Word64
m Word64
_ Word64
_) = Word64
m Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1

-- | Generate next number.
next :: MCG -> (Word64, MCG)
next :: MCG -> (Word64, MCG)
next (MCG Word64
m Word64
a Word64
x) = (Word64
x Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1, Word64 -> Word64 -> Word64 -> MCG
MCG Word64
m Word64
a (Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
mod (Word64
x Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
a) Word64
m))

-- | Generate next numbers until one less than given bound is generated.
--
-- Replacing 'next' with @'reject' n@ effectively cuts the period to @n@:
--
-- >>> let g = make 9 04
-- >>> period g
-- 10
--
-- >>> take 22 (unfoldr (Just . reject 9) g)
-- [4,7,3,1,0,5,2,6,8,4,7,3,1,0,5,2,6,8,4,7,3,1]
--
-- if @n@ is close enough to actual period of 'MCG', the rejection ratio
-- is very small.
--
reject :: Word64 -> MCG -> (Word64, MCG)
reject :: Word64 -> MCG -> (Word64, MCG)
reject Word64
ub MCG
g = case MCG -> (Word64, MCG)
next MCG
g of
    (Word64
x, MCG
g') -> if Word64
x Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
ub then (Word64
x, MCG
g') else Word64 -> MCG -> (Word64, MCG)
reject Word64
ub MCG
g'

-------------------------------------------------------------------------------
-- guessing some initial a
-------------------------------------------------------------------------------

-- | calculate x -> log2 (x + 1) i.e. approximate how large the number is in bits.
word64Log2m1 :: Word64 -> Int
word64Log2m1 :: Word64 -> Int
word64Log2m1 Word64
x = Int
64 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Word64 -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros Word64
x

-- | we guess a such that a*a is larger than m:
-- we shift a number a little.
guessA :: Word64 -> Word64
guessA :: Word64 -> Word64
guessA Word64
x = Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR Word64
x (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div (Word64 -> Int
word64Log2m1 Word64
x) Int
3)