{-# LANGUAGE OverloadedStrings  #-}

------------------------------------------------------------------------------
-- |
-- Module:      Database.PostgreSQL.Simple.Cursor
-- Copyright:   (c) 2011-2012 Leon P Smith
--              (c) 2017 Bardur Arantsson
-- License:     BSD3
-- Maintainer:  Leon P Smith <[email protected]>
--
------------------------------------------------------------------------------

module Database.PostgreSQL.Simple.Cursor
    (
    -- * Types
      Cursor
    -- * Cursor management
    , declareCursor
    , closeCursor
    -- * Folding over rows from a cursor
    , foldForward
    , foldForwardWithParser
    ) where

import           Data.ByteString.Builder (intDec)
import           Control.Applicative ((<$>))
import           Control.Exception as E
import           Control.Monad (unless, void)
import           Data.Monoid (mconcat)
import           Database.PostgreSQL.Simple.Compat ((<>), toByteString)
import           Database.PostgreSQL.Simple.FromRow (FromRow(..))
import           Database.PostgreSQL.Simple.Types (Query(..))
import           Database.PostgreSQL.Simple.Internal as Base hiding (result, row)
import           Database.PostgreSQL.Simple.Internal.PQResultUtils
import           Database.PostgreSQL.Simple.Transaction
import qualified Database.PostgreSQL.LibPQ as PQ

-- | Cursor within a transaction.
data Cursor = Cursor !Query !Connection

-- | Declare a temporary cursor. The cursor is given a
-- unique name for the given connection.
declareCursor :: Connection -> Query -> IO Cursor
declareCursor :: Connection -> Query -> IO Cursor
declareCursor Connection
conn Query
q = do
  name <- Connection -> IO Query
newTempName Connection
conn
  void $ execute_ conn $ mconcat ["DECLARE ", name, " NO SCROLL CURSOR FOR ", q]
  return $ Cursor name conn

-- | Close the given cursor.
closeCursor :: Cursor -> IO ()
closeCursor :: Cursor -> IO ()
closeCursor (Cursor Query
name Connection
conn) =
  (IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
execute_ Connection
conn (Query
"CLOSE " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
name)) IO () -> (SqlError -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \SqlError
ex ->
     -- Don't throw exception if CLOSE failed because the transaction is
     -- aborted.  Otherwise, it will throw away the original error.
     Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (SqlError -> Bool
isFailedTransactionError SqlError
ex) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ SqlError -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SqlError
ex

-- | Fold over a chunk of rows from the given cursor, calling the
-- supplied fold-like function on each row as it is received. In case
-- the cursor is exhausted, a 'Left' value is returned, otherwise a
-- 'Right' value is returned.
foldForwardWithParser :: Cursor -> RowParser r -> Int -> (a -> r -> IO a) -> a -> IO (Either a a)
foldForwardWithParser :: forall r a.
Cursor
-> RowParser r -> Int -> (a -> r -> IO a) -> a -> IO (Either a a)
foldForwardWithParser (Cursor Query
name Connection
conn) RowParser r
parser Int
chunkSize a -> r -> IO a
f a
a0 = do
  let q :: ByteString
q = ByteString
"FETCH FORWARD "
            ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> (Builder -> ByteString
toByteString (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Builder
intDec Int
chunkSize)
            ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" FROM "
            ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Query -> ByteString
fromQuery Query
name
  result <- Connection -> ByteString -> IO Result
exec Connection
conn ByteString
q
  status <- PQ.resultStatus result
  case status of
      ExecStatus
PQ.TuplesOk -> do
          nrows <- Result -> IO Row
PQ.ntuples Result
result
          ncols <- PQ.nfields result
          if nrows > 0
          then do
              let inner a
a Row
row = do
                    x <- RowParser r -> Row -> Column -> Connection -> Result -> IO r
forall r.
RowParser r -> Row -> Column -> Connection -> Result -> IO r
getRowWith RowParser r
parser Row
row Column
ncols Connection
conn Result
result
                    f a x
              Right <$> foldM' inner a0 0 (nrows - 1)
          else
            return $ Left a0
      ExecStatus
_   -> ByteString -> Result -> ExecStatus -> IO (Either a a)
forall a. ByteString -> Result -> ExecStatus -> IO a
throwResultError ByteString
"foldForwardWithParser" Result
result ExecStatus
status

-- | Fold over a chunk of rows, calling the supplied fold-like function
-- on each row as it is received. In case the cursor is exhausted,
-- a 'Left' value is returned, otherwise a 'Right' value is returned.
foldForward :: FromRow r => Cursor -> Int -> (a -> r -> IO a) -> a -> IO (Either a a)
foldForward :: forall r a.
FromRow r =>
Cursor -> Int -> (a -> r -> IO a) -> a -> IO (Either a a)
foldForward Cursor
cursor = Cursor
-> RowParser r -> Int -> (a -> r -> IO a) -> a -> IO (Either a a)
forall r a.
Cursor
-> RowParser r -> Int -> (a -> r -> IO a) -> a -> IO (Either a a)
foldForwardWithParser Cursor
cursor RowParser r
forall a. FromRow a => RowParser a
fromRow


foldM' :: (Ord n, Num n) => (a -> n -> IO a) -> a -> n -> n -> IO a
foldM' :: forall n a.
(Ord n, Num n) =>
(a -> n -> IO a) -> a -> n -> n -> IO a
foldM' a -> n -> IO a
f a
a n
lo n
hi = a -> n -> IO a
loop a
a n
lo
  where
    loop :: a -> n -> IO a
loop a
x !n
n
      | n
n n -> n -> Bool
forall a. Ord a => a -> a -> Bool
> n
hi = a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
      | Bool
otherwise = do
           x' <- a -> n -> IO a
f a
x n
n
           loop x' (n+1)
{-# INLINE foldM' #-}