Skip to content

Commit ef28e39

Browse files
Genesis929Yuan325
andauthored
feat(tools/bigquery-analyze-contribution): add allowed dataset support (googleapis#1675)
## Description > Should include a concise description of the changes (bug or feature), it's > impact, along with a summary of the solution ## PR Checklist > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here> --------- Co-authored-by: Yuan Teoh <[email protected]>
1 parent a2006ad commit ef28e39

File tree

3 files changed

+217
-37
lines changed

3 files changed

+217
-37
lines changed

docs/en/resources/tools/bigquery/bigquery-analyze-contribution.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ The behavior of this tool is influenced by the `writeMode` setting on its `bigqu
4646
tools using the same source. This allows the `input_data` parameter to be a query that references temporary resources (e.g.,
4747
`TEMP` tables) created within that session.
4848

49+
The tool's behavior is also influenced by the `allowedDatasets` restriction on the `bigquery` source:
50+
51+
- **Without `allowedDatasets` restriction:** The tool can use any table or query for the `input_data` parameter.
52+
- **With `allowedDatasets` restriction:** The tool verifies that the `input_data` parameter only accesses tables within the allowed datasets.
53+
- If `input_data` is a table ID, the tool checks if the table's dataset is in the allowed list.
54+
- If `input_data` is a query, the tool performs a dry run to analyze the query and rejects it if it accesses any table outside the allowed list.
55+
4956

5057
## Example
5158

internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go

Lines changed: 97 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/googleapis/genai-toolbox/internal/sources"
2626
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
2727
"github.com/googleapis/genai-toolbox/internal/tools"
28+
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
2829
bigqueryrestapi "google.golang.org/api/bigquery/v2"
2930
"google.golang.org/api/iterator"
3031
)
@@ -50,6 +51,8 @@ type compatibleSource interface {
5051
BigQueryRestService() *bigqueryrestapi.Service
5152
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
5253
UseClientAuthorization() bool
54+
IsDatasetAllowed(projectID, datasetID string) bool
55+
BigQueryAllowedDatasets() []string
5356
BigQuerySession() bigqueryds.BigQuerySessionProvider
5457
}
5558

@@ -86,8 +89,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
8689
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
8790
}
8891

89-
inputDataParameter := tools.NewStringParameter("input_data",
90-
"The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query.")
92+
allowedDatasets := s.BigQueryAllowedDatasets()
93+
inputDataDescription := "The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query."
94+
if len(allowedDatasets) > 0 {
95+
datasetIDs := []string{}
96+
for _, ds := range allowedDatasets {
97+
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
98+
}
99+
inputDataDescription += fmt.Sprintf(" The query or table must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
100+
}
101+
102+
inputDataParameter := tools.NewStringParameter("input_data", inputDataDescription)
91103
contributionMetricParameter := tools.NewStringParameter("contribution_metric",
92104
`The name of the column that contains the metric to analyze.
93105
Provides the expression to use to calculate the metric you are analyzing.
@@ -123,17 +135,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
123135

124136
// finish tool setup
125137
t := Tool{
126-
Name: cfg.Name,
127-
Kind: kind,
128-
Parameters: parameters,
129-
AuthRequired: cfg.AuthRequired,
130-
UseClientOAuth: s.UseClientAuthorization(),
131-
ClientCreator: s.BigQueryClientCreator(),
132-
Client: s.BigQueryClient(),
133-
RestService: s.BigQueryRestService(),
134-
SessionProvider: s.BigQuerySession(),
135-
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
136-
mcpManifest: mcpManifest,
138+
Name: cfg.Name,
139+
Kind: kind,
140+
Parameters: parameters,
141+
AuthRequired: cfg.AuthRequired,
142+
UseClientOAuth: s.UseClientAuthorization(),
143+
ClientCreator: s.BigQueryClientCreator(),
144+
Client: s.BigQueryClient(),
145+
RestService: s.BigQueryRestService(),
146+
IsDatasetAllowed: s.IsDatasetAllowed,
147+
AllowedDatasets: allowedDatasets,
148+
SessionProvider: s.BigQuerySession(),
149+
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
150+
mcpManifest: mcpManifest,
137151
}
138152
return t, nil
139153
}
@@ -148,12 +162,14 @@ type Tool struct {
148162
UseClientOAuth bool `yaml:"useClientOAuth"`
149163
Parameters tools.Parameters `yaml:"parameters"`
150164

151-
Client *bigqueryapi.Client
152-
RestService *bigqueryrestapi.Service
153-
ClientCreator bigqueryds.BigqueryClientCreator
154-
SessionProvider bigqueryds.BigQuerySessionProvider
155-
manifest tools.Manifest
156-
mcpManifest tools.McpManifest
165+
Client *bigqueryapi.Client
166+
RestService *bigqueryrestapi.Service
167+
ClientCreator bigqueryds.BigqueryClientCreator
168+
IsDatasetAllowed func(projectID, datasetID string) bool
169+
AllowedDatasets []string
170+
SessionProvider bigqueryds.BigQuerySessionProvider
171+
manifest tools.Manifest
172+
mcpManifest tools.McpManifest
157173
}
158174

