{-# language DataKinds #-}
{-# language MagicHash #-}
{-# language NumericUnderscores #-}
{-# language BangPatterns #-}
{-# language TypeApplications #-}
{-# language TypeOperators #-}

module Vector.Int32
  ( -- Types
    Vector(..)
  , Vector#
  , MutableVector(..)
  , MutableVector#
  , Bounded(..)
  , Vector_(..)
  , FromMutability#
    -- * Primitives
  , write#
  , write
  , read#
  , index#
  , index
  , unlift
  , substitute
  , substitute#
  , initialized
  , initialized#
  , unsafeCoerceLength
  , expose
  , expose#
    -- * Ranges
  , set
  , setSlice
    -- * Freeze
  , unsafeShrinkFreeze
  , unsafeFreeze
  , freeze
  , freezeSlice
  , freeze#
  , freezeSlice#
    -- * Copy
  , thaw
    -- * Composite
  , map
  , traverse_
  , itraverse_
  , itraverse_#
  , traverseST#
  , ifoldl'
  , ifoldlSlice'
  , replicate
  , empty
  , empty_
  , construct1
  , construct2
  , construct3
  , construct4
  , construct5
  , construct6
  , construct7
  , append
  , clone
  , cloneSlice
  , copySlice
    -- * Index
  , index0
  , index1
  , index2
  , index3
    -- * Ordered
  , unique
  , equals
  , findIndexEq
  , maximum
  , maximumSlice
  , maximumSliceInitial
  , bubbleSort
  , bubbleSortSlice
  , bubbleSortSliceInPlace
  , mapEq
    -- * Custom
  , cumulativeSum1
  , toFins
  , weakenFins
    -- * Show
  , show
    -- * Interop with primitive
  , cloneFromByteArray
  ) where

import Prelude hiding (replicate,map,maximum,Bounded,all,show)

import Vector.Std.Int32
import Vector.Ord.Int32
import Vector.Eq.Int32

import Control.Monad.ST (runST)
import GHC.Exts (Int32#)
import GHC.Int (Int(I#),Int32(I32#),Int64(I64#))
import GHC.TypeNats (type (+))
import Arithmetic.Types (Nat#,Fin32#,type (<=#))
import Data.Primitive (ByteArray(ByteArray))
import Data.Unlifted (PrimArray#(PrimArray#))

import qualified GHC.Exts as Exts
import qualified Arithmetic.Fin as Fin
import qualified Arithmetic.Nat as Nat
import qualified Data.Primitive as PM
import qualified Vector.Prim.Int32

-- | Crashes if the sum of all the elements exceeds the maximum
cumulativeSum1 ::
     Nat# n
  -> Vector n Int32#
  -> Vector (n + 1) Int32#
cumulativeSum1 :: forall (n :: Nat).
Nat# n -> Vector n Int32# -> Vector (n + 1) Int32#
cumulativeSum1 Nat# n
n !Vector n Int32#
v = (forall s. ST s (Vector (n + 1) Int32#)) -> Vector (n + 1) Int32#
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector (n + 1) Int32#)) -> Vector (n + 1) Int32#)
-> (forall s. ST s (Vector (n + 1) Int32#))
-> Vector (n + 1) Int32#
forall a b. (a -> b) -> a -> b
$ do
  MutableVector s (n + 1) Int32#
dst <- Nat# (n + 1) -> Int32# -> ST s (MutableVector s (n + 1) Int32#)
forall s (n :: Nat) (a :: TYPE R).
Nat# n -> a -> ST s (MutableVector s n a)
initialized (Nat# n -> Nat# (n + 1)
forall (a :: Nat). Nat# a -> Nat# (a + 1)
Nat.succ# Nat# n
n) (Int# -> Int32#
Exts.intToInt32# Int#
0#)
  Int64
_ <- Nat# n -> Int64 -> (Fin# n -> Int64 -> ST s Int64) -> ST s Int64
forall (m :: * -> *) a (n :: Nat).
Monad m =>
Nat# n -> a -> (Fin# n -> a -> m a) -> m a
Fin.ascendM# Nat# n
n (Int64
0 :: Int64)
    (\Fin# n
fin Int64
acc0 -> do
      let x :: Int32#
x = Vector n Int32# -> Fin# n -> Int32#
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
index Vector n Int32#
v Fin# n
fin
      let !acc1 :: Int64
acc1@(I64# Int64#
acc1# ) = Int64
acc0 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64# -> Int64
I64# (Int# -> Int64#
Exts.intToInt64# (Int32# -> Int#
Exts.int32ToInt# Int32#
x))
      if Int64
acc1 Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
> Int64
2_147_483_647
        then [Char] -> ST s Int64
forall a. [Char] -> a
errorWithoutStackTrace [Char]
"Vector.Int32.cumulativeSum1: sum > 2^31-1"
        else if Int64
acc1 Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< (-Int64
2_147_483_648) 
          then [Char] -> ST s Int64
forall a. [Char] -> a
errorWithoutStackTrace [Char]
"Vector.Int32.cumulativeSum1: sum < -2^31"
          else do
            MutableVector s (n + 1) Int32# -> Fin# (n + 1) -> Int32# -> ST s ()
forall s (n :: Nat) (a :: TYPE R).
MutableVector s n a -> Fin# n -> a -> ST s ()
write MutableVector s (n + 1) Int32#
dst (Nat# 1 -> Fin# n -> Fin# (n + 1)
forall (n :: Nat) (m :: Nat). Nat# m -> Fin# n -> Fin# (n + m)
Fin.incrementR# (# #) -> Nat# 1
Nat.N1# Fin# n
fin) (Int# -> Int32#
Exts.intToInt32# (Int64# -> Int#
Exts.int64ToInt# Int64#
acc1#))
            Int64 -> ST s Int64
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int64
acc1
    )
  MutableVector s (n + 1) Int32# -> ST s (Vector (n + 1) Int32#)
forall s (n :: Nat) (a :: TYPE R).
MutableVector s n a -> ST s (Vector n a)
unsafeFreeze MutableVector s (n + 1) Int32#
dst

toFins :: 
     Nat# m -- ^ upper bound
  -> Nat# n -- ^ vector length
  -> Vector n Int32#
  -> Maybe (Vector n (Fin32# m))
toFins :: forall (m :: Nat) (n :: Nat).
Nat# m -> Nat# n -> Vector n Int32# -> Maybe (Vector n (Fin32# m))
toFins Nat# m
m Nat# n
n !Vector n Int32#
v = if (Int32# -> Bool) -> Nat# n -> Vector n Int32# -> Bool
forall (a :: TYPE R) (n :: Nat).
(a -> Bool) -> Nat# n -> Vector n a -> Bool
all (\Int32#
v# -> let w :: Int32
w = Int32# -> Int32
I32# Int32#
v# in Int32
w Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int32
0 Bool -> Bool -> Bool
&& forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int32 @Int Int32
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int# -> Int
I# (Nat# m -> Int#
forall (n :: Nat). Nat# n -> Int#
Nat.demote# Nat# m
m)) Nat# n
n Vector n Int32#
v
  then Vector n (Fin32# m) -> Maybe (Vector n (Fin32# m))
forall a. a -> Maybe a
Just (Vector n Int32# -> Vector n (Fin32# m)
forall (a :: TYPE R) (b :: TYPE R) (n :: Nat).
Vector n a -> Vector n b
unsafeCoerceVector Vector n Int32#
v)
  else Maybe (Vector n (Fin32# m))
forall a. Maybe a
Nothing

weakenFins ::
     (a <=# b)
  -> Vector n (Fin32# a)
  -> Vector n (Fin32# b)
{-# inline weakenFins #-}
weakenFins :: forall (a :: Nat) (b :: Nat) (n :: Nat).
(a <=# b) -> Vector n (Fin32# a) -> Vector n (Fin32# b)
weakenFins a <=# b
_ (Vector Vector# n (Fin32# a)
x) = case Vector# n (Fin32# a) -> A# (Fin32# a)
forall (n :: Nat) (a :: TYPE R). Vector# n a -> A# a
expose# Vector# n (Fin32# a)
x of
  PrimArray# ByteArray#
z -> Vector# n (Fin32# b) -> Vector n (Fin32# b)
forall (a :: Nat) (b :: TYPE R). Vector# a b -> Vector a b
Vector (A# (Fin32# b) -> Vector# n (Fin32# b)
forall (a :: TYPE R) (n :: Nat). A# a -> Vector# n a
unsafeConstruct# (ByteArray# -> A# (Fin32# b)
forall a. ByteArray# -> PrimArray# a
PrimArray# ByteArray#
z))

-- | Crashes the program if the range is out of bounds. That is,
-- behavior is always well defined.
--
-- Interprets the bytes in a native-endian fashion.
cloneFromByteArray ::
     Int    -- ^ Offset into byte array, units are elements, not bytes
  -> Nat# n -- ^ Length of the vector, units are elements, not bytes
  -> ByteArray
  -> Vector n Int32#
cloneFromByteArray :: forall (n :: Nat). Int -> Nat# n -> ByteArray -> Vector n Int32#
cloneFromByteArray = Int -> Nat# n -> ByteArray -> Vector n Int32#
forall (n :: Nat) (a :: TYPE R).
Int -> Nat# n -> ByteArray -> Vector n a
Vector.Prim.Int32.unsafeCloneFromByteArray

show :: Nat# n -> Vector n Int32# -> String
show :: forall (n :: Nat). Nat# n -> Vector n Int32# -> [Char]
show Nat# n
n Vector n Int32#
v = (Int32# -> [Char] -> [Char])
-> Nat# n -> Vector n Int32# -> [Char] -> [Char]
forall (n :: Nat) (a :: TYPE R).
(a -> [Char] -> [Char]) -> Nat# n -> Vector n a -> [Char] -> [Char]
liftShows (\Int32#
i [Char]
s -> Int32 -> [Char] -> [Char]
forall a. Show a => a -> [Char] -> [Char]
shows (Int32# -> Int32
I32# Int32#
i) [Char]
s) Nat# n
n Vector n Int32#
v [Char]
""