Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

_UNSET_COUNT = -1

Expand Down Expand Up @@ -210,8 +211,20 @@ def executemany(self, operation, seq_of_params):
"""
self._raise_if_closed()

classification = parse_utils.classify_stmt(operation)
if classification == parse_utils.STMT_DDL:
raise ProgrammingError(
"Executing DDL statements with executemany() method is not allowed."
)

many_result_set = StreamedManyResultSets()

for params in seq_of_params:
self.execute(operation, params)
many_result_set.add_iter(self._itr)

self._result_set = many_result_set
self._itr = many_result_set

def fetchone(self):
"""Fetch the next row of a query result set, returning a single
Expand Down
40 changes: 39 additions & 1 deletion google/cloud/spanner_dbapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import re

re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)")


class PeekIterator:
"""
Expand Down Expand Up @@ -55,7 +57,43 @@ def __iter__(self):
return self


re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)")
class StreamedManyResultSets:
"""Iterator to walk through several `StreamedResultsSet` iterators.
This type of iterator is used by `Cursor.executemany()`
method to iterate through several `StreamedResultsSet`
iterators like they all are merged into single iterator.
"""

def __init__(self):
self._iterators = []
self._index = 0

def add_iter(self, iterator):
"""Add new iterator into this one.
:type iterator: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`
:param iterator: Iterator to merge into this one.
"""
self._iterators.append(iterator)

def __next__(self):
"""Return the next value from the currently streamed iterator.
If the current iterator is streamed to the end,
start to stream the next one.
:rtype: list
:returns: The next result row.
"""
try:
res = next(self._iterators[self._index])
except StopIteration:
self._index += 1
res = self.__next__()
except IndexError:
raise StopIteration

return res

def __iter__(self):
return self


def backtick_unicode(sql):
Expand Down
40 changes: 40 additions & 0 deletions tests/system/test_system_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,46 @@ def test_results_checksum(self):

self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest())

def test_execute_many(self):
# connect to the test database
conn = Connection(Config.INSTANCE, self._db)
cursor = conn.cursor()

cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', '[email protected]'),
(2, 'first-name2', 'last-name2', '[email protected]')
"""
)
conn.commit()

cursor.executemany(
"""
SELECT * FROM contacts WHERE contact_id = @a1
""",
({"a1": 1}, {"a1": 2}),
)
res = cursor.fetchall()
conn.commit()

self.assertEqual(len(res), 2)
self.assertEqual(res[0][0], 1)
self.assertEqual(res[1][0], 2)

# checking that execute() and executemany()
# results are not mixed together
cursor.execute(
"""
SELECT * FROM contacts WHERE contact_id = 1
""",
)
res = cursor.fetchone()
conn.commit()

self.assertEqual(res[0], 1)
conn.close()


def clear_table(transaction):
"""Clear the test table."""
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,22 @@ def test_executemany_on_closed_cursor(self):
with self.assertRaises(InterfaceError):
cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ())

def test_executemany_DLL(self):
from google.cloud.spanner_dbapi import connect, ProgrammingError

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True,
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")

cursor = connection.cursor()

with self.assertRaises(ProgrammingError):
cursor.executemany("""DROP DATABASE database_name""", ())

def test_executemany(self):
from google.cloud.spanner_dbapi import connect

Expand All @@ -272,6 +288,9 @@ def test_executemany(self):
connection = connect("test-instance", "test-database")

cursor = connection.cursor()
cursor._result_set = [1, 2, 3]
cursor._itr = iter([1, 2, 3])

with mock.patch(
"google.cloud.spanner_dbapi.cursor.Cursor.execute"
) as execute_mock:
Expand Down