Skip to content

Commit f475da6

Browse files
feat(neo4j): Add dry_run parameter to validate Cypher queries (googleapis#1769)
This pull request adds support for a new `dry_run` mode to the Neo4j Cypher execution tool, allowing users to validate queries and view execution plans without running them. It also sets a custom user agent for Neo4j connections and improves error handling and documentation. The most important changes are grouped below. ### New dry run feature for Cypher execution * Added an optional `dry_run` boolean parameter to the `neo4j-execute-cypher` tool, allowing users to validate Cypher queries and receive execution plan details without running the query. The tool now prepends `EXPLAIN` to the query when `dry_run` is true and returns a structured summary of the execution plan. [[1]](diffhunk://#diff-de7fdd7e68c95ea9813c704a89fffb8fd6de34e81b43a484623fdff7683e18f3L87-R93) [[2]](diffhunk://#diff-de7fdd7e68c95ea9813c704a89fffb8fd6de34e81b43a484623fdff7683e18f3R155-R188) [[3]](diffhunk://#diff-de7fdd7e68c95ea9813c704a89fffb8fd6de34e81b43a484623fdff7683e18f3R219-R236) [[4]](diffhunk://#diff-1dca93fc9450e9b9ea64bc1ae02774c3198ea6f8310b2437815bd1a5eae11e79L30-R32) * Updated integration tests to cover the new `dry_run` functionality, including successful dry runs, error handling for invalid syntax, and enforcement of read-only mode. [[1]](diffhunk://#diff-b07de4a304bc72964b5de9481cbc6aec6cf9bb9dabd903a837eb8974e7100a90R163-R169) [[2]](diffhunk://#diff-b07de4a304bc72964b5de9481cbc6aec6cf9bb9dabd903a837eb8974e7100a90R250-R291) ### Improved error handling * Enhanced error messages for parameter casting in the tool's `Invoke` method to clarify issues with input parameters. ### Neo4j driver configuration * Set a custom user agent (`genai-toolbox/neo4j-source`) for Neo4j driver connections to help identify requests from this tool. [[1]](diffhunk://#diff-3f0444add0913f1722d678118ffedc70039cca3603f31c9927c06be5e00ffb29R24-R29) [[2]](diffhunk://#diff-3f0444add0913f1722d678118ffedc70039cca3603f31c9927c06be5e00ffb29L109-R113) ### Documentation updates * Updated the documentation to describe the new `dry_run` parameter and its usage for query validation. --------- Co-authored-by: Yuan Teoh <[email protected]>
1 parent eb04e0d commit f475da6

File tree

4 files changed

+126
-5
lines changed

4 files changed

+126
-5
lines changed

docs/en/resources/tools/neo4j/neo4j-execute-cypher.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ Cypher](https://neo4j.com/docs/cypher-manual/current/queries/) syntax and
2727
supports all Cypher features, including pattern matching, filtering, and
2828
aggregation.
2929

30-
`neo4j-execute-cypher` takes one input parameter `cypher` and run the cypher
31-
query against the `source`.
30+
`neo4j-execute-cypher` takes a required input parameter `cypher` and run the cypher
31+
query against the `source`. It also supports an optional `dry_run`
32+
parameter to validate a query without executing it.
3233

3334
> **Note:** This tool is intended for developer assistant workflows with
3435
> human-in-the-loop and shouldn't be used for production agents.

internal/sources/neo4j/neo4j.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import (
2020

2121
"github.com/goccy/go-yaml"
2222
"github.com/googleapis/genai-toolbox/internal/sources"
23+
"github.com/googleapis/genai-toolbox/internal/util"
2324
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
25+
neo4jconf "github.com/neo4j/neo4j-go-driver/v5/neo4j/config"
2426
"go.opentelemetry.io/otel/trace"
2527
)
2628

@@ -106,7 +108,13 @@ func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, passwo
106108
defer span.End()
107109

108110
auth := neo4j.BasicAuth(user, password, "")
109-
driver, err := neo4j.NewDriverWithContext(uri, auth)
111+
userAgent, err := util.UserAgentFromContext(ctx)
112+
if err != nil {
113+
return nil, err
114+
}
115+
driver, err := neo4j.NewDriverWithContext(uri, auth, func(config *neo4jconf.Config) {
116+
config.UserAgent = userAgent
117+
})
110118
if err != nil {
111119
return nil, fmt.Errorf("unable to create connection driver: %w", err)
112120
}

internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
8484
}
8585

8686
cypherParameter := tools.NewStringParameter("cypher", "The cypher to execute.")
87-
parameters := tools.Parameters{cypherParameter}
87+
dryRunParameter := tools.NewBooleanParameterWithDefault(
88+
"dry_run",
89+
false,
90+
"If set to true, the query will be validated and information about the execution "+
91+
"will be returned without running the query. Defaults to false.",
92+
)
93+
parameters := tools.Parameters{cypherParameter, dryRunParameter}
8894

8995
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
9096

@@ -124,13 +130,18 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
124130
paramsMap := params.AsMap()
125131
cypherStr, ok := paramsMap["cypher"].(string)
126132
if !ok {
127-
return nil, fmt.Errorf("unable to get cast %s", paramsMap["cypher"])
133+
return nil, fmt.Errorf("unable to cast cypher parameter %s", paramsMap["cypher"])
128134
}
129135

130136
if cypherStr == "" {
131137
return nil, fmt.Errorf("parameter 'cypher' must be a non-empty string")
132138
}
133139

140+
dryRun, ok := paramsMap["dry_run"].(bool)
141+
if !ok {
142+
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
143+
}
144+
134145
// validate the cypher query before executing
135146
cf := t.classifier.Classify(cypherStr)
136147
if cf.Error != nil {
@@ -141,16 +152,40 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
141152
return nil, fmt.Errorf("this tool is read-only and cannot execute write queries")
142153
}
143154

