Skip to content

Commit 68968c4

Browse files
mcp: make schema caching opt-in via ServerOptions.SchemaCache
Add ServerOptions.SchemaCache to enable optional caching of JSON schemas for tools. When set, the cache avoids repeated reflection-based schema generation and resolution, which significantly improves performance for stateless server deployments where tools are re-registered on every request. This change addresses reviewer feedback to make the caching optimization opt-in, avoiding subtle behavior changes for users who may re-use and modify schemas between requests. Usage: cache := mcp.NewSchemaCache() server := mcp.NewServer(impl, &mcp.ServerOptions{ SchemaCache: cache, }) mcp.AddTool(server, tool, handler)
1 parent 3579d1a commit 68968c4

File tree

5 files changed

+98
-54
lines changed

5 files changed

+98
-54
lines changed

mcp/schema_cache.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818
// This cache significantly improves performance for stateless server deployments
1919
// where tools are re-registered on every request. Without caching, each AddTool
2020
// call would trigger expensive reflection-based schema generation and resolution.
21+
//
22+
// Create a cache using [NewSchemaCache] and pass it to [ServerOptions.SchemaCache].
2123
type schemaCache struct {
2224
// byType caches schemas generated from Go types via jsonschema.ForType.
2325
// Key: reflect.Type, Value: *cachedSchema
@@ -36,9 +38,11 @@ type cachedSchema struct {
3638
resolved *jsonschema.Resolved
3739
}
3840

39-
// globalSchemaCache is the package-level cache used by setSchema.
40-
// It is unbounded since typical MCP servers have <100 tools.
41-
var globalSchemaCache = &schemaCache{}
41+
// NewSchemaCache creates a new schema cache for use with [ServerOptions.SchemaCache].
42+
// Safe for concurrent use, unbounded.
43+
func NewSchemaCache() *schemaCache {
44+
return &schemaCache{}
45+
}
4246

4347
// getByType retrieves a cached schema by Go type.
4448
// Returns the schema, resolved schema, and whether the cache hit.
@@ -68,9 +72,3 @@ func (c *schemaCache) getBySchema(schema *jsonschema.Schema) (*jsonschema.Resolv
6872
func (c *schemaCache) setBySchema(schema *jsonschema.Schema, resolved *jsonschema.Resolved) {
6973
c.bySchema.Store(schema, resolved)
7074
}
71-
72-
// resetForTesting clears the cache. Only for use in tests.
73-
func (c *schemaCache) resetForTesting() {
74-
c.byType.Clear()
75-
c.bySchema.Clear()
76-
}

mcp/schema_cache_benchmark_test.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ func BenchmarkAddToolTypedHandler(b *testing.B) {
3434
Description: "Search for items",
3535
}
3636

37-
// Reset cache to simulate cold start for first iteration
38-
globalSchemaCache.resetForTesting()
37+
// Create a shared cache for caching benefit
38+
cache := NewSchemaCache()
3939

4040
b.ResetTimer()
4141
b.ReportAllocs()
4242

4343
for i := 0; i < b.N; i++ {
44-
s := NewServer(&Implementation{Name: "test", Version: "1.0"}, nil)
44+
s := NewServer(&Implementation{Name: "test", Version: "1.0"}, &ServerOptions{
45+
SchemaCache: cache,
46+
})
4547
AddTool(s, tool, handler)
4648
}
4749
}
@@ -69,9 +71,6 @@ func BenchmarkAddToolPreDefinedSchema(b *testing.B) {
6971
InputSchema: schema, // Pre-defined schema like github-mcp-server
7072
}
7173

72-
// Reset cache to simulate cold start for first iteration
73-
globalSchemaCache.resetForTesting()
74-
7574
b.ResetTimer()
7675
b.ReportAllocs()
7776

@@ -108,9 +107,7 @@ func BenchmarkAddToolTypedHandlerNoCache(b *testing.B) {
108107
b.ReportAllocs()
109108

110109
for i := 0; i < b.N; i++ {
111-
// Reset cache every iteration to simulate no caching
112-
globalSchemaCache.resetForTesting()
113-
110+
// No cache - each iteration generates new schemas
114111
s := NewServer(&Implementation{Name: "test", Version: "1.0"}, nil)
115112
AddTool(s, tool, handler)
116113
}
@@ -146,14 +143,16 @@ func BenchmarkAddToolMultipleTools(b *testing.B) {
146143
tool2 := &Tool{Name: "tool2", Description: "Tool 2"}
147144
tool3 := &Tool{Name: "tool3", Description: "Tool 3"}
148145

149-
// Reset cache before benchmark
150-
globalSchemaCache.resetForTesting()
146+
// Create a shared cache for caching benefit
147+
cache := NewSchemaCache()
151148

152149
b.ResetTimer()
153150
b.ReportAllocs()
154151

155152
for i := 0; i < b.N; i++ {
156-
s := NewServer(&Implementation{Name: "test", Version: "1.0"}, nil)
153+
s := NewServer(&Implementation{Name: "test", Version: "1.0"}, &ServerOptions{
154+
SchemaCache: cache,
155+
})
157156
AddTool(s, tool1, handler1)
158157
AddTool(s, tool2, handler2)
159158
AddTool(s, tool3, handler3)

mcp/schema_cache_test.go

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
)
1414

1515
func TestSchemaCache_ByType(t *testing.T) {
16-
cache := &schemaCache{}
16+
cache := NewSchemaCache()
1717

1818
type TestInput struct {
1919
Name string `json:"name"`
@@ -49,7 +49,7 @@ func TestSchemaCache_ByType(t *testing.T) {
4949
}
5050

5151
func TestSchemaCache_BySchema(t *testing.T) {
52-
cache := &schemaCache{}
52+
cache := NewSchemaCache()
5353

5454
schema := &jsonschema.Schema{
5555
Type: "object",
@@ -89,7 +89,7 @@ func TestSchemaCache_BySchema(t *testing.T) {
8989
}
9090

9191
func TestSetSchema_CachesGeneratedSchemas(t *testing.T) {
92-
globalSchemaCache.resetForTesting()
92+
cache := NewSchemaCache()
9393

9494
type TestInput struct {
9595
Query string `json:"query"`
@@ -100,21 +100,21 @@ func TestSetSchema_CachesGeneratedSchemas(t *testing.T) {
100100
// First call should generate and cache
101101
var sfield1 any
102102
var rfield1 *jsonschema.Resolved
103-
_, err := setSchema[TestInput](&sfield1, &rfield1)
103+
_, err := setSchema[TestInput](&sfield1, &rfield1, cache)
104104
if err != nil {
105105
t.Fatalf("setSchema failed: %v", err)
106106
}
107107

108108
// Verify it's in cache
109-
cachedSchema, cachedResolved, ok := globalSchemaCache.getByType(rt)
109+
cachedSchema, cachedResolved, ok := cache.getByType(rt)
110110
if !ok {
111111
t.Fatal("schema not cached after first setSchema call")
112112
}
113113

114114
// Second call should hit cache
115115
var sfield2 any
116116
var rfield2 *jsonschema.Resolved
117-
_, err = setSchema[TestInput](&sfield2, &rfield2)
117+
_, err = setSchema[TestInput](&sfield2, &rfield2, cache)
118118
if err != nil {
119119
t.Fatalf("setSchema failed on second call: %v", err)
120120
}
@@ -129,7 +129,7 @@ func TestSetSchema_CachesGeneratedSchemas(t *testing.T) {
129129
}
130130

131131
func TestSetSchema_CachesProvidedSchemas(t *testing.T) {
132-
globalSchemaCache.resetForTesting()
132+
cache := NewSchemaCache()
133133

134134
// This simulates the github-mcp-server pattern:
135135
// schema is created once and reused across requests
@@ -143,13 +143,13 @@ func TestSetSchema_CachesProvidedSchemas(t *testing.T) {
143143
// First call should resolve and cache
144144
var sfield1 any = schema
145145
var rfield1 *jsonschema.Resolved
146-
_, err := setSchema[map[string]any](&sfield1, &rfield1)
146+
_, err := setSchema[map[string]any](&sfield1, &rfield1, cache)
147147
if err != nil {
148148
t.Fatalf("setSchema failed: %v", err)
149149
}
150150

151151
// Verify it's in cache
152-
cachedResolved, ok := globalSchemaCache.getBySchema(schema)
152+
cachedResolved, ok := cache.getBySchema(schema)
153153
if !ok {
154154
t.Fatal("resolved schema not cached after first setSchema call")
155155
}
@@ -160,7 +160,7 @@ func TestSetSchema_CachesProvidedSchemas(t *testing.T) {
160160
// Second call with same schema pointer should hit cache
161161
var sfield2 any = schema
162162
var rfield2 *jsonschema.Resolved
163-
_, err = setSchema[map[string]any](&sfield2, &rfield2)
163+
_, err = setSchema[map[string]any](&sfield2, &rfield2, cache)
164164
if err != nil {
165165
t.Fatalf("setSchema failed on second call: %v", err)
166166
}
@@ -170,8 +170,39 @@ func TestSetSchema_CachesProvidedSchemas(t *testing.T) {
170170
}
171171
}
172172

173+
func TestSetSchema_NoCacheWhenNil(t *testing.T) {
174+
type TestInput struct {
175+
Query string `json:"query"`
176+
}
177+
178+
// First call without cache
179+
var sfield1 any
180+
var rfield1 *jsonschema.Resolved
181+
_, err := setSchema[TestInput](&sfield1, &rfield1, nil)
182+
if err != nil {
183+
t.Fatalf("setSchema failed: %v", err)
184+
}
185+
186+
// Second call without cache - should still generate a new schema
187+
var sfield2 any
188+
var rfield2 *jsonschema.Resolved
189+
_, err = setSchema[TestInput](&sfield2, &rfield2, nil)
190+
if err != nil {
191+
t.Fatalf("setSchema failed on second call: %v", err)
192+
}
193+
194+
// Both calls should succeed, schemas should be equivalent but not same pointer
195+
// (since no caching is happening)
196+
if sfield1 == nil || sfield2 == nil {
197+
t.Error("expected schemas to be generated")
198+
}
199+
if rfield1 == nil || rfield2 == nil {
200+
t.Error("expected resolved schemas to be generated")
201+
}
202+
}
203+
173204
func TestAddTool_CachesBetweenCalls(t *testing.T) {
174-
globalSchemaCache.resetForTesting()
205+
cache := NewSchemaCache()
175206

176207
type GreetInput struct {
177208
Name string `json:"name" jsonschema:"the name to greet"`
@@ -190,15 +221,17 @@ func TestAddTool_CachesBetweenCalls(t *testing.T) {
190221
Description: "Greet someone",
191222
}
192223

193-
// Simulate stateless server pattern: create new server each time
224+
// Simulate stateless server pattern: create new server each time, but share cache
194225
for i := 0; i < 3; i++ {
195-
s := NewServer(&Implementation{Name: "test", Version: "1.0"}, nil)
226+
s := NewServer(&Implementation{Name: "test", Version: "1.0"}, &ServerOptions{
227+
SchemaCache: cache,
228+
})
196229
AddTool(s, tool, handler)
197230
}
198231

199232
// Verify schema was cached by type
200233
rt := reflect.TypeFor[GreetInput]()
201-
_, _, ok := globalSchemaCache.getByType(rt)
234+
_, _, ok := cache.getByType(rt)
202235
if !ok {
203236
t.Error("expected schema to be cached by type after multiple AddTool calls")
204237
}

mcp/server.go

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ type ServerOptions struct {
8989
// If true, advertises the tools capability during initialization,
9090
// even if no tools have been registered.
9191
HasTools bool
92+
// SchemaCache, if non-nil, enables caching of JSON schemas for tools.
93+
// This can significantly improve performance for stateless server
94+
// deployments where tools are re-registered on every request.
95+
SchemaCache *schemaCache
9296

9397
// GetSessionID provides the next session ID to use for an incoming request.
9498
// If nil, a default randomly generated ID will be used.
@@ -239,7 +243,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
239243
func() bool { s.tools.add(st); return true })
240244
}
241245

242-
func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) {
246+
func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *schemaCache) (*Tool, ToolHandler, error) {
243247
tt := *t
244248

245249
// Special handling for an "any" input: treat as an empty object.
@@ -248,7 +252,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
248252
}
249253

250254
var inputResolved *jsonschema.Resolved
251-
if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil {
255+
if _, err := setSchema[In](&tt.InputSchema, &inputResolved, cache); err != nil {
252256
return nil, nil, fmt.Errorf("input schema: %w", err)
253257
}
254258

@@ -263,7 +267,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
263267
)
264268
if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
265269
var err error
266-
elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved)
270+
elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved, cache)
267271
if err != nil {
268272
return nil, nil, fmt.Errorf("output schema: %v", err)
269273
}
@@ -364,9 +368,11 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
364368
// pointer: if the user provided the schema, they may have intentionally
365369
// derived it from the pointer type, and handling of zero values is up to them.
366370
//
371+
// If cache is non-nil, schemas are cached to avoid repeated reflection.
372+
//
367373
// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we
368374
// should have a jsonschema.Zero(schema) helper?
369-
func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) {
375+
func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, cache *schemaCache) (zero any, err error) {
370376
rt := reflect.TypeFor[T]()
371377
if rt.Kind() == reflect.Pointer {
372378
rt = rt.Elem()
@@ -377,37 +383,43 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err
377383

378384
if *sfield == nil {
379385
// Case 1: No schema provided - check type cache first
380-
if schema, resolved, ok := globalSchemaCache.getByType(rt); ok {
381-
*sfield = schema
382-
*rfield = resolved
383-
return zero, nil
386+
if cache != nil {
387+
if schema, resolved, ok := cache.getByType(rt); ok {
388+
*sfield = schema
389+
*rfield = resolved
390+
return zero, nil
391+
}
384392
}
385393

386-
// Generate schema via reflection (expensive, but cached for next time)
394+
// Generate schema via reflection (expensive, but cached for next time if cache is set)
387395
internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{})
388396
if err != nil {
389397
return zero, err
390398
}
391399
*sfield = internalSchema
392400

393-
// Resolve and cache
401+
// Resolve and optionally cache
394402
resolved, err := internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
395403
if err != nil {
396404
return zero, err
397405
}
398406
*rfield = resolved
399-
globalSchemaCache.setByType(rt, internalSchema, resolved)
407+
if cache != nil {
408+
cache.setByType(rt, internalSchema, resolved)
409+
}
400410
return zero, nil
401411
}
402412

403413
// Case 2: Schema was provided
404414
// Check if it's a *jsonschema.Schema we can cache by pointer
405415
if providedSchema, ok := (*sfield).(*jsonschema.Schema); ok {
406-
if resolved, ok := globalSchemaCache.getBySchema(providedSchema); ok {
407-
*rfield = resolved
408-
return zero, nil
416+
if cache != nil {
417+
if resolved, ok := cache.getBySchema(providedSchema); ok {
418+
*rfield = resolved
419+
return zero, nil
420+
}
409421
}
410-
// Need to resolve and cache
422+
// Need to resolve and optionally cache
411423
internalSchema = providedSchema
412424
} else {
413425
// Schema provided as different type (e.g., map) - need to remarshal
@@ -424,8 +436,10 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err
424436
*rfield = resolved
425437

426438
// Cache by schema pointer if we got a direct *jsonschema.Schema
427-
if providedSchema, ok := (*sfield).(*jsonschema.Schema); ok {
428-
globalSchemaCache.setBySchema(providedSchema, resolved)
439+
if cache != nil {
440+
if providedSchema, ok := (*sfield).(*jsonschema.Schema); ok {
441+
cache.setBySchema(providedSchema, resolved)
442+
}
429443
}
430444

431445
return zero, nil
@@ -451,7 +465,7 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err
451465
// tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed
452466
// description of this automatic behavior.
453467
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
454-
tt, hh, err := toolForErr(t, h)
468+
tt, hh, err := toolForErr(t, h, s.opts.SchemaCache)
455469
if err != nil {
456470
panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err))
457471
}

mcp/server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out
562562
th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) {
563563
return nil, out, nil
564564
}
565-
gott, goth, err := toolForErr(tool, th)
565+
gott, goth, err := toolForErr(tool, th, nil)
566566
if err != nil {
567567
t.Fatal(err)
568568
}

0 commit comments

Comments
 (0)