159175
// Invoke runs the contribution analysis.
@@ -164,6 +180,22 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
164180
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
165181
}
166182

183+
bqClient := t.Client
184+
restService := t.RestService
185+
var err error
186+
187+
// Initialize new client if using user OAuth token
188+
if t.UseClientOAuth {
189+
tokenStr, err := accessToken.ParseBearerToken()
190+
if err != nil {
191+
return nil, fmt.Errorf("error parsing access token: %w", err)
192+
}
193+
bqClient, restService, err = t.ClientCreator(tokenStr, true)
194+
if err != nil {
195+
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
196+
}
197+
}
198+
167199
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
168200

169201
var options []string
@@ -196,8 +228,54 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
196228
var inputDataSource string
197229
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
198230
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
231+
if len(t.AllowedDatasets) > 0 {
232+
var connProps []*bigqueryapi.ConnectionProperty
233+
session, err := t.SessionProvider(ctx)
234+
if err != nil {
235+
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
236+
}
237+
if session != nil {
238+
connProps = []*bigqueryapi.ConnectionProperty{
239+
{Key: "session_id", Value: session.ID},
240+
}
241+
}
242+
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
243+
if err != nil {
244+
return nil, fmt.Errorf("query validation failed: %w", err)
245+
}
246+
statementType := dryRunJob.Statistics.Query.StatementType
247+
if statementType != "SELECT" {
248+
return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
249+
}
250+
251+
queryStats := dryRunJob.Statistics.Query
252+
if queryStats != nil {
253+
for _, tableRef := range queryStats.ReferencedTables {
254+
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
255+
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
256+
}
257+
}
258+
} else {
259+
return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets")
260+
}
261+
}
199262
inputDataSource = fmt.Sprintf("(%s)", inputData)
200263
} else {
264+
if len(t.AllowedDatasets) > 0 {
265+
parts := strings.Split(inputData, ".")
266+
var projectID, datasetID string
267+
switch len(parts) {
268+
case 3: // project.dataset.table
269+
projectID, datasetID = parts[0], parts[1]
270+
case 2: // dataset.table
271+
projectID, datasetID = t.Client.Project(), parts[0]
272+
default:
273+
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
274+
}
275+
if !t.IsDatasetAllowed(projectID, datasetID) {
276+
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
277+
}
278+
}
201279
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
202280
}
203281

@@ -209,21 +287,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
209287
inputDataSource,
210288
)
211289

212-
bqClient := t.Client
213-
var err error
214-
215-
// Initialize new client if using user OAuth token
216-
if t.UseClientOAuth {
217-
tokenStr, err := accessToken.ParseBearerToken()
218-
if err != nil {
219-
return nil, fmt.Errorf("error parsing access token: %w", err)
220-
}
221-
bqClient, _, err = t.ClientCreator(tokenStr, false)
222-
if err != nil {
223-
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
224-
}
225-
}
226-
227290
createModelQuery := bqClient.Query(createModelSQL)
228291

229292
// Get session from provider if in protected mode.

tests/bigquery/bigquery_integration_test.go

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
205205
}
206206

207207
func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
208-
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
208+
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute)
209209
defer cancel()
210210

211211
client, err := initBigQueryConnection(BigqueryProject)
@@ -225,6 +225,9 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
225225
allowedForecastTableName2 := "allowed_forecast_table_2"
226226
disallowedForecastTableName := "disallowed_forecast_table"
227227

228+
allowedAnalyzeContributionTableName1 := "allowed_analyze_contribution_table_1"
229+
allowedAnalyzeContributionTableName2 := "allowed_analyze_contribution_table_2"
230+
disallowedAnalyzeContributionTableName := "disallowed_analyze_contribution_table"
228231
// Setup allowed table
229232
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
230233
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
@@ -259,6 +262,23 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
259262
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
260263
defer teardownDisallowedForecast(t)
261264

