{-# LANGUAGE BangPatterns #-}

module Database.LSMTree.Extras.Random (
    -- * Sampling from uniform distributions
    uniformWithoutReplacement
  , uniformWithReplacement
  , sampleUniformWithoutReplacement
  , sampleUniformWithReplacement
  , withoutReplacement
  , withReplacement
    -- * Sampling from multiple distributions
  , frequency
    -- * Generators for specific data types
  , randomByteStringR
  ) where

import qualified Data.ByteString as BS
import           Data.List (unfoldr)
import qualified Data.Set as Set
import qualified System.Random as R
import           System.Random (StdGen, Uniform, uniform, uniformR)
import           Text.Printf (printf)

{-------------------------------------------------------------------------------
  Sampling from uniform distributions
-------------------------------------------------------------------------------}

uniformWithoutReplacement :: (Ord a, Uniform a) => StdGen -> Int -> [a]
uniformWithoutReplacement :: forall a. (Ord a, Uniform a) => StdGen -> Int -> [a]
uniformWithoutReplacement StdGen
rng Int
n = StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
forall a. Ord a => StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
withoutReplacement StdGen
rng Int
n StdGen -> (a, StdGen)
forall g a. (RandomGen g, Uniform a) => g -> (a, g)
uniform

uniformWithReplacement :: Uniform a => StdGen -> Int -> [a]
uniformWithReplacement :: forall a. Uniform a => StdGen -> Int -> [a]
uniformWithReplacement StdGen
rng Int
n = StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
forall a. StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
withReplacement StdGen
rng Int
n StdGen -> (a, StdGen)
forall g a. (RandomGen g, Uniform a) => g -> (a, g)
uniform

sampleUniformWithoutReplacement :: Ord a => StdGen -> Int -> [a] -> [a]
sampleUniformWithoutReplacement :: forall a. Ord a => StdGen -> Int -> [a] -> [a]
sampleUniformWithoutReplacement StdGen
rng0 Int
n ([a] -> Set a
forall a. Ord a => [a] -> Set a
Set.fromList -> Set a
xs0)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Set a -> Int
forall a. Set a -> Int
Set.size Set a
xs0 =
      [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [a]) -> [Char] -> [a]
forall a b. (a -> b) -> a -> b
$
        [Char] -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"sampleUniformWithoutReplacement: n > length xs0 for n=%d, \
               \ length xs0=%d"
               Int
n
               (Set a -> Int
forall a. Set a -> Int
Set.size Set a
xs0)
  | Bool
otherwise =
      -- Could use 'withoutReplacement', but this is more efficient.
      Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
n ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ Set a -> StdGen -> [a]
forall {t} {a}. RandomGen t => Set a -> t -> [a]
go Set a
xs0 StdGen
rng0
  where
    go :: Set a -> t -> [a]
go !Set a
xs !t
rng = a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Set a -> t -> [a]
go Set a
xs' t
rng'
      where
        (Int
i, t
rng') = (Int, Int) -> t -> (Int, t)
forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
uniformR (Int
0, Set a -> Int
forall a. Set a -> Int
Set.size Set a
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) t
rng
        !x :: a
x        = Int -> Set a -> a
forall a. Int -> Set a -> a
Set.elemAt Int
i Set a
xs
        !xs' :: Set a
xs'      = Int -> Set a -> Set a
forall a. Int -> Set a -> Set a
Set.deleteAt Int
i Set a
xs

sampleUniformWithReplacement :: Ord a => StdGen -> Int -> [a] -> [a]
sampleUniformWithReplacement :: forall a. Ord a => StdGen -> Int -> [a] -> [a]
sampleUniformWithReplacement StdGen
rng0 Int
n ([a] -> Set a
forall a. Ord a => [a] -> Set a
Set.fromList -> Set a
xs) =
    StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
forall a. StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
withReplacement StdGen
rng0 Int
n ((StdGen -> (a, StdGen)) -> [a]) -> (StdGen -> (a, StdGen)) -> [a]
forall a b. (a -> b) -> a -> b
$ \StdGen
rng ->
      let (Int
i, StdGen
rng') = (Int, Int) -> StdGen -> (Int, StdGen)
forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
uniformR (Int
0, Set a -> Int
forall a. Set a -> Int
Set.size Set a
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) StdGen
rng
      in  (Int -> Set a -> a
forall a. Int -> Set a -> a
Set.elemAt Int
i Set a
xs, StdGen
rng')

withoutReplacement :: Ord a => StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
withoutReplacement :: forall a. Ord a => StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
withoutReplacement StdGen
rng0 Int
n0 StdGen -> (a, StdGen)
sample = Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
n0 ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$
    Set a -> StdGen -> [a]
go Set a
forall a. Set a
Set.empty StdGen
rng0
  where
    go :: Set a -> StdGen -> [a]
go !Set a
seen !StdGen
rng
        | a -> Set a -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member a
x Set a
seen =     Set a -> StdGen -> [a]
go               Set a
seen  StdGen
rng'
        | Bool
otherwise         = a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Set a -> StdGen -> [a]
go (a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
Set.insert a
x Set a
seen) StdGen
rng'
      where
        (!a
x, !StdGen
rng') = StdGen -> (a, StdGen)
sample StdGen
rng

withReplacement :: StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
withReplacement :: forall a. StdGen -> Int -> (StdGen -> (a, StdGen)) -> [a]
withReplacement StdGen
rng0 Int
n0 StdGen -> (a, StdGen)
sample =
    Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
n0 ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (StdGen -> Maybe (a, StdGen)) -> StdGen -> [a]
forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr ((a, StdGen) -> Maybe (a, StdGen)
forall a. a -> Maybe a
Just ((a, StdGen) -> Maybe (a, StdGen))
-> (StdGen -> (a, StdGen)) -> StdGen -> Maybe (a, StdGen)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StdGen -> (a, StdGen)
sample) StdGen
rng0

{-------------------------------------------------------------------------------
  Sampling from multiple distributions
-------------------------------------------------------------------------------}

-- | Chooses one of the given generators, with a weighted random distribution.
-- The input list must be non-empty, weights should be non-negative, and the sum
-- of weights should be non-zero (i.e., at least one weight should be positive).
--
-- Based on the implementation in @QuickCheck@.
frequency :: [(Int, StdGen -> (a, StdGen))] -> StdGen -> (a, StdGen)
frequency :: forall a. [(Int, StdGen -> (a, StdGen))] -> StdGen -> (a, StdGen)
frequency [(Int, StdGen -> (a, StdGen))]
xs0 StdGen
g
  | ((Int, StdGen -> (a, StdGen)) -> Bool)
-> [(Int, StdGen -> (a, StdGen))] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Int -> Bool)
-> ((Int, StdGen -> (a, StdGen)) -> Int)
-> (Int, StdGen -> (a, StdGen))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, StdGen -> (a, StdGen)) -> Int
forall a b. (a, b) -> a
fst) [(Int, StdGen -> (a, StdGen))]
xs0 = [Char] -> (a, StdGen)
forall a. HasCallStack => [Char] -> a
error [Char]
"frequency: frequencies must be non-negative"
  | Int
tot Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0              = [Char] -> (a, StdGen)
forall a. HasCallStack => [Char] -> a
error [Char]
"frequency: at least one frequency should be non-zero"
  | Bool
otherwise = Int -> [(Int, StdGen -> (a, StdGen))] -> (a, StdGen)
pick Int
i [(Int, StdGen -> (a, StdGen))]
xs0
 where
  (Int
i, StdGen
g') = (Int, Int) -> StdGen -> (Int, StdGen)
forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
uniformR (Int
1, Int
tot) StdGen
g

  tot :: Int
tot = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (((Int, StdGen -> (a, StdGen)) -> Int)
-> [(Int, StdGen -> (a, StdGen))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, StdGen -> (a, StdGen)) -> Int
forall a b. (a, b) -> a
fst [(Int, StdGen -> (a, StdGen))]
xs0)

  pick :: Int -> [(Int, StdGen -> (a, StdGen))] -> (a, StdGen)
pick Int
n ((Int
k,StdGen -> (a, StdGen)
x):[(Int, StdGen -> (a, StdGen))]
xs)
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
k    = StdGen -> (a, StdGen)
x StdGen
g'
    | Bool
otherwise = Int -> [(Int, StdGen -> (a, StdGen))] -> (a, StdGen)
pick (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
k) [(Int, StdGen -> (a, StdGen))]
xs
  pick Int
_ [(Int, StdGen -> (a, StdGen))]
_  = [Char] -> (a, StdGen)
forall a. HasCallStack => [Char] -> a
error [Char]
"frequency: pick used with empty list"

{-------------------------------------------------------------------------------
  Generators for specific data types
-------------------------------------------------------------------------------}

-- | Generates a random bytestring. Its length is uniformly distributed within
-- the provided range.
randomByteStringR :: (Int, Int) -> StdGen -> (BS.ByteString, StdGen)
randomByteStringR :: (Int, Int) -> StdGen -> (ByteString, StdGen)
randomByteStringR (Int, Int)
range StdGen
g =
    let (!Int
l, !StdGen
g')  = (Int, Int) -> StdGen -> (Int, StdGen)
forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
uniformR (Int, Int)
range StdGen
g
    in  Int -> StdGen -> (ByteString, StdGen)
forall g. RandomGen g => Int -> g -> (ByteString, g)
R.genByteString Int
l StdGen
g'