27
27
28
28
29
29
def _parse_doc_from_row (
30
- content_columns : Iterable [str ], metadata_columns : Iterable [str ], row : Dict
30
+ content_columns : Iterable [str ],
31
+ metadata_columns : Iterable [str ],
32
+ row : Dict ,
33
+ metadata_json_column : str = DEFAULT_METADATA_COL ,
31
34
) -> Document :
32
35
page_content = " " .join (
33
36
str (row [column ]) for column in content_columns if column in row
34
37
)
35
38
metadata : Dict [str , Any ] = {}
36
39
# unnest metadata from langchain_metadata column
37
- if DEFAULT_METADATA_COL in metadata_columns and row .get (DEFAULT_METADATA_COL ):
38
- for k , v in row [DEFAULT_METADATA_COL ].items ():
40
+ if row .get (metadata_json_column ):
41
+ for k , v in row [metadata_json_column ].items ():
39
42
metadata [k ] = v
40
43
# load metadata from other columns
41
44
for column in metadata_columns :
42
- if column in row and column != DEFAULT_METADATA_COL :
45
+ if column in row and column != metadata_json_column :
43
46
metadata [column ] = row [column ]
44
47
return Document (page_content = page_content , metadata = metadata )
45
48
46
49
47
- def _parse_row_from_doc (column_names : Iterable [str ], doc : Document ) -> Dict :
50
+ def _parse_row_from_doc (
51
+ column_names : Iterable [str ],
52
+ doc : Document ,
53
+ content_column : str = DEFAULT_CONTENT_COL ,
54
+ metadata_json_column : str = DEFAULT_METADATA_COL ,
55
+ ) -> Dict :
48
56
doc_metadata = doc .metadata .copy ()
49
- row : Dict [str , Any ] = {DEFAULT_CONTENT_COL : doc .page_content }
57
+ row : Dict [str , Any ] = {content_column : doc .page_content }
50
58
for entry in doc .metadata :
51
59
if entry in column_names :
52
60
row [entry ] = doc_metadata [entry ]
53
61
del doc_metadata [entry ]
54
62
# store extra metadata in langchain_metadata column in json format
55
- if DEFAULT_METADATA_COL in column_names and len (doc_metadata ) > 0 :
56
- row [DEFAULT_METADATA_COL ] = doc_metadata
63
+ if metadata_json_column in column_names and len (doc_metadata ) > 0 :
64
+ row [metadata_json_column ] = doc_metadata
57
65
return row
58
66
59
67
@@ -67,6 +75,7 @@ def __init__(
67
75
query : str = "" ,
68
76
content_columns : Optional [List [str ]] = None ,
69
77
metadata_columns : Optional [List [str ]] = None ,
78
+ metadata_json_column : Optional [str ] = None ,
70
79
):
71
80
"""
72
81
Document page content defaults to the first column present in the query or table and
@@ -85,12 +94,15 @@ def __init__(
85
94
of the document. Optional.
86
95
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
87
96
Optional.
97
+ metadata_json_column (str): The name of the JSON column to use as the metadata’s base
98
+ dictionary. Default: `langchain_metadata`. Optional.
88
99
"""
89
100
self .engine = engine
90
101
self .table_name = table_name
91
102
self .query = query
92
103
self .content_columns = content_columns
93
104
self .metadata_columns = metadata_columns
105
+ self .metadata_json_column = metadata_json_column
94
106
if not self .table_name and not self .query :
95
107
raise ValueError ("One of 'table_name' or 'query' must be specified." )
96
108
if self .table_name and self .query :
@@ -139,6 +151,25 @@ def lazy_load(self) -> Iterator[Document]:
139
151
metadata_columns = self .metadata_columns or [
140
152
col for col in column_names if col not in content_columns
141
153
]
154
+ # check validity of metadata json column
155
+ if (
156
+ self .metadata_json_column
157
+ and self .metadata_json_column not in column_names
158
+ ):
159
+ raise ValueError (
160
+ f"Column { self .metadata_json_column } not found in query result { column_names } ."
161
+ )
162
+ # check validity of other column
163
+ all_names = content_columns + metadata_columns
164
+ for name in all_names :
165
+ if name not in column_names :
166
+ raise ValueError (
167
+ f"Column { name } not found in query result { column_names } ."
168
+ )
169
+ # use default metadata json column if not specified
170
+ metadata_json_column = self .metadata_json_column or DEFAULT_METADATA_COL
171
+
172
+ # load document one by one
142
173
while True :
143
174
row = result_proxy .fetchone ()
144
175
if not row :
@@ -151,7 +182,12 @@ def lazy_load(self) -> Iterator[Document]:
151
182
row_data [column ] = json .loads (value )
152
183
else :
153
184
row_data [column ] = value
154
- yield _parse_doc_from_row (content_columns , metadata_columns , row_data )
185
+ yield _parse_doc_from_row (
186
+ content_columns ,
187
+ metadata_columns ,
188
+ row_data ,
189
+ metadata_json_column ,
190
+ )
155
191
156
192
157
193
class MySQLDocumentSaver :
@@ -161,6 +197,8 @@ def __init__(
161
197
self ,
162
198
engine : MySQLEngine ,
163
199
table_name : str ,
200
+ content_column : Optional [str ] = None ,
201
+ metadata_json_column : Optional [str ] = None ,
164
202
):
165
203
"""
166
204
MySQLDocumentSaver allows for saving of langchain documents in a database. If the table
@@ -169,17 +207,33 @@ def __init__(
169
207
- langchain_metadata (type: JSON)
170
208
171
209
Args:
172
- engine: MySQLEngine object to connect to the MySQL database.
173
- table_name: The name of table for saving documents.
210
+ engine (MySQLEngine): MySQLEngine object to connect to the MySQL database.
211
+ table_name (str): The name of table for saving documents.
212
+ content_column (str): The column to store document content.
213
+ Deafult: `page_content`. Optional.
214
+ metadata_json_column (str): The name of the JSON column to use as the metadata’s base
215
+ dictionary. Default: `langchain_metadata`. Optional.
174
216
"""
175
217
self .engine = engine
176
218
self .table_name = table_name
177
219
self ._table = self .engine ._load_document_table (table_name )
178
- if DEFAULT_CONTENT_COL not in self ._table .columns .keys ():
220
+
221
+ self .content_column = content_column or DEFAULT_CONTENT_COL
222
+ if self .content_column not in self ._table .columns .keys ():
179
223
raise ValueError (
180
- f"Missing '{ DEFAULT_CONTENT_COL } ' field in table { table_name } ."
224
+ f"Missing '{ self . content_column } ' field in table { table_name } ."
181
225
)
182
226
227
+ # check metadata_json_column existence if it's provided.
228
+ if (
229
+ metadata_json_column
230
+ and metadata_json_column not in self ._table .columns .keys ()
231
+ ):
232
+ raise ValueError (
233
+ f"Cannot find '{ metadata_json_column } ' column in table { table_name } ."
234
+ )
235
+ self .metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL
236
+
183
237
def add_documents (self , docs : List [Document ]) -> None :
184
238
"""
185
239
Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
@@ -190,7 +244,12 @@ def add_documents(self, docs: List[Document]) -> None:
190
244
"""
191
245
with self .engine .connect () as conn :
192
246
for doc in docs :
193
- row = _parse_row_from_doc (self ._table .columns .keys (), doc )
247
+ row = _parse_row_from_doc (
248
+ self ._table .columns .keys (),
249
+ doc ,
250
+ self .content_column ,
251
+ self .metadata_json_column ,
252
+ )
194
253
conn .execute (sqlalchemy .insert (self ._table ).values (row ))
195
254
conn .commit ()
196
255
@@ -204,7 +263,12 @@ def delete(self, docs: List[Document]) -> None:
204
263
"""
205
264
with self .engine .connect () as conn :
206
265
for doc in docs :
207
- row = _parse_row_from_doc (self ._table .columns .keys (), doc )
266
+ row = _parse_row_from_doc (
267
+ self ._table .columns .keys (),
268
+ doc ,
269
+ self .content_column ,
270
+ self .metadata_json_column ,
271
+ )
208
272
# delete by matching all fields of document
209
273
where_conditions = []
210
274
for col in self ._table .columns :
0 commit comments