@@ -205,7 +205,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
205205}
206206
207207func TestBigQueryToolWithDatasetRestriction (t * testing.T ) {
208- ctx , cancel := context .WithTimeout (context .Background (), 3 * time .Minute )
208+ ctx , cancel := context .WithTimeout (context .Background (), 4 * time .Minute )
209209 defer cancel ()
210210
211211 client , err := initBigQueryConnection (BigqueryProject )
@@ -225,6 +225,9 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
225225 allowedForecastTableName2 := "allowed_forecast_table_2"
226226 disallowedForecastTableName := "disallowed_forecast_table"
227227
228+ allowedAnalyzeContributionTableName1 := "allowed_analyze_contribution_table_1"
229+ allowedAnalyzeContributionTableName2 := "allowed_analyze_contribution_table_2"
230+ disallowedAnalyzeContributionTableName := "disallowed_analyze_contribution_table"
228231 // Setup allowed table
229232 allowedTableNameParam1 := fmt .Sprintf ("`%s.%s.%s`" , BigqueryProject , allowedDatasetName1 , allowedTableName1 )
230233 createAllowedTableStmt1 := fmt .Sprintf ("CREATE TABLE %s (id INT64)" , allowedTableNameParam1 )
@@ -259,6 +262,23 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
259262 teardownDisallowedForecast := setupBigQueryTable (t , ctx , client , createDisallowedForecastStmt , insertDisallowedForecastStmt , disallowedDatasetName , disallowedForecastTableFullName , disallowedForecastParams )
260263 defer teardownDisallowedForecast (t )
261264
265+ // Setup allowed analyze contribution table
266+ allowedAnalyzeContributionTableFullName1 := fmt .Sprintf ("`%s.%s.%s`" , BigqueryProject , allowedDatasetName1 , allowedAnalyzeContributionTableName1 )
267+ createAnalyzeContributionStmt1 , insertAnalyzeContributionStmt1 , analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo (allowedAnalyzeContributionTableFullName1 )
268+ teardownAllowedAnalyzeContribution1 := setupBigQueryTable (t , ctx , client , createAnalyzeContributionStmt1 , insertAnalyzeContributionStmt1 , allowedDatasetName1 , allowedAnalyzeContributionTableFullName1 , analyzeContributionParams1 )
269+ defer teardownAllowedAnalyzeContribution1 (t )
270+
271+ allowedAnalyzeContributionTableFullName2 := fmt .Sprintf ("`%s.%s.%s`" , BigqueryProject , allowedDatasetName2 , allowedAnalyzeContributionTableName2 )
272+ createAnalyzeContributionStmt2 , insertAnalyzeContributionStmt2 , analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo (allowedAnalyzeContributionTableFullName2 )
273+ teardownAllowedAnalyzeContribution2 := setupBigQueryTable (t , ctx , client , createAnalyzeContributionStmt2 , insertAnalyzeContributionStmt2 , allowedDatasetName2 , allowedAnalyzeContributionTableFullName2 , analyzeContributionParams2 )
274+ defer teardownAllowedAnalyzeContribution2 (t )
275+
276+ // Setup disallowed analyze contribution table
277+ disallowedAnalyzeContributionTableFullName := fmt .Sprintf ("`%s.%s.%s`" , BigqueryProject , disallowedDatasetName , disallowedAnalyzeContributionTableName )
278+ createDisallowedAnalyzeContributionStmt , insertDisallowedAnalyzeContributionStmt , disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo (disallowedAnalyzeContributionTableFullName )
279+ teardownDisallowedAnalyzeContribution := setupBigQueryTable (t , ctx , client , createDisallowedAnalyzeContributionStmt , insertDisallowedAnalyzeContributionStmt , disallowedDatasetName , disallowedAnalyzeContributionTableFullName , disallowedAnalyzeContributionParams )
280+ defer teardownDisallowedAnalyzeContribution (t )
281+
262282 // Configure source with dataset restriction.
263283 sourceConfig := getBigQueryVars (t )
264284 sourceConfig ["allowedDatasets" ] = []string {allowedDatasetName1 , allowedDatasetName2 }
@@ -300,6 +320,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
300320 "source" : "my-instance" ,
301321 "description" : "Tool to forecast" ,
302322 },
323+ "analyze-contribution-restricted" : map [string ]any {
324+ "kind" : "bigquery-analyze-contribution" ,
325+ "source" : "my-instance" ,
326+ "description" : "Tool to analyze contribution" ,
327+ },
303328 }
304329
305330 // Create config file
@@ -327,8 +352,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
327352
328353 // Run tests
329354 runListDatasetIdsWithRestriction (t , allowedDatasetName1 , allowedDatasetName2 )
330- runListTableIdsWithRestriction (t , allowedDatasetName1 , disallowedDatasetName , allowedTableName1 , allowedForecastTableName1 )
331- runListTableIdsWithRestriction (t , allowedDatasetName2 , disallowedDatasetName , allowedTableName2 , allowedForecastTableName2 )
355+ runListTableIdsWithRestriction (t , allowedDatasetName1 , disallowedDatasetName , allowedTableName1 , allowedForecastTableName1 , allowedAnalyzeContributionTableName1 )
356+ runListTableIdsWithRestriction (t , allowedDatasetName2 , disallowedDatasetName , allowedTableName2 , allowedForecastTableName2 , allowedAnalyzeContributionTableName2 )
332357 runGetDatasetInfoWithRestriction (t , allowedDatasetName1 , disallowedDatasetName )
333358 runGetDatasetInfoWithRestriction (t , allowedDatasetName2 , disallowedDatasetName )
334359 runGetTableInfoWithRestriction (t , allowedDatasetName1 , disallowedDatasetName , allowedTableName1 , disallowedTableName )
@@ -339,6 +364,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
339364 runConversationalAnalyticsWithRestriction (t , allowedDatasetName2 , disallowedDatasetName , allowedTableName2 , disallowedTableName )
340365 runForecastWithRestriction (t , allowedForecastTableFullName1 , disallowedForecastTableFullName )
341366 runForecastWithRestriction (t , allowedForecastTableFullName2 , disallowedForecastTableFullName )
367+ runAnalyzeContributionWithRestriction (t , allowedAnalyzeContributionTableFullName1 , disallowedAnalyzeContributionTableFullName )
368+ runAnalyzeContributionWithRestriction (t , allowedAnalyzeContributionTableFullName2 , disallowedAnalyzeContributionTableFullName )
342369}
343370
344371func TestBigQueryWriteModeAllowed (t * testing.T ) {
@@ -3125,3 +3152,86 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa
31253152 })
31263153 }
31273154}
3155+
3156+ func runAnalyzeContributionWithRestriction (t * testing.T , allowedTableFullName , disallowedTableFullName string ) {
3157+ allowedTableUnquoted := strings .ReplaceAll (allowedTableFullName , "`" , "" )
3158+ disallowedTableUnquoted := strings .ReplaceAll (disallowedTableFullName , "`" , "" )
3159+ disallowedDatasetFQN := strings .Join (strings .Split (disallowedTableUnquoted , "." )[0 :2 ], "." )
3160+
3161+ testCases := []struct {
3162+ name string
3163+ inputData string
3164+ wantStatusCode int
3165+ wantInResult string
3166+ wantInError string
3167+ }{
3168+ {
3169+ name : "invoke with allowed table name" ,
3170+ inputData : allowedTableUnquoted ,
3171+ wantStatusCode : http .StatusOK ,
3172+ wantInResult : `"relative_difference"` ,
3173+ },
3174+ {
3175+ name : "invoke with disallowed table name" ,
3176+ inputData : disallowedTableUnquoted ,
3177+ wantStatusCode : http .StatusBadRequest ,
3178+ wantInError : fmt .Sprintf ("access to dataset '%s' (from table '%s') is not allowed" , disallowedDatasetFQN , disallowedTableUnquoted ),
3179+ },
3180+ {
3181+ name : "invoke with query on allowed table" ,
3182+ inputData : fmt .Sprintf ("SELECT * FROM %s" , allowedTableFullName ),
3183+ wantStatusCode : http .StatusOK ,
3184+ wantInResult : `"relative_difference"` ,
3185+ },
3186+ {
3187+ name : "invoke with query on disallowed table" ,
3188+ inputData : fmt .Sprintf ("SELECT * FROM %s" , disallowedTableFullName ),
3189+ wantStatusCode : http .StatusBadRequest ,
3190+ wantInError : fmt .Sprintf ("query in input_data accesses dataset '%s', which is not in the allowed list" , disallowedDatasetFQN ),
3191+ },
3192+ }
3193+
3194+ for _ , tc := range testCases {
3195+ t .Run (tc .name , func (t * testing.T ) {
3196+ requestBodyMap := map [string ]any {
3197+ "input_data" : tc .inputData ,
3198+ "contribution_metric" : "SUM(metric)" ,
3199+ "is_test_col" : "is_test" ,
3200+ "dimension_id_cols" : []string {"dim1" , "dim2" },
3201+ }
3202+ bodyBytes , err := json .Marshal (requestBodyMap )
3203+ if err != nil {
3204+ t .Fatalf ("failed to marshal request body: %v" , err )
3205+ }
3206+ body := bytes .NewBuffer (bodyBytes )
3207+
3208+ resp , bodyBytes := tests .RunRequest (t , http .MethodPost , "http://127.0.0.1:5000/api/tool/analyze-contribution-restricted/invoke" , body , nil )
3209+
3210+ if resp .StatusCode != tc .wantStatusCode {
3211+ t .Fatalf ("unexpected status code: got %d, want %d. Body: %s" , resp .StatusCode , tc .wantStatusCode , string (bodyBytes ))
3212+ }
3213+
3214+ var respBody map [string ]interface {}
3215+ if err := json .Unmarshal (bodyBytes , & respBody ); err != nil {
3216+ t .Fatalf ("error parsing response body: %v" , err )
3217+ }
3218+
3219+ if tc .wantInResult != "" {
3220+ got , ok := respBody ["result" ].(string )
3221+ if ! ok {
3222+ t .Fatalf ("unable to find result in response body" )
3223+ }
3224+
3225+ if ! strings .Contains (got , tc .wantInResult ) {
3226+ t .Errorf ("unexpected result: got %q, want to contain %q" , string (bodyBytes ), tc .wantInResult )
3227+ }
3228+ }
3229+
3230+ if tc .wantInError != "" {
3231+ if ! strings .Contains (string (bodyBytes ), tc .wantInError ) {
3232+ t .Errorf ("unexpected error message: got %q, want to contain %q" , string (bodyBytes ), tc .wantInError )
3233+ }
3234+ }
3235+ })
3236+ }
3237+ }
0 commit comments