{-# language BangPatterns #-}
{-# language BlockArguments #-}
{-# language DataKinds #-}
{-# language ExplicitNamespaces #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language MagicHash #-}
{-# language PatternSynonyms #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language TypeApplications #-}
{-# language TypeOperators #-}
{-# language UnboxedTuples #-}
{-# language UnliftedNewtypes #-}

module PermuteVector
  ( permute
  ) where

import Prelude hiding (Bounded,max,min,maximum)

import Rep (R)
import FinType (Finite#,weaken)
import GHC.ST (ST(ST),runST)
import Arithmetic.Types (type (<),Fin(Fin),Nat#)
import Arithmetic.Types (type (:=:),type (<=))
import Arithmetic.Types (type (<#),type (<=#))
import Arithmetic.Nat ((<?),(<?#))
import GHC.TypeNats (type (+))
import GHC.Exts (TYPE,State#)
import Data.Either.Void (pattern LeftVoid#, pattern RightVoid#)

import qualified GHC.TypeNats as GHC
import qualified Element
import qualified Arithmetic.Lt as Lt
import qualified Arithmetic.Lte as Lte
import qualified Arithmetic.Nat as Nat
import qualified Arithmetic.Fin as Fin
import qualified Vector as V
import qualified FinVector as FV

-- | Permute the source array according to the indices:
--
-- forall ix. output[ix] = source[indices[ix]]
permute :: forall (m :: GHC.Nat) (n :: GHC.Nat) (a :: TYPE R).
     Nat# m -- ^ indices length
  -> FV.Vector m (Finite# n) -- ^ indices
  -> V.Vector n a -- ^ source
  -> V.Vector m a -- ^ output
{-# noinline permute #-}
permute :: forall (m :: Nat) (n :: Nat) (a :: TYPE R).
Nat# m -> Vector m (Finite# n) -> Vector n a -> Vector m a
permute Nat# m
m !Vector m (Finite# n)
ixs !Vector n a
v = case Nat# m -> EitherVoid# (0 :=:# m) (0 <# m)
forall (a :: Nat). Nat# a -> EitherVoid# (0 :=:# a) (0 <# a)
Nat.testZero# Nat# m
m of
  LeftVoid# 0 :=:# m
zeq -> (0 :=:# m) -> Vector 0 a -> Vector m a
forall (m :: Nat) (n :: Nat) (a :: TYPE R).
(m :=:# n) -> Vector m a -> Vector n a
V.substitute 0 :=:# m
zeq Vector 0 a
forall (a :: TYPE R). Vector 0 a
V.empty
  RightVoid# 0 <# m
zlt -> (forall s. ST s (Vector m a)) -> Vector m a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector m a)) -> Vector m a)
-> (forall s. ST s (Vector m a)) -> Vector m a
forall a b. (a -> b) -> a -> b
$ do
    -- More clean presentation of initialization:  
    -- dst := initialize(v[ixs[0]]])
    MutableVector s m a
dst <- Nat# m -> a -> ST s (MutableVector s m a)
forall s (n :: Nat) (a :: TYPE R).
Nat# n -> a -> ST s (MutableVector s n a)
V.initialized Nat# m
m (Vector n a -> Fin# n -> a
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
V.index Vector n a
v (Finite# n -> Fin# n
forall (n :: Nat). Finite# n -> Fin# n
weaken (Vector m (Finite# n) -> Fin# m -> Finite# n
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
FV.index Vector m (Finite# n)
ixs ((0 <# m) -> Nat# 0 -> Fin# m
forall (i :: Nat) (n :: Nat). (i <# n) -> Nat# i -> Fin# n
Fin.construct# 0 <# m
zlt ((# #) -> Nat# 0
Nat.zero# (# #))))))
    Nat# m -> (Fin# m -> ST s ()) -> ST s ()
forall (m :: * -> *) a (n :: Nat).
Monad m =>
Nat# n -> (Fin# n -> m a) -> m ()
Fin.ascendM_# Nat# m
m ((Fin# m -> ST s ()) -> ST s ()) -> (Fin# m -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Fin# m
fin -> do
      MutableVector s m a -> Fin# m -> a -> ST s ()
forall s (n :: Nat) (a :: TYPE R).
MutableVector s n a -> Fin# n -> a -> ST s ()
V.write MutableVector s m a
dst Fin# m
fin (Vector n a -> Fin# n -> a
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
V.index Vector n a
v (Finite# n -> Fin# n
forall (n :: Nat). Finite# n -> Fin# n
weaken (Vector m (Finite# n) -> Fin# m -> Finite# n
forall (n :: Nat) (a :: TYPE R). Vector n a -> Fin# n -> a
FV.index Vector m (Finite# n)
ixs Fin# m
fin)))
    MutableVector s m a -> ST s (Vector m a)
forall s (n :: Nat) (a :: TYPE R).
MutableVector s n a -> ST s (Vector n a)
V.unsafeFreeze MutableVector s m a
dst