Skip to content

Commit 56b6574

Browse files
feat(source/cloud-sql-admin): Add User agent and attach sqldmin in cloud-sql-admin source. (#1441)
## Description --- 1. This change introduces a userAgentRoundTripper that correctly prepends our custom user agent to the existing User-Agent header 2. Moves sqladmin client to source. 3. Updated cloudsql tools for above support. 4. Add test cases to validate User agent. ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here>
1 parent 7d384dc commit 56b6574

File tree

11 files changed

+115
-94
lines changed

11 files changed

+115
-94
lines changed

docs/en/resources/tools/cloudsql/cloudsqlgetinstances.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ description: >
99
The `cloud-sql-get-instance` tool retrieves a Cloud SQL instance resource using the Cloud SQL Admin API.
1010

1111
{{< notice info >}}
12-
This tool uses a `source` of kind `cloud-sql-admin`. The source automatically generates a bearer token on behalf of the user with the `https://www.googleapis.com/auth/sqlservice.admin` scope to authenticate requests.
12+
This tool uses a `source` of kind `cloud-sql-admin`.
1313
{{< /notice >}}
1414

1515
## Example
@@ -18,14 +18,14 @@ This tool uses a `source` of kind `cloud-sql-admin`. The source automatically ge
1818
tools:
1919
get-sql-instance:
2020
kind: cloud-sql-get-instance
21-
description: "Get a Cloud SQL instance resource."
22-
source: my-cloud-sql-source
21+
source: my-cloud-sql-admin-source
22+
description: "Gets a particular cloud sql instance."
2323
```
2424
2525
## Reference
2626
27-
| **field** | **type** | **required** | **description** |
28-
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
29-
| kind | string | true | Must be "cloud-sql-get-instance". |
30-
| description | string | true | A description of the tool. |
31-
| source | string | true | The name of the `cloud-sql-admin` source to use. |
27+
| **field** | **type** | **required** | **description** |
28+
| ----------- | :------: | :----------: | ------------------------------------------------ |
29+
| kind | string | true | Must be "cloud-sql-get-instance". |
30+
| source | string | true | The name of the `cloud-sql-admin` source to use. |
31+
| description | string | false | A description of the tool. |

docs/en/resources/tools/cloudsql/cloudsqlwaitforoperation.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ tools:
3232
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
3333
| kind | string | true | Must be "cloud-sql-wait-for-operation". |
3434
| source | string | true | The name of a `cloud-sql-admin` source to use for authentication. |
35-
| description | string | true | A description of the tool. |
36-
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
37-
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |
38-
| multiplier | float | false | The multiplier for the polling delay. The delay is multiplied by this value after each request. Defaults to 2.0. |
39-
| maxRetries | int | false | The maximum number of polling attempts before giving up. Defaults to 10. |
35+
| description | string | false | A description of the tool. |
36+
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
37+
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |
38+
| multiplier | float | false | The multiplier for the polling delay. The delay is multiplied by this value after each request. Defaults to 2.0. |
39+
| maxRetries | int | false | The maximum number of polling attempts before giving up. Defaults to 10. |

internal/sources/cloudsqladmin/cloud_sql_admin.go

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,32 @@ import (
2424
"go.opentelemetry.io/otel/trace"
2525
"golang.org/x/oauth2"
2626
"golang.org/x/oauth2/google"
27+
"google.golang.org/api/option"
2728
sqladmin "google.golang.org/api/sqladmin/v1"
2829
)
2930

3031
const SourceKind string = "cloud-sql-admin"
3132

33+
type userAgentRoundTripper struct {
34+
userAgent string
35+
next http.RoundTripper
36+
}
37+
38+
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
39+
newReq := *req
40+
newReq.Header = make(http.Header)
41+
for k, v := range req.Header {
42+
newReq.Header[k] = v
43+
}
44+
ua := newReq.Header.Get("User-Agent")
45+
if ua == "" {
46+
newReq.Header.Set("User-Agent", rt.userAgent)
47+
} else {
48+
newReq.Header.Set("User-Agent", rt.userAgent+" "+ua)
49+
}
50+
return rt.next.RoundTrip(&newReq)
51+
}
52+
3253
// validate interface
3354
var _ sources.SourceConfig = Config{}
3455

@@ -65,22 +86,36 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
6586

6687
var client *http.Client
6788
if r.UseClientOAuth {
68-
client = nil
89+
client = &http.Client{
90+
Transport: &userAgentRoundTripper{
91+
userAgent: ua,
92+
next: http.DefaultTransport,
93+
},
94+
}
6995
} else {
7096
// Use Application Default Credentials
7197
creds, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope)
7298
if err != nil {
7399
return nil, fmt.Errorf("failed to find default credentials: %w", err)
74100
}
75-
client = oauth2.NewClient(ctx, creds.TokenSource)
101+
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
102+
baseClient.Transport = &userAgentRoundTripper{
103+
userAgent: ua,
104+
next: baseClient.Transport,
105+
}
106+
client = baseClient
107+
}
108+
109+
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
110+
if err != nil {
111+
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
76112
}
77113

78114
s := &Source{
79115
Name: r.Name,
80116
Kind: SourceKind,
81117
BaseURL: "https://sqladmin.googleapis.com",
82-
Client: client,
83-
UserAgent: ua,
118+
Service: service,
84119
UseClientOAuth: r.UseClientOAuth,
85120
}
86121
return s, nil
@@ -92,24 +127,25 @@ type Source struct {
92127
Name string `yaml:"name"`
93128
Kind string `yaml:"kind"`
94129
BaseURL string
95-
Client *http.Client
96-
UserAgent string
130+
Service *sqladmin.Service
97131
UseClientOAuth bool
98132
}
99133

100134
func (s *Source) SourceKind() string {
101135
return SourceKind
102136
}
103137

104-
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
138+
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
105139
if s.UseClientOAuth {
106-
if accessToken == "" {
107-
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
108-
}
109140
token := &oauth2.Token{AccessToken: accessToken}
110-
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
141+
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
142+
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
143+
if err != nil {
144+
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
145+
}
146+
return service, nil
111147
}
112-
return s.Client, nil
148+
return s.Service, nil
113149
}
114150

115151
func (s *Source) UseClientAuthorization() bool {

internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"github.com/googleapis/genai-toolbox/internal/sources"
2323
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
2424
"github.com/googleapis/genai-toolbox/internal/tools"
25-
"google.golang.org/api/option"
2625
sqladmin "google.golang.org/api/sqladmin/v1"
2726
)
2827

@@ -135,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
135134

136135
iamUser, _ := paramsMap["iamUser"].(bool)
137136

138-
user := &sqladmin.User{
137+
user := sqladmin.User{
139138
Name: name,
140139
}
141140

@@ -150,19 +149,12 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
150149
user.Password = password
151150
}
152151

153-
client, err := t.Source.GetClient(ctx, string(accessToken))
152+
service, err := t.Source.GetService(ctx, string(accessToken))
154153
if err != nil {
155154
return nil, err
156155
}
157156

158-
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
159-
if err != nil {
160-
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
161-
}
162-
163-
service.UserAgent = t.Source.UserAgent
164-
165-
resp, err := service.Users.Insert(project, instance, user).Do()
157+
resp, err := service.Users.Insert(project, instance, &user).Do()
166158
if err != nil {
167159
return nil, fmt.Errorf("error creating user: %w", err)
168160
}

internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import (
2222
"github.com/googleapis/genai-toolbox/internal/sources"
2323
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
2424
"github.com/googleapis/genai-toolbox/internal/tools"
25-
"google.golang.org/api/option"
26-
sqladmin "google.golang.org/api/sqladmin/v1"
2725
)
2826

2927
const kind string = "cloud-sql-get-instance"
@@ -46,7 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
4644
type Config struct {
4745
Name string `yaml:"name" validate:"required"`
4846
Kind string `yaml:"kind" validate:"required"`
49-
Description string `yaml:"description" validate:"required"`
47+
Description string `yaml:"description"`
5048
Source string `yaml:"source" validate:"required"`
5149
AuthRequired []string `yaml:"authRequired"`
5250
}
@@ -80,9 +78,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
8078
inputSchema := allParameters.McpManifest()
8179
inputSchema.Required = []string{"projectId", "instanceId"}
8280

81+
description := cfg.Description
82+
if description == "" {
83+
description = "Gets a particular cloud sql instance."
84+
}
85+
8386
mcpManifest := tools.McpManifest{
8487
Name: cfg.Name,
85-
Description: cfg.Description,
88+
Description: description,
8689
InputSchema: inputSchema,
8790
}
8891

@@ -123,17 +126,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
123126
return nil, fmt.Errorf("missing 'instanceId' parameter")
124127
}
125128

126-
client, err := t.Source.GetClient(ctx, string(accessToken))
129+
service, err := t.Source.GetService(ctx, string(accessToken))
127130
if err != nil {
128131
return nil, err
129132
}
130133

131-
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
132-
if err != nil {
133-
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
134-
}
135-
service.UserAgent = t.Source.UserAgent
136-
137134
resp, err := service.Instances.Get(projectId, instanceId).Do()
138135
if err != nil {
139136
return nil, fmt.Errorf("error getting instance: %w", err)

internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ package cloudsqllistinstances
1616

1717
import (
1818
"context"
19-
"encoding/json"
2019
"fmt"
21-
"io"
22-
"net/http"
2320

2421
"github.com/goccy/go-yaml"
2522
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -123,48 +120,34 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
123120
return nil, fmt.Errorf("missing 'project' parameter")
124121
}
125122

126-
client, err := t.source.GetClient(ctx, string(accessToken))
123+
service, err := t.source.GetService(ctx, string(accessToken))
127124
if err != nil {
128125
return nil, err
129126
}
130127

131-
urlString := fmt.Sprintf("%s/v1/projects/%s/instances", t.source.BaseURL, project)
132-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlString, nil)
128+
resp, err := service.Instances.List(project).Do()
133129
if err != nil {
134-
return nil, fmt.Errorf("error creating request: %w", err)
130+
return nil, fmt.Errorf("error listing instances: %w", err)
135131
}
136132

137-
resp, err := client.Do(req)
138-
if err != nil {
139-
return nil, fmt.Errorf("error making HTTP request: %w", err)
140-
}
141-
defer resp.Body.Close()
142-
143-
body, err := io.ReadAll(resp.Body)
144-
if err != nil {
145-
return nil, fmt.Errorf("error reading response body: %w", err)
146-
}
147-
148-
if resp.StatusCode != http.StatusOK {
149-
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
150-
}
151-
152-
var v struct {
153-
Items []struct {
154-
Name string `json:"name"`
155-
InstanceType string `json:"instanceType"`
156-
} `json:"items"`
133+
if resp.Items == nil {
134+
return []any{}, nil
157135
}
158136

159-
if err := json.Unmarshal(body, &v); err != nil {
160-
return nil, fmt.Errorf("error unmarshaling response body: %w", err)
137+
type instanceInfo struct {
138+
Name string `json:"name"`
139+
InstanceType string `json:"instanceType"`
161140
}
162141

163-
if v.Items == nil {
164-
return []any{}, nil
142+
var instances []instanceInfo
143+
for _, item := range resp.Items {
144+
instances = append(instances, instanceInfo{
145+
Name: item.Name,
146+
InstanceType: item.InstanceType,
147+
})
165148
}
166149

167-
return v.Items, nil
150+
return instances, nil
168151
}
169152

170153
// ParseParams parses the parameters for the tool.

internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ import (
2727
"github.com/googleapis/genai-toolbox/internal/sources"
2828
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
2929
"github.com/googleapis/genai-toolbox/internal/tools"
30-
"google.golang.org/api/option"
31-
sqladmin "google.golang.org/api/sqladmin/v1"
3230
)
3331

3432
const kind string = "cloud-sql-wait-for-operation"
@@ -93,7 +91,7 @@ type Config struct {
9391
Name string `yaml:"name" validate:"required"`
9492
Kind string `yaml:"kind" validate:"required"`
9593
Source string `yaml:"source" validate:"required"`
96-
Description string `yaml:"description" validate:"required"`
94+
Description string `yaml:"description"`
9795
AuthRequired []string `yaml:"authRequired"`
9896
BaseURL string `yaml:"baseURL"`
9997

@@ -133,9 +131,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
133131
inputSchema := allParameters.McpManifest()
134132
inputSchema.Required = []string{"project", "operation"}
135133

134+
description := cfg.Description
135+
if description == "" {
136+
description = "This will poll on operations API until the operation is done. For checking operation status we need projectId and operationId. Once instance is created give follow up steps on how to use the variables to bring data plane MCP server up in local and remote setup."
137+
}
138+
136139
mcpManifest := tools.McpManifest{
137140
Name: cfg.Name,
138-
Description: cfg.Description,
141+
Description: description,
139142
InputSchema: inputSchema,
140143
}
141144

@@ -219,16 +222,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
219222
return nil, fmt.Errorf("missing 'operation' parameter")
220223
}
221224

222-
client, err := t.Source.GetClient(ctx, string(accessToken))
225+
service, err := t.Source.GetService(ctx, string(accessToken))
223226
if err != nil {
224227
return nil, err
225228
}
226229

227-
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
228-
if err != nil {
229-
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
230-
}
231-
232230
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
233231
defer cancel()
234232

@@ -389,14 +387,10 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
389387
}
390388

391389
func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) {
392-
client, err := t.Source.GetClient(ctx, "")
390+
service, err := t.Source.GetService(ctx, "")
393391
if err != nil {
394392
return nil, err
395393
}
396-
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
397-
if err != nil {
398-
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
399-
}
400394

401395
resp, err := service.Instances.Get(project, instance).Do()
402396
if err != nil {

tests/cloudsql/cloud_sql_create_users_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ type masterCreateUserHandler struct {
6262
}
6363

6464
func (h *masterCreateUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
65+
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
66+
h.t.Errorf("User-Agent header not found")
67+
}
68+
6569
var body userCreateRequest
6670
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
6771
h.t.Fatalf("failed to decode request body: %v", err)

0 commit comments

Comments
 (0)