Skip to content

Commit b489a43

Browse files
authoredFeb 15, 2024
feat: support saving with customized content column and saving/loading with non-default metadata JSON column. (#19)
1 parent 5aecbd0 commit b489a43

File tree

4 files changed

+135
-40
lines changed

4 files changed

+135
-40
lines changed
 

‎src/langchain_google_cloud_sql_mysql/mysql_engine.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,9 @@ def init_document_table(
233233
self,
234234
table_name: str,
235235
metadata_columns: List[sqlalchemy.Column] = [],
236-
store_metadata: bool = True,
236+
content_column: str = "page_content",
237+
metadata_json_column: Optional[str] = "langchain_metadata",
238+
overwrite_existing: bool = False,
237239
) -> None:
238240
"""
239241
Create a table for saving of langchain documents.
@@ -242,22 +244,29 @@ def init_document_table(
242244
table_name (str): The MySQL database table name.
243245
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
244246
to create for custom metadata. Optional.
245-
store_metadata (bool): Whether to store extra metadata in a metadata column
246-
if not described in 'metadata' field list (Default: True).
247+
content_column (str): The column to store document content.
248+
Deafult: `page_content`.
249+
metadata_json_column (Optional[str]): The column to store extra metadata in JSON format.
250+
Default: `langchain_metadata`. Optional.
251+
overwrite_existing (bool): Whether to drop existing table. Default: False.
247252
"""
253+
if overwrite_existing:
254+
with self.engine.connect() as conn:
255+
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`;"))
256+
248257
columns = [
249258
sqlalchemy.Column(
250-
"page_content",
259+
content_column,
251260
sqlalchemy.UnicodeText,
252261
primary_key=False,
253262
nullable=False,
254263
)
255264
]
256265
columns += metadata_columns
257-
if store_metadata:
266+
if metadata_json_column:
258267
columns.append(
259268
sqlalchemy.Column(
260-
"langchain_metadata",
269+
metadata_json_column,
261270
sqlalchemy.JSON,
262271
primary_key=False,
263272
nullable=True,

‎src/langchain_google_cloud_sql_mysql/mysql_loader.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,41 @@
2727

2828

2929
def _parse_doc_from_row(
30-
content_columns: Iterable[str], metadata_columns: Iterable[str], row: Dict
30+
content_columns: Iterable[str],
31+
metadata_columns: Iterable[str],
32+
row: Dict,
33+
metadata_json_column: str = DEFAULT_METADATA_COL,
3134
) -> Document:
3235
page_content = " ".join(
3336
str(row[column]) for column in content_columns if column in row
3437
)
3538
metadata: Dict[str, Any] = {}
3639
# unnest metadata from langchain_metadata column
37-
if DEFAULT_METADATA_COL in metadata_columns and row.get(DEFAULT_METADATA_COL):
38-
for k, v in row[DEFAULT_METADATA_COL].items():
40+
if row.get(metadata_json_column):
41+
for k, v in row[metadata_json_column].items():
3942
metadata[k] = v
4043
# load metadata from other columns
4144
for column in metadata_columns:
42-
if column in row and column != DEFAULT_METADATA_COL:
45+
if column in row and column != metadata_json_column:
4346
metadata[column] = row[column]
4447
return Document(page_content=page_content, metadata=metadata)
4548

4649

47-
def _parse_row_from_doc(column_names: Iterable[str], doc: Document) -> Dict:
50+
def _parse_row_from_doc(
51+
column_names: Iterable[str],
52+
doc: Document,
53+
content_column: str = DEFAULT_CONTENT_COL,
54+
metadata_json_column: str = DEFAULT_METADATA_COL,
55+
) -> Dict:
4856
doc_metadata = doc.metadata.copy()
49-
row: Dict[str, Any] = {DEFAULT_CONTENT_COL: doc.page_content}
57+
row: Dict[str, Any] = {content_column: doc.page_content}
5058
for entry in doc.metadata:
5159
if entry in column_names:
5260
row[entry] = doc_metadata[entry]
5361
del doc_metadata[entry]
5462
# store extra metadata in langchain_metadata column in json format
55-
if DEFAULT_METADATA_COL in column_names and len(doc_metadata) > 0:
56-
row[DEFAULT_METADATA_COL] = doc_metadata
63+
if metadata_json_column in column_names and len(doc_metadata) > 0:
64+
row[metadata_json_column] = doc_metadata
5765
return row
5866

5967

@@ -67,6 +75,7 @@ def __init__(
6775
query: str = "",
6876
content_columns: Optional[List[str]] = None,
6977
metadata_columns: Optional[List[str]] = None,
78+
metadata_json_column: Optional[str] = None,
7079
):
7180
"""
7281
Document page content defaults to the first column present in the query or table and
@@ -85,12 +94,15 @@ def __init__(
8594
of the document. Optional.
8695
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
8796
Optional.
97+
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
98+
dictionary. Default: `langchain_metadata`. Optional.
8899
"""
89100
self.engine = engine
90101
self.table_name = table_name
91102
self.query = query
92103
self.content_columns = content_columns
93104
self.metadata_columns = metadata_columns
105+
self.metadata_json_column = metadata_json_column
94106
if not self.table_name and not self.query:
95107
raise ValueError("One of 'table_name' or 'query' must be specified.")
96108
if self.table_name and self.query:
@@ -139,6 +151,25 @@ def lazy_load(self) -> Iterator[Document]:
139151
metadata_columns = self.metadata_columns or [
140152
col for col in column_names if col not in content_columns
141153
]
154+
# check validity of metadata json column
155+
if (
156+
self.metadata_json_column
157+
and self.metadata_json_column not in column_names
158+
):
159+
raise ValueError(
160+
f"Column {self.metadata_json_column} not found in query result {column_names}."
161+
)
162+
# check validity of other column
163+
all_names = content_columns + metadata_columns
164+
for name in all_names:
165+
if name not in column_names:
166+
raise ValueError(
167+
f"Column {name} not found in query result {column_names}."
168+
)
169+
# use default metadata json column if not specified
170+
metadata_json_column = self.metadata_json_column or DEFAULT_METADATA_COL
171+
172+
# load document one by one
142173
while True:
143174
row = result_proxy.fetchone()
144175
if not row:
@@ -151,7 +182,12 @@ def lazy_load(self) -> Iterator[Document]:
151182
row_data[column] = json.loads(value)
152183
else:
153184
row_data[column] = value
154-
yield _parse_doc_from_row(content_columns, metadata_columns, row_data)
185+
yield _parse_doc_from_row(
186+
content_columns,
187+
metadata_columns,
188+
row_data,
189+
metadata_json_column,
190+
)
155191

156192

157193
class MySQLDocumentSaver:
@@ -161,6 +197,8 @@ def __init__(
161197
self,
162198
engine: MySQLEngine,
163199
table_name: str,
200+
content_column: Optional[str] = None,
201+
metadata_json_column: Optional[str] = None,
164202
):
165203
"""
166204
MySQLDocumentSaver allows for saving of langchain documents in a database. If the table
@@ -169,17 +207,33 @@ def __init__(
169207
- langchain_metadata (type: JSON)
170208
171209
Args:
172-
engine: MySQLEngine object to connect to the MySQL database.
173-
table_name: The name of table for saving documents.
210+
engine (MySQLEngine): MySQLEngine object to connect to the MySQL database.
211+
table_name (str): The name of table for saving documents.
212+
content_column (str): The column to store document content.
213+
Deafult: `page_content`. Optional.
214+
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
215+
dictionary. Default: `langchain_metadata`. Optional.
174216
"""
175217
self.engine = engine
176218
self.table_name = table_name
177219
self._table = self.engine._load_document_table(table_name)
178-
if DEFAULT_CONTENT_COL not in self._table.columns.keys():
220+
221+
self.content_column = content_column or DEFAULT_CONTENT_COL
222+
if self.content_column not in self._table.columns.keys():
179223
raise ValueError(
180-
f"Missing '{DEFAULT_CONTENT_COL}' field in table {table_name}."
224+
f"Missing '{self.content_column}' field in table {table_name}."
181225
)
182226

227+
# check metadata_json_column existence if it's provided.
228+
if (
229+
metadata_json_column
230+
and metadata_json_column not in self._table.columns.keys()
231+
):
232+
raise ValueError(
233+
f"Cannot find '{metadata_json_column}' column in table {table_name}."
234+
)
235+
self.metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL
236+
183237
def add_documents(self, docs: List[Document]) -> None:
184238
"""
185239
Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
@@ -190,7 +244,12 @@ def add_documents(self, docs: List[Document]) -> None:
190244
"""
191245
with self.engine.connect() as conn:
192246
for doc in docs:
193-
row = _parse_row_from_doc(self._table.columns.keys(), doc)
247+
row = _parse_row_from_doc(
248+
self._table.columns.keys(),
249+
doc,
250+
self.content_column,
251+
self.metadata_json_column,
252+
)
194253
conn.execute(sqlalchemy.insert(self._table).values(row))
195254
conn.commit()
196255

@@ -204,7 +263,12 @@ def delete(self, docs: List[Document]) -> None:
204263
"""
205264
with self.engine.connect() as conn:
206265
for doc in docs:
207-
row = _parse_row_from_doc(self._table.columns.keys(), doc)
266+
row = _parse_row_from_doc(
267+
self._table.columns.keys(),
268+
doc,
269+
self.content_column,
270+
self.metadata_json_column,
271+
)
208272
# delete by matching all fields of document
209273
where_conditions = []
210274
for col in self._table.columns:

‎tests/integration/test_mysql_loader.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def test_load_from_query_with_langchain_metadata(engine):
249249
query=query,
250250
metadata_columns=[
251251
"fruit_name",
252-
"langchain_metadata",
253252
],
254253
)
255254

@@ -294,8 +293,9 @@ def test_save_doc_with_default_metadata(engine):
294293
]
295294

296295

297-
@pytest.mark.parametrize("store_metadata", [True, False])
298-
def test_save_doc_with_customized_metadata(engine, store_metadata):
296+
@pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"])
297+
def test_save_doc_with_customized_metadata(engine, metadata_json_column):
298+
content_column = "content_col_test"
299299
engine.init_document_table(
300300
table_name,
301301
metadata_columns=[
@@ -312,35 +312,43 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
312312
nullable=True,
313313
),
314314
],
315-
store_metadata=store_metadata,
315+
content_column=content_column,
316+
metadata_json_column=metadata_json_column,
317+
overwrite_existing=True,
316318
)
317319
test_docs = [
318320
Document(
319321
page_content="Granny Smith 150 0.99",
320322
metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1},
321323
),
322324
]
323-
saver = MySQLDocumentSaver(engine=engine, table_name=table_name)
325+
saver = MySQLDocumentSaver(
326+
engine=engine,
327+
table_name=table_name,
328+
content_column=content_column,
329+
metadata_json_column=metadata_json_column,
330+
)
324331
loader = MySQLLoader(
325332
engine=engine,
326333
table_name=table_name,
334+
content_columns=[content_column],
327335
metadata_columns=[
328-
"fruit_id",
329336
"fruit_name",
330337
"organic",
331338
],
339+
metadata_json_column=metadata_json_column,
332340
)
333341

334342
saver.add_documents(test_docs)
335343
docs = loader.load()
336344

337-
if store_metadata:
345+
if metadata_json_column:
338346
docs == test_docs
339347
assert engine._load_document_table(table_name).columns.keys() == [
340-
"page_content",
348+
content_column,
341349
"fruit_name",
342350
"organic",
343-
"langchain_metadata",
351+
metadata_json_column,
344352
]
345353
else:
346354
assert docs == [
@@ -350,7 +358,7 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
350358
),
351359
]
352360
assert engine._load_document_table(table_name).columns.keys() == [
353-
"page_content",
361+
content_column,
354362
"fruit_name",
355363
"organic",
356364
]
@@ -359,7 +367,7 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
359367
def test_save_doc_without_metadata(engine):
360368
engine.init_document_table(
361369
table_name,
362-
store_metadata=False,
370+
metadata_json_column=None,
363371
)
364372
test_docs = [
365373
Document(
@@ -413,8 +421,9 @@ def test_delete_doc_with_default_metadata(engine):
413421
assert len(loader.load()) == 0
414422

415423

416-
@pytest.mark.parametrize("store_metadata", [True, False])
417-
def test_delete_doc_with_customized_metadata(engine, store_metadata):
424+
@pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"])
425+
def test_delete_doc_with_customized_metadata(engine, metadata_json_column):
426+
content_column = "content_col_test"
418427
engine.init_document_table(
419428
table_name,
420429
metadata_columns=[
@@ -431,7 +440,9 @@ def test_delete_doc_with_customized_metadata(engine, store_metadata):
431440
nullable=True,
432441
),
433442
],
434-
store_metadata=store_metadata,
443+
content_column=content_column,
444+
metadata_json_column=metadata_json_column,
445+
overwrite_existing=True,
435446
)
436447
test_docs = [
437448
Document(
@@ -443,8 +454,18 @@ def test_delete_doc_with_customized_metadata(engine, store_metadata):
443454
metadata={"fruit_id": 2, "fruit_name": "Banana", "organic": 1},
444455
),
445456
]
446-
saver = MySQLDocumentSaver(engine=engine, table_name=table_name)
447-
loader = MySQLLoader(engine=engine, table_name=table_name)
457+
saver = MySQLDocumentSaver(
458+
engine=engine,
459+
table_name=table_name,
460+
content_column=content_column,
461+
metadata_json_column=metadata_json_column,
462+
)
463+
loader = MySQLLoader(
464+
engine=engine,
465+
table_name=table_name,
466+
content_columns=[content_column],
467+
metadata_json_column=metadata_json_column,
468+
)
448469

449470
saver.add_documents(test_docs)
450471
docs = loader.load()
@@ -474,7 +495,6 @@ def test_delete_doc_with_query(engine):
474495
nullable=True,
475496
),
476497
],
477-
store_metadata=True,
478498
)
479499
test_docs = [
480500
Document(

‎tests/unit/test_doc2row.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,13 @@ def test_row2doc_ovrride_default_metadata():
8989

9090

9191
def test_row2doc_metadata_col_nonexist():
92-
assert _parse_doc_from_row(
92+
doc = _parse_doc_from_row(
9393
["variety", "quantity_in_stock", "price_per_unit"],
9494
["fruit-id"],
9595
row_customized_nested,
96-
) == Document(page_content="Granny Smith 150 0.99")
96+
metadata_json_column="non-exist",
97+
)
98+
assert doc == Document(page_content="Granny Smith 150 0.99")
9799

98100

99101
def test_doc2row_default():

0 commit comments

Comments
 (0)