265+
// Setup allowed analyze contribution table
266+
allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1)
267+
createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1)
268+
teardownAllowedAnalyzeContribution1 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
269+
defer teardownAllowedAnalyzeContribution1(t)
270+
271+
allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2)
272+
createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2)
273+
teardownAllowedAnalyzeContribution2 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
274+
defer teardownAllowedAnalyzeContribution2(t)
275+
276+
// Setup disallowed analyze contribution table
277+
disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName)
278+
createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName)
279+
teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
280+
defer teardownDisallowedAnalyzeContribution(t)
281+
262282
// Configure source with dataset restriction.
263283
sourceConfig := getBigQueryVars(t)
264284
sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2}
@@ -300,6 +320,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
300320
"source": "my-instance",
301321
"description": "Tool to forecast",
302322
},
323+
"analyze-contribution-restricted": map[string]any{
324+
"kind": "bigquery-analyze-contribution",
325+
"source": "my-instance",
326+
"description": "Tool to analyze contribution",
327+
},
303328
}
304329

305330
// Create config file
@@ -327,8 +352,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
327352

328353
// Run tests
329354
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
330-
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
331-
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
355+
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1, allowedAnalyzeContributionTableName1)
356+
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2, allowedAnalyzeContributionTableName2)
332357
runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName)
333358
runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName)
334359
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
@@ -339,6 +364,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
339364
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
340365
runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName)
341366
runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName)
367+
runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName1, disallowedAnalyzeContributionTableFullName)
368+
runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName2, disallowedAnalyzeContributionTableFullName)
342369
}
343370

344371
func TestBigQueryWriteModeAllowed(t *testing.T) {
@@ -3125,3 +3152,86 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa
31253152
})
31263153
}
31273154
}
3155+
3156+
func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
3157+
allowedTableUnquoted := strings.ReplaceAll(allowedTableFullName, "`", "")
3158+
disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "")
3159+
disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".")
3160+
3161+
testCases := []struct {
3162+
name string
3163+
inputData string
3164+
wantStatusCode int
3165+
wantInResult string
3166+
wantInError string
3167+
}{
3168+
{
3169+
name: "invoke with allowed table name",
3170+
inputData: allowedTableUnquoted,
3171+
wantStatusCode: http.StatusOK,
3172+
wantInResult: `"relative_difference"`,
3173+
},
3174+
{
3175+
name: "invoke with disallowed table name",
3176+
inputData: disallowedTableUnquoted,
3177+
wantStatusCode: http.StatusBadRequest,
3178+
wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted),
3179+
},
3180+
{
3181+
name: "invoke with query on allowed table",
3182+
inputData: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName),
3183+
wantStatusCode: http.StatusOK,
3184+
wantInResult: `"relative_difference"`,
3185+
},
3186+
{
3187+
name: "invoke with query on disallowed table",
3188+
inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName),
3189+
wantStatusCode: http.StatusBadRequest,
3190+
wantInError: fmt.Sprintf("query in input_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN),
3191+
},
3192+
}
3193+
3194+
for _, tc := range testCases {
3195+
t.Run(tc.name, func(t *testing.T) {
3196+
requestBodyMap := map[string]any{
3197+
"input_data": tc.inputData,
3198+
"contribution_metric": "SUM(metric)",
3199+
"is_test_col": "is_test",
3200+
"dimension_id_cols": []string{"dim1", "dim2"},
3201+
}
3202+
bodyBytes, err := json.Marshal(requestBodyMap)
3203+
if err != nil {
3204+
t.Fatalf("failed to marshal request body: %v", err)
3205+
}
3206+
body := bytes.NewBuffer(bodyBytes)
3207+
3208+
resp, bodyBytes := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/api/tool/analyze-contribution-restricted/invoke", body, nil)
3209+
3210+
if resp.StatusCode != tc.wantStatusCode {
3211+
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
3212+
}
3213+
3214+
var respBody map[string]interface{}
3215+
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
3216+
t.Fatalf("error parsing response body: %v", err)
3217+
}
3218+
3219+
if tc.wantInResult != "" {
3220+
got, ok := respBody["result"].(string)
3221+
if !ok {
3222+
t.Fatalf("unable to find result in response body")
3223+
}
3224+
3225+
if !strings.Contains(got, tc.wantInResult) {
3226+
t.Errorf("unexpected result: got %q, want to contain %q", string(bodyBytes), tc.wantInResult)
3227+
}
3228+
}
3229+
3230+
if tc.wantInError != "" {
3231+
if !strings.Contains(string(bodyBytes), tc.wantInError) {
3232+
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
3233+
}
3234+
}
3235+
})
3236+
}
3237+
}

0 commit comments

Comments
 (0)