{-# LANGUAGE CPP           #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE UnboxedTuples #-}
{-# OPTIONS_HADDOCK not-home #-}

-- | @bytestring@ extras
--
module Database.LSMTree.Internal.ByteString (
    tryCheapToShort,
    tryGetByteArray,
    shortByteStringFromTo,
    byteArrayFromTo,
    byteArrayToByteString,
    unsafePinnedByteArrayToByteString,
    byteArrayToSBS,
) where

import           Control.Exception (assert)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Builder.Internal as BB
import qualified Data.ByteString.Internal as BS.Internal
import           Data.ByteString.Short (ShortByteString (SBS))
import qualified Data.ByteString.Short.Internal as SBS
import           Data.Primitive.ByteArray
import           Database.LSMTree.Internal.Assertions (isValidSlice)
import           Foreign.Ptr (minusPtr, plusPtr)
import           GHC.Exts
import qualified GHC.ForeignPtr as Foreign
import           GHC.Stack (HasCallStack)

-- | \( O(1) \) conversion, if possible.
--
-- In addition to the conditions explained for 'tryGetByteArray', the
-- bytestring must use the full length of the underlying byte array.
tryCheapToShort :: BS.ByteString -> Either String ShortByteString
tryCheapToShort :: ByteString -> Either String ShortByteString
tryCheapToShort ByteString
bs =
    ByteString -> Either String (ByteArray, Int)
tryGetByteArray ByteString
bs Either String (ByteArray, Int)
-> ((ByteArray, Int) -> Either String ShortByteString)
-> Either String ShortByteString
forall a b.
Either String a -> (a -> Either String b) -> Either String b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteArray
ba , Int
n) ->
      if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteArray -> Int
sizeofByteArray ByteArray
ba then
        String -> Either String ShortByteString
forall a b. a -> Either a b
Left String
"ByteString does not use full ByteArray"
      else
        let !(ByteArray ByteArray#
ba#) = ByteArray
ba in ShortByteString -> Either String ShortByteString
forall a b. b -> Either a b
Right (ByteArray# -> ShortByteString
SBS ByteArray#
ba#)


-- | \( O(1) \) conversion from a strict 'BS.ByteString' to its underlying
-- pinned 'ByteArray', if possible. Also returns the length (in bytes) of the
-- byte array prefix that was used by the bytestring.
--
-- Strict bytestrings are allocated using 'mallocPlainForeignPtrBytes', so we
-- are expecting a 'PlainPtr' (or 'FinalPtr' when the length is 0).
-- We also require that bytestrings referencing a byte array point point at the
-- beginning, without any offset.
tryGetByteArray :: BS.ByteString -> Either String (ByteArray, Int)
tryGetByteArray :: ByteString -> Either String (ByteArray, Int)
tryGetByteArray (BS.Internal.BS (Foreign.ForeignPtr Addr#
addr# ForeignPtrContents
contents) Int
n) =
    case ForeignPtrContents
contents of
      Foreign.PlainPtr MutableByteArray# RealWorld
mba# ->
        case MutableByteArray# RealWorld -> Addr#
forall s. MutableByteArray# s -> Addr#
mutableByteArrayContentsShim# MutableByteArray# RealWorld
mba# Addr# -> Addr# -> Int#
`eqAddr#` Addr#
addr# of
          Int#
0# -> String -> Either String (ByteArray, Int)
forall a b. a -> Either a b
Left String
"non-zero offset into ByteArray"
          Int#
_  -> -- safe, ByteString's content is considered immutable
                (ByteArray, Int) -> Either String (ByteArray, Int)
forall a b. b -> Either a b
Right ((ByteArray, Int) -> Either String (ByteArray, Int))
-> (ByteArray, Int) -> Either String (ByteArray, Int)
forall a b. (a -> b) -> a -> b
$ case MutableByteArray# RealWorld
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
mba# State# RealWorld
realWorld# of
                  (# State# RealWorld
_, ByteArray#
ba# #) -> (ByteArray# -> ByteArray
ByteArray ByteArray#
ba#, Int
n)
      Foreign.MallocPtr {} ->
        String -> Either String (ByteArray, Int)
forall a b. a -> Either a b
Left (String
"unsupported MallocPtr (length " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
")")
      Foreign.PlainForeignPtr {} ->
        String -> Either String (ByteArray, Int)
forall a b. a -> Either a b
Left (String
"unsupported PlainForeignPtr (length " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
")")
#if __GLASGOW_HASKELL__ >= 902
      ForeignPtrContents
Foreign.FinalPtr | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
        -- We can also handle empty bytestrings ('BS.empty' uses 'FinalPtr').
        (ByteArray, Int) -> Either String (ByteArray, Int)
forall a b. b -> Either a b
Right (ByteArray
emptyByteArray, Int
0)
      ForeignPtrContents
Foreign.FinalPtr ->
        String -> Either String (ByteArray, Int)
forall a b. a -> Either a b
Left (String
"unsupported FinalPtr (length "  String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
")")
#endif

-- | Copied from the @primitive@ package
mutableByteArrayContentsShim# :: MutableByteArray# s -> Addr#
{-# INLINE mutableByteArrayContentsShim# #-}
mutableByteArrayContentsShim# :: forall s. MutableByteArray# s -> Addr#
mutableByteArrayContentsShim# MutableByteArray# s
x =
#if __GLASGOW_HASKELL__ >= 902
  MutableByteArray# s -> Addr#
forall s. MutableByteArray# s -> Addr#
mutableByteArrayContents# MutableByteArray# s
x
#else
  byteArrayContents# (unsafeCoerce# x)
#endif

-- | Copy of 'SBS.shortByteString', but with bounds (unchecked).
--
-- https://github.com/haskell/bytestring/issues/664
{-# INLINE shortByteStringFromTo #-}
shortByteStringFromTo :: Int -> Int -> ShortByteString -> BB.Builder
shortByteStringFromTo :: Int -> Int -> ShortByteString -> Builder
shortByteStringFromTo = \Int
i Int
j ShortByteString
sbs -> (forall r. BuildStep r -> BuildStep r) -> Builder
BB.builder ((forall r. BuildStep r -> BuildStep r) -> Builder)
-> (forall r. BuildStep r -> BuildStep r) -> Builder
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ShortByteString -> BuildStep r -> BuildStep r
forall a.
Int -> Int -> ShortByteString -> BuildStep a -> BuildStep a
shortByteStringCopyStepFromTo Int
i Int
j ShortByteString
sbs

-- | Like 'shortByteStringFromTo' but for 'ByteArray'
--
-- https://github.com/haskell/bytestring/issues/664
byteArrayFromTo :: Int -> Int -> ByteArray -> BB.Builder
byteArrayFromTo :: Int -> Int -> ByteArray -> Builder
byteArrayFromTo = \Int
i Int
j (ByteArray ByteArray#
ba) -> (forall r. BuildStep r -> BuildStep r) -> Builder
BB.builder ((forall r. BuildStep r -> BuildStep r) -> Builder)
-> (forall r. BuildStep r -> BuildStep r) -> Builder
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ShortByteString -> BuildStep r -> BuildStep r
forall a.
Int -> Int -> ShortByteString -> BuildStep a -> BuildStep a
shortByteStringCopyStepFromTo Int
i Int
j (ByteArray# -> ShortByteString
SBS ByteArray#
ba)

-- | Copy of 'SBS.shortByteStringCopyStep' but with bounds (unchecked)
{-# INLINE shortByteStringCopyStepFromTo #-}
shortByteStringCopyStepFromTo ::
  Int -> Int -> ShortByteString -> BB.BuildStep a -> BB.BuildStep a
shortByteStringCopyStepFromTo :: forall a.
Int -> Int -> ShortByteString -> BuildStep a -> BuildStep a
shortByteStringCopyStepFromTo !Int
ip0 !Int
ipe0 !ShortByteString
sbs BuildStep a
k =
    Int -> Int -> BuildStep a
go Int
ip0 Int
ipe0
  where
    go :: Int -> Int -> BuildStep a
go !Int
ip !Int
ipe (BB.BufferRange Ptr Word8
op Ptr Word8
ope)
      | Int
inpRemaining Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
outRemaining = do
          ShortByteString -> Int -> Ptr Word8 -> Int -> IO ()
forall a. ShortByteString -> Int -> Ptr a -> Int -> IO ()
SBS.copyToPtr ShortByteString
sbs Int
ip Ptr Word8
op Int
inpRemaining
          let !br' :: BufferRange
br' = Ptr Word8 -> Ptr Word8 -> BufferRange
BB.BufferRange (Ptr Word8
op Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
inpRemaining) Ptr Word8
ope
          BuildStep a
k BufferRange
br'
      | Bool
otherwise = do
          ShortByteString -> Int -> Ptr Word8 -> Int -> IO ()
forall a. ShortByteString -> Int -> Ptr a -> Int -> IO ()
SBS.copyToPtr ShortByteString
sbs Int
ip Ptr Word8
op Int
outRemaining
          let !ip' :: Int
ip' = Int
ip Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
outRemaining
          BuildSignal a -> IO (BuildSignal a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (BuildSignal a -> IO (BuildSignal a))
-> BuildSignal a -> IO (BuildSignal a)
forall a b. (a -> b) -> a -> b
$ Int -> Ptr Word8 -> BuildStep a -> BuildSignal a
forall a. Int -> Ptr Word8 -> BuildStep a -> BuildSignal a
BB.bufferFull Int
1 Ptr Word8
ope (Int -> Int -> BuildStep a
go Int
ip' Int
ipe)
      where
        outRemaining :: Int
outRemaining = Ptr Word8
ope Ptr Word8 -> Ptr Word8 -> Int
forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
op
        inpRemaining :: Int
inpRemaining = Int
ipe Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ip

-- | \( O(1) \) conversion if the byte array is pinned, \( O(n) \) otherwise.
-- Takes offset and length of the slice to be used.
byteArrayToByteString :: Int -> Int -> ByteArray -> BS.ByteString
byteArrayToByteString :: Int -> Int -> ByteArray -> ByteString
byteArrayToByteString Int
off Int
len ByteArray
ba =
    Bool -> ByteString -> ByteString
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int -> Int -> ByteArray -> Bool
isValidSlice Int
off Int
len ByteArray
ba) (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
      if ByteArray -> Bool
isByteArrayPinned ByteArray
ba
      then (?callStack::CallStack) => Int -> Int -> ByteArray -> ByteString
Int -> Int -> ByteArray -> ByteString
unsafePinnedByteArrayToByteString Int
off Int
len ByteArray
ba
      else (?callStack::CallStack) => Int -> Int -> ByteArray -> ByteString
Int -> Int -> ByteArray -> ByteString
unsafePinnedByteArrayToByteString Int
0 Int
len (ByteArray -> ByteString) -> ByteArray -> ByteString
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MutableByteArray s)) -> ByteArray
runByteArray ((forall s. ST s (MutableByteArray s)) -> ByteArray)
-> (forall s. ST s (MutableByteArray s)) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
        MutableByteArray s
mba <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newPinnedByteArray Int
len
        MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba Int
0 ByteArray
ba Int
off Int
len
        MutableByteArray s -> ST s (MutableByteArray s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MutableByteArray s
mba

-- | \( O(1) \) conversion. Takes offset and length of the slice to be used.
-- Fails if the byte array is not pinned.
--
-- Based on 'SBS.fromShort'.
unsafePinnedByteArrayToByteString :: HasCallStack => Int -> Int -> ByteArray -> BS.ByteString
unsafePinnedByteArrayToByteString :: (?callStack::CallStack) => Int -> Int -> ByteArray -> ByteString
unsafePinnedByteArrayToByteString off :: Int
off@(I# Int#
off#) Int
len ba :: ByteArray
ba@(ByteArray ByteArray#
ba#) =
    Bool -> ByteString -> ByteString
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int -> Int -> ByteArray -> Bool
isValidSlice Int
off Int
len ByteArray
ba) (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
      if ByteArray -> Bool
isByteArrayPinned ByteArray
ba
      then ForeignPtr Word8 -> Int -> ByteString
BS.Internal.BS ForeignPtr Word8
fp Int
len
      else String -> ByteString
forall a. (?callStack::CallStack) => String -> a
error (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ String
"unsafePinnedByteArrayToByteString: not pinned, length "
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (ByteArray -> Int
sizeofByteArray ByteArray
ba)
  where
    addr# :: Addr#
addr# = Addr# -> Int# -> Addr#
plusAddr# (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba#) Int#
off#
    fp :: ForeignPtr Word8
fp = Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
Foreign.ForeignPtr Addr#
addr# (MutableByteArray# RealWorld -> ForeignPtrContents
Foreign.PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
unsafeCoerce# ByteArray#
ba#))

byteArrayToSBS :: ByteArray -> ShortByteString
#if MIN_VERSION_bytestring(0,12,0)
byteArrayToSBS ba             = SBS.ShortByteString ba
#else
byteArrayToSBS :: ByteArray -> ShortByteString
byteArrayToSBS (ByteArray ByteArray#
ba) = ByteArray# -> ShortByteString
SBS.SBS ByteArray#
ba
#endif