Skip to content
This repository was archived by the owner on Dec 31, 2023. It is now read-only.

Commit b89fb00

Browse files
authored
fix: Pass the 'params' parameter to the underlying 'BatchPredictRequest' object in 'batch_predict()' method (#110)
1 parent df22fd5 commit b89fb00

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

google/cloud/automl_v1beta1/services/tables/tables_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2999,7 +2999,10 @@ def batch_predict(
29992999
)
30003000

30013001
req = google.cloud.automl_v1beta1.BatchPredictRequest(
3002-
name=model_name, input_config=input_request, output_config=output_request,
3002+
name=model_name,
3003+
input_config=input_request,
3004+
output_config=output_request,
3005+
params=params,
30033006
)
30043007

30053008
method_kwargs = self.__process_request_kwargs(req, **kwargs)

tests/unit/test_tables_client_v1beta1.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,24 @@ def test_batch_predict_bigquery(self):
15991599
)
16001600
)
16011601

1602+
def test_batch_predict_bigquery_with_params(self):
1603+
client = self.tables_client({}, {})
1604+
client.batch_predict(
1605+
model_name="my_model",
1606+
bigquery_input_uri="bq://input",
1607+
bigquery_output_uri="bq://output",
1608+
params={"feature_importance": "true"},
1609+
)
1610+
1611+
client.prediction_client.batch_predict.assert_called_with(
1612+
request=automl_v1beta1.BatchPredictRequest(
1613+
name="my_model",
1614+
input_config={"bigquery_source": {"input_uri": "bq://input"}},
1615+
output_config={"bigquery_destination": {"output_uri": "bq://output"}},
1616+
params={"feature_importance": "true"},
1617+
)
1618+
)
1619+
16021620
def test_batch_predict_mixed(self):
16031621
client = self.tables_client({}, {})
16041622
client.batch_predict(

0 commit comments

Comments
 (0)