Skip to content

Commit 8d33b15

Browse files
authored
refactor: merge cached_schema_for_type into schema_for_type (#581)
1 parent bce0555 commit 8d33b15

File tree

7 files changed

+91
-30
lines changed

7 files changed

+91
-30
lines changed

crates/rmcp-macros/src/tool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
220220
if let Some(params_ty) = params_ty {
221221
// if found, use the Parameters schema
222222
syn::parse2::<Expr>(quote! {
223-
rmcp::handler::server::common::cached_schema_for_type::<#params_ty>()
223+
rmcp::handler::server::common::schema_for_type::<#params_ty>()
224224
})?
225225
} else {
226226
// if not found, use a default empty JSON schema object

crates/rmcp/src/handler/server/common.rs

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,8 @@ use crate::{
88
RoleServer, model::JsonObject, schemars::generate::SchemaSettings, service::RequestContext,
99
};
1010

11-
/// A shortcut for generating a JSON schema for a type.
12-
pub fn schema_for_type<T: JsonSchema>() -> JsonObject {
13-
// explicitly to align json schema version to official specifications.
14-
// refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details.
15-
let mut settings = SchemaSettings::draft2020_12();
16-
settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())];
17-
let generator = settings.into_generator();
18-
let schema = generator.into_root_schema_for::<T>();
19-
let object = serde_json::to_value(schema).expect("failed to serialize schema");
20-
match object {
21-
serde_json::Value::Object(object) => object,
22-
_ => panic!(
23-
"Schema serialization produced non-object value: expected JSON object but got {:?}",
24-
object
25-
),
26-
}
27-
}
28-
29-
/// Call [`schema_for_type`] with a cache
30-
pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
11+
/// Generates a JSON schema for a type
12+
pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
3113
thread_local! {
3214
static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
3315
};
@@ -39,12 +21,26 @@ pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject
3921
{
4022
x.clone()
4123
} else {
42-
let schema = schema_for_type::<T>();
43-
let schema = Arc::new(schema);
24+
// explicitly to align json schema version to official specifications.
25+
// refer to https://github.com/modelcontextprotocol/modelcontextprotocol/pull/655 for details.
26+
let mut settings = SchemaSettings::draft2020_12();
27+
settings.transforms = vec![Box::new(schemars::transform::AddNullable::default())];
28+
let generator = settings.into_generator();
29+
let schema = generator.into_root_schema_for::<T>();
30+
let object = serde_json::to_value(schema).expect("failed to serialize schema");
31+
let object = match object {
32+
serde_json::Value::Object(object) => object,
33+
_ => panic!(
34+
"Schema serialization produced non-object value: expected JSON object but got {:?}",
35+
object
36+
),
37+
};
38+
let schema = Arc::new(object);
4439
cache
4540
.write()
4641
.expect("schema cache lock poisoned")
4742
.insert(TypeId::of::<T>(), schema.clone());
43+
4844
schema
4945
}
5046
})
@@ -69,7 +65,7 @@ pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObje
6965
// Generate and validate schema
7066
let schema = schema_for_type::<T>();
7167
let result = match schema.get("type") {
72-
Some(serde_json::Value::String(t)) if t == "object" => Ok(Arc::new(schema)),
68+
Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
7369
Some(serde_json::Value::String(t)) => Err(format!(
7470
"MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
7571
t
@@ -196,6 +192,71 @@ mod tests {
196192
value: i32,
197193
}
198194

195+
#[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
196+
struct AnotherTestObject {
197+
value: i32,
198+
}
199+
200+
#[test]
201+
fn test_schema_for_type_handles_primitive() {
202+
let schema = schema_for_type::<i32>();
203+
204+
assert_eq!(schema.get("type"), Some(&serde_json::json!("integer")));
205+
}
206+
207+
#[test]
208+
fn test_schema_for_type_handles_array() {
209+
let schema = schema_for_type::<Vec<i32>>();
210+
211+
assert_eq!(schema.get("type"), Some(&serde_json::json!("array")));
212+
let items = schema.get("items").and_then(|v| v.as_object());
213+
assert_eq!(
214+
items.unwrap().get("type"),
215+
Some(&serde_json::json!("integer"))
216+
);
217+
}
218+
219+
#[test]
220+
fn test_schema_for_type_handles_struct() {
221+
let schema = schema_for_type::<TestObject>();
222+
223+
assert_eq!(schema.get("type"), Some(&serde_json::json!("object")));
224+
let properties = schema.get("properties").and_then(|v| v.as_object());
225+
assert!(properties.unwrap().contains_key("value"));
226+
}
227+
228+
#[test]
229+
fn test_schema_for_type_caches_primitive_types() {
230+
let schema1 = schema_for_type::<i32>();
231+
let schema2 = schema_for_type::<i32>();
232+
233+
assert!(Arc::ptr_eq(&schema1, &schema2));
234+
}
235+
236+
#[test]
237+
fn test_schema_for_type_caches_struct_types() {
238+
let schema1 = schema_for_type::<TestObject>();
239+
let schema2 = schema_for_type::<TestObject>();
240+
241+
assert!(Arc::ptr_eq(&schema1, &schema2));
242+
}
243+
244+
#[test]
245+
fn test_schema_for_type_different_types_different_schemas() {
246+
let schema1 = schema_for_type::<TestObject>();
247+
let schema2 = schema_for_type::<AnotherTestObject>();
248+
249+
assert!(!Arc::ptr_eq(&schema1, &schema2));
250+
}
251+
252+
#[test]
253+
fn test_schema_for_type_arc_can_be_shared() {
254+
let schema = schema_for_type::<TestObject>();
255+
let cloned = schema.clone();
256+
257+
assert!(Arc::ptr_eq(&schema, &cloned));
258+
}
259+
199260
#[test]
200261
fn test_schema_for_output_rejects_primitive() {
201262
let result = schema_for_output::<i32>();

crates/rmcp/src/handler/server/prompt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ impl_prompt_handler_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
325325
/// as PromptArgument entries with name, description, and required status
326326
pub fn cached_arguments_from_schema<T: schemars::JsonSchema + std::any::Any>()
327327
-> Option<Vec<crate::model::PromptArgument>> {
328-
let schema = super::common::cached_schema_for_type::<T>();
328+
let schema = super::common::schema_for_type::<T>();
329329
let schema_value = serde_json::Value::Object((*schema).clone());
330330

331331
let properties = schema_value.get("properties").and_then(|p| p.as_object());

crates/rmcp/src/handler/server/router/tool.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ where
154154
self.attr.description = Some(description.into());
155155
self
156156
}
157-
pub fn parameters<T: JsonSchema>(mut self) -> Self {
158-
self.attr.input_schema = schema_for_type::<T>().into();
157+
pub fn parameters<T: JsonSchema + 'static>(mut self) -> Self {
158+
self.attr.input_schema = schema_for_type::<T>();
159159
self
160160
}
161161
pub fn parameters_value(mut self, schema: serde_json::Value) -> Self {

crates/rmcp/src/handler/server/tool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use serde::de::DeserializeOwned;
99

1010
use super::common::{AsRequestContext, FromContextPart};
1111
pub use super::{
12-
common::{Extension, RequestId, cached_schema_for_type, schema_for_output, schema_for_type},
12+
common::{Extension, RequestId, schema_for_output, schema_for_type},
1313
router::tool::{ToolRoute, ToolRouter},
1414
};
1515
use crate::{

crates/rmcp/src/model/tool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ impl Tool {
178178

179179
/// Set the input schema using a type that implements JsonSchema
180180
pub fn with_input_schema<T: JsonSchema + 'static>(mut self) -> Self {
181-
self.input_schema = crate::handler::server::tool::cached_schema_for_type::<T>();
181+
self.input_schema = crate::handler::server::tool::schema_for_type::<T>();
182182
self
183183
}
184184

crates/rmcp/tests/test_json_schema_detection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl TestServer {
5555
}
5656

5757
/// Tool with explicit output_schema attribute - should have output schema
58-
#[tool(name = "explicit-schema", output_schema = rmcp::handler::server::tool::cached_schema_for_type::<TestData>())]
58+
#[tool(name = "explicit-schema", output_schema = rmcp::handler::server::tool::schema_for_type::<TestData>())]
5959
pub async fn explicit_schema(&self) -> Result<String, String> {
6060
Ok("test".to_string())
6161
}

0 commit comments

Comments
 (0)