155+
if dryRun {
156+
// Add EXPLAIN to the beginning of the query to validate it without executing
157+
cypherStr = "EXPLAIN " + cypherStr
158+
}
159+
144160
config := neo4j.ExecuteQueryWithDatabase(t.Database)
145161
results, err := neo4j.ExecuteQuery(ctx, t.Driver, cypherStr, nil,
146162
neo4j.EagerResultTransformer, config)
147163
if err != nil {
148164
return nil, fmt.Errorf("unable to execute query: %w", err)
149165
}
150166

167+
// If dry run, return the summary information only
168+
if dryRun {
169+
summary := results.Summary
170+
plan := summary.Plan()
171+
execPlan := map[string]any{
172+
"queryType": cf.Type.String(),
173+
"statementType": summary.StatementType(),
174+
"operator": plan.Operator(),
175+
"arguments": plan.Arguments(),
176+
"identifiers": plan.Identifiers(),
177+
"childrenCount": len(plan.Children()),
178+
}
179+
if len(plan.Children()) > 0 {
180+
execPlan["children"] = addPlanChildren(plan)
181+
}
182+
return []map[string]any{execPlan}, nil
183+
}
184+
151185
var out []any
152186
keys := results.Keys
153187
records := results.Records
188+
154189
for _, record := range records {
155190
vMap := make(map[string]any)
156191
for col, value := range record.Values {
@@ -181,3 +216,21 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
181216
func (t Tool) RequiresClientAuthorization() bool {
182217
return false
183218
}
219+
220+
// Recursive function to add plan children
221+
func addPlanChildren(p neo4j.Plan) []map[string]any {
222+
var children []map[string]any
223+
for _, child := range p.Children() {
224+
childMap := map[string]any{
225+
"operator": child.Operator(),
226+
"arguments": child.Arguments(),
227+
"identifiers": child.Identifiers(),
228+
"children_count": len(child.Children()),
229+
}
230+
if len(child.Children()) > 0 {
231+
childMap["children"] = addPlanChildren(child)
232+
}
233+
children = append(children, childMap)
234+
}
235+
return children
236+
}

tests/neo4j/neo4j_integration_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ func TestNeo4jToolEndpoints(t *testing.T) {
160160
"description": "The cypher to execute.",
161161
"authSources": []any{},
162162
},
163+
map[string]any{
164+
"name": "dry_run",
165+
"type": "boolean",
166+
"required": false,
167+
"description": "If set to true, the query will be validated and information about the execution will be returned without running the query. Defaults to false.",
168+
"authSources": []any{},
169+
},
163170
},
164171
"authRequired": []any{},
165172
},
@@ -240,13 +247,65 @@ func TestNeo4jToolEndpoints(t *testing.T) {
240247
want: "[{\"a\":1}]",
241248
wantStatus: http.StatusOK,
242249
},
250+
{
251+
name: "invoke my-simple-execute-cypher-tool with dry_run",
252+
api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke",
253+
requestBody: bytes.NewBuffer([]byte(`{"cypher": "MATCH (n:Test) RETURN n", "dry_run": true}`)),
254+
wantStatus: http.StatusOK,
255+
validateFunc: func(t *testing.T, body string) {
256+
var result []map[string]any
257+
if err := json.Unmarshal([]byte(body), &result); err != nil {
258+
t.Fatalf("failed to unmarshal dry_run result: %v", err)
259+
}
260+
if len(result) == 0 {
261+
t.Fatalf("expected a query plan, but got an empty result")
262+
}
263+
264+
operatorValue, ok := result[0]["operator"]
265+
if !ok {
266+
t.Fatalf("expected key 'Operator' not found in dry_run response: %s", body)
267+
}
268+
269+
operatorStr, ok := operatorValue.(string)
270+
if !ok {
271+
t.Fatalf("expected 'Operator' to be a string, but got %T", operatorValue)
272+
}
273+
274+
if operatorStr != "ProduceResults@neo4j" {
275+
t.Errorf("unexpected operator: got %q, want %q", operatorStr, "ProduceResults@neo4j")
276+
}
277+
278+
childrenCount, ok := result[0]["childrenCount"]
279+
if !ok {
280+
t.Fatalf("expected key 'ChildrenCount' not found in dry_run response: %s", body)
281+
}
282+
283+
if childrenCount.(float64) != 1 {
284+
t.Errorf("unexpected children count: got %v, want %d", childrenCount, 1)
285+
}
286+
},
287+
},
288+
{
289+
name: "invoke my-simple-execute-cypher-tool with dry_run and invalid syntax",
290+
api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke",
291+
requestBody: bytes.NewBuffer([]byte(`{"cypher": "RTN 1", "dry_run": true}`)),
292+
wantStatus: http.StatusBadRequest,
293+
wantErrorSubstring: "unable to execute query",
294+
},
243295
{
244296
name: "invoke readonly tool with write query",
245297
api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke",
246298
requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)"}`)),
247299
wantStatus: http.StatusBadRequest,
248300
wantErrorSubstring: "this tool is read-only and cannot execute write queries",
249301
},
302+
{
303+
name: "invoke readonly tool with write query and dry_run",
304+
api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke",
305+
requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)", "dry_run": true}`)),
306+
wantStatus: http.StatusBadRequest,
307+
wantErrorSubstring: "this tool is read-only and cannot execute write queries",
308+
},
250309
{
251310
name: "invoke my-schema-tool",
252311
api: "http://127.0.0.1:5000/api/tool/my-schema-tool/invoke",

0 commit comments

Comments
 (0)