Rust Best Practices for MCP Servers¶
Rust-Specific Best Practices¶
This guide covers Rust-specific best practices, patterns, and idioms for building robust, efficient, and maintainable MCP servers that leverage Rust's unique strengths.
Code Organization and Architecture¶
Module Structure and Visibility¶
// โ
Good: Clear module hierarchy with appropriate visibility
pub mod server {
pub mod handlers;
pub mod transport;
// Re-export public API
pub use handlers::Handlers;
pub use transport::Transport;
}
pub mod tools {
mod database; // Private implementation
mod http; // Private implementation
pub use database::DatabaseTool;
pub use http::HttpTool;
// Trait for tool implementations
pub trait ToolExecutor: Send + Sync {
// Method definitions
}
}
// โ Bad: Everything public or unclear hierarchy
pub mod everything {
pub mod database;
pub mod http;
pub mod utils;
pub mod helpers;
pub mod stuff;
}
Error Handling with thiserror and anyhow¶
// โ
Good: Structured error types with context
use thiserror::Error;
#[derive(Error, Debug)]
pub enum McpError {
#[error("Configuration error: {message}")]
Config { message: String },
#[error("Tool '{tool_name}' execution failed: {source}")]
ToolExecution {
tool_name: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("Validation failed for field '{field}': {message}")]
Validation { field: String, message: String },
#[error("Resource '{uri}' not found")]
ResourceNotFound { uri: String },
#[error("Permission denied: {operation}")]
PermissionDenied { operation: String },
#[error("I/O error")]
Io(#[from] std::io::Error),
#[error("HTTP request failed")]
Http(#[from] reqwest::Error),
#[error("JSON serialization error")]
Json(#[from] serde_json::Error),
}
// Extension trait for better error context
pub trait McpErrorExt<T> {
fn with_tool_context(self, tool_name: &str) -> Result<T, McpError>;
fn with_validation_context(self, field: &str) -> Result<T, McpError>;
}
impl<T, E> McpErrorExt<T> for Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
fn with_tool_context(self, tool_name: &str) -> Result<T, McpError> {
self.map_err(|e| McpError::ToolExecution {
tool_name: tool_name.to_string(),
source: Box::new(e),
})
}
fn with_validation_context(self, field: &str) -> Result<T, McpError> {
self.map_err(|e| McpError::Validation {
field: field.to_string(),
message: e.to_string(),
})
}
}
// Usage
fn execute_tool(name: &str) -> Result<String, McpError> {
some_operation()
.with_tool_context(name)?
.into_string()
.with_validation_context("result")
}
Type Safety with Newtypes¶
// โ
Good: Strong typing with newtypes
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ToolName(String);
impl ToolName {
pub fn new(name: impl Into<String>) -> Result<Self, McpError> {
let name = name.into();
if name.is_empty() {
return Err(McpError::Validation {
field: "tool_name".to_string(),
message: "Tool name cannot be empty".to_string(),
});
}
if name.len() > 100 {
return Err(McpError::Validation {
field: "tool_name".to_string(),
message: "Tool name too long".to_string(),
});
}
Ok(Self(name))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for ToolName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceUri(String);
impl ResourceUri {
pub fn new(uri: impl Into<String>) -> Result<Self, McpError> {
let uri = uri.into();
url::Url::parse(&uri)
.map_err(|e| McpError::Validation {
field: "resource_uri".to_string(),
message: format!("Invalid URI: {}", e),
})?;
Ok(Self(uri))
}
}
// โ Bad: Stringly typed parameters
fn execute_tool(name: String, uri: String) -> Result<String, McpError> {
// No compile-time guarantees about validity
unimplemented!()
}
// โ
Good: Type-safe parameters
fn execute_tool(name: ToolName, uri: ResourceUri) -> Result<String, McpError> {
// Guaranteed valid at compile time
unimplemented!()
}
Async Programming Best Practices¶
Structured Concurrency¶
// โ
Good: Structured concurrency with proper error handling
use tokio::select;
use tokio::sync::mpsc;
use tokio::time::{timeout, Duration};
pub struct ToolExecutorPool {
workers: Vec<tokio::task::JoinHandle<()>>,
task_tx: mpsc::UnboundedSender<ToolTask>,
shutdown_tx: mpsc::Sender<()>,
}
impl ToolExecutorPool {
pub async fn new(worker_count: usize) -> Self {
let (task_tx, task_rx) = mpsc::unbounded_channel();
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let task_rx = Arc::new(Mutex::new(task_rx));
let shutdown_rx = Arc::new(Mutex::new(shutdown_rx));
let mut workers = Vec::new();
for worker_id in 0..worker_count {
let task_rx = task_rx.clone();
let shutdown_rx = shutdown_rx.clone();
let handle = tokio::spawn(async move {
Self::worker_loop(worker_id, task_rx, shutdown_rx).await;
});
workers.push(handle);
}
Self {
workers,
task_tx,
shutdown_tx,
}
}
async fn worker_loop(
worker_id: usize,
task_rx: Arc<Mutex<mpsc::UnboundedReceiver<ToolTask>>>,
shutdown_rx: Arc<Mutex<mpsc::Receiver<()>>>,
) {
tracing::info!("Worker {} started", worker_id);
loop {
select! {
// Receive shutdown signal
_ = async {
let mut rx = shutdown_rx.lock().await;
rx.recv().await
} => {
tracing::info!("Worker {} shutting down", worker_id);
break;
}
// Process task
task = async {
let mut rx = task_rx.lock().await;
rx.recv().await
} => {
match task {
Some(task) => {
if let Err(e) = Self::process_task(task).await {
tracing::error!("Worker {} task failed: {}", worker_id, e);
}
}
None => break, // Channel closed
}
}
}
}
tracing::info!("Worker {} stopped", worker_id);
}
async fn process_task(task: ToolTask) -> Result<(), McpError> {
// Add timeout to prevent hanging
let result = timeout(Duration::from_secs(30), async {
task.tool.execute(task.arguments).await
}).await;
match result {
Ok(Ok(result)) => {
let _ = task.response_tx.send(Ok(result));
}
Ok(Err(e)) => {
let _ = task.response_tx.send(Err(e));
}
Err(_) => {
let _ = task.response_tx.send(Err(McpError::ToolExecution {
tool_name: "unknown".to_string(),
source: Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Task execution timed out"
)),
}));
}
}
Ok(())
}
pub async fn execute_tool(
&self,
tool: Arc<dyn ToolExecutor>,
arguments: HashMap<String, Value>,
) -> Result<ToolResult, McpError> {
let (response_tx, mut response_rx) = mpsc::channel(1);
let task = ToolTask {
tool,
arguments,
response_tx,
};
self.task_tx.send(task)
.map_err(|_| McpError::Config {
message: "Tool executor pool is shut down".to_string(),
})?;
response_rx.recv().await
.ok_or_else(|| McpError::Config {
message: "Failed to receive task result".to_string(),
})?
}
pub async fn shutdown(self) {
// Signal shutdown
let _ = self.shutdown_tx.send(()).await;
// Wait for all workers to finish
for handle in self.workers {
let _ = handle.await;
}
}
}
struct ToolTask {
tool: Arc<dyn ToolExecutor>,
arguments: HashMap<String, Value>,
response_tx: mpsc::Sender<Result<ToolResult, McpError>>,
}
Resource Management with RAII¶
// โ
Good: RAII pattern with proper cleanup
pub struct DatabaseConnection {
pool: sqlx::PgPool,
connection_id: uuid::Uuid,
acquired_at: std::time::Instant,
}
impl DatabaseConnection {
pub async fn acquire(pool: &sqlx::PgPool) -> Result<Self, McpError> {
let connection_id = uuid::Uuid::new_v4();
let acquired_at = std::time::Instant::now();
tracing::debug!("Acquiring database connection {}", connection_id);
Ok(Self {
pool: pool.clone(),
connection_id,
acquired_at,
})
}
pub async fn execute_query(&self, query: &str) -> Result<Vec<serde_json::Value>, McpError> {
tracing::debug!("Executing query on connection {}", self.connection_id);
let rows = sqlx::query(query)
.fetch_all(&self.pool)
.await
.map_err(|e| McpError::ToolExecution {
tool_name: "database".to_string(),
source: Box::new(e),
})?;
// Convert rows to JSON...
Ok(vec![])
}
}
impl Drop for DatabaseConnection {
fn drop(&mut self) {
let duration = self.acquired_at.elapsed();
tracing::debug!(
"Releasing database connection {} after {:?}",
self.connection_id,
duration
);
}
}
// Usage ensures automatic cleanup
async fn use_database() -> Result<(), McpError> {
let pool = get_pool().await?;
let conn = DatabaseConnection::acquire(&pool).await?;
conn.execute_query("SELECT 1").await?;
// Connection automatically cleaned up when `conn` goes out of scope
Ok(())
}
Memory Management and Performance¶
Zero-Copy Operations¶
// โ
Good: Zero-copy string handling
use std::borrow::Cow;
pub fn process_content<'a>(input: &'a str, should_transform: bool) -> Cow<'a, str> {
if should_transform {
// Only allocate when transformation is needed
Cow::Owned(input.to_uppercase())
} else {
// Return borrowed reference
Cow::Borrowed(input)
}
}
// โ
Good: Efficient JSON handling with serde_json::RawValue
use serde_json::value::RawValue;
pub struct ToolArguments<'a> {
// Store raw JSON to avoid unnecessary parsing
raw: &'a RawValue,
}
impl<'a> ToolArguments<'a> {
pub fn get_string(&self, key: &str) -> Result<Option<&str>, McpError> {
let obj: serde_json::Map<String, serde_json::Value> =
serde_json::from_str(self.raw.get())?;
Ok(obj.get(key).and_then(|v| v.as_str()))
}
pub fn parse_into<T: serde::de::DeserializeOwned>(&self) -> Result<T, McpError> {
serde_json::from_str(self.raw.get())
.map_err(|e| McpError::Json(e))
}
}
Custom Allocators and Memory Pools¶
// โ
Good: Object pooling for frequently allocated objects
use std::sync::{Arc, Mutex};
pub struct BufferPool {
buffers: Arc<Mutex<Vec<Vec<u8>>>>,
max_capacity: usize,
buffer_size: usize,
}
impl BufferPool {
pub fn new(max_capacity: usize, buffer_size: usize) -> Self {
Self {
buffers: Arc::new(Mutex::new(Vec::new())),
max_capacity,
buffer_size,
}
}
pub fn get(&self) -> PooledBuffer {
let mut pool = self.buffers.lock().unwrap();
let buffer = pool.pop().unwrap_or_else(|| Vec::with_capacity(self.buffer_size));
PooledBuffer {
buffer: Some(buffer),
pool: self.buffers.clone(),
}
}
}
pub struct PooledBuffer {
buffer: Option<Vec<u8>>,
pool: Arc<Mutex<Vec<Vec<u8>>>>,
}
impl PooledBuffer {
pub fn as_mut(&mut self) -> &mut Vec<u8> {
self.buffer.as_mut().unwrap()
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
if let Some(mut buffer) = self.buffer.take() {
buffer.clear(); // Reset but keep capacity
let mut pool = self.pool.lock().unwrap();
if pool.len() < pool.capacity() {
pool.push(buffer);
}
}
}
}
Security Best Practices¶
Input Validation with Type System¶
// โ
Good: Validation at type level
use serde::{Deserialize, Deserializer};
use std::fmt;
#[derive(Debug, Clone)]
pub struct SafeFilePath(std::path::PathBuf);
impl SafeFilePath {
pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self, McpError> {
let path = path.as_ref();
// Prevent path traversal
if path.to_string_lossy().contains("..") {
return Err(McpError::ValidationError {
field: "path".to_string(),
message: "Path traversal not allowed".to_string(),
});
}
// Ensure absolute path
let canonical = path.canonicalize()
.map_err(|e| McpError::ValidationError {
field: "path".to_string(),
message: format!("Invalid path: {}", e),
})?;
Ok(Self(canonical))
}
pub fn as_path(&self) -> &std::path::Path {
&self.0
}
}
impl<'de> Deserialize<'de> for SafeFilePath {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
SafeFilePath::new(s).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone)]
pub struct SafeUrl(url::Url);
impl SafeUrl {
pub fn new(url_str: &str, allowed_schemes: &[&str]) -> Result<Self, McpError> {
let url = url::Url::parse(url_str)
.map_err(|e| McpError::ValidationError {
field: "url".to_string(),
message: format!("Invalid URL: {}", e),
})?;
if !allowed_schemes.contains(&url.scheme()) {
return Err(McpError::ValidationError {
field: "url".to_string(),
message: format!("Scheme '{}' not allowed", url.scheme()),
});
}
Ok(Self(url))
}
pub fn as_url(&self) -> &url::Url {
&self.0
}
}
// Usage in tool arguments
#[derive(Deserialize)]
pub struct FileReadArgs {
pub path: SafeFilePath,
}
#[derive(Deserialize)]
pub struct HttpRequestArgs {
pub url: SafeUrl,
pub method: Option<HttpMethod>,
}
Secure Default Configurations¶
// โ
Good: Secure defaults with explicit opt-in for permissive settings
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
/// Allow dangerous operations (default: false)
#[serde(default)]
pub allow_dangerous_operations: bool,
/// Maximum request size in bytes (default: 1MB)
#[serde(default = "default_max_request_size")]
pub max_request_size: u64,
/// Request timeout in seconds (default: 30)
#[serde(default = "default_request_timeout")]
pub request_timeout: u64,
/// Allowed file extensions (default: empty = all allowed)
#[serde(default)]
pub allowed_file_extensions: Vec<String>,
/// Rate limiting (requests per minute, default: 60)
#[serde(default = "default_rate_limit")]
pub rate_limit: u32,
}
fn default_max_request_size() -> u64 { 1024 * 1024 } // 1MB
fn default_request_timeout() -> u64 { 30 }
fn default_rate_limit() -> u32 { 60 }
impl Default for SecurityConfig {
fn default() -> Self {
Self {
allow_dangerous_operations: false,
max_request_size: default_max_request_size(),
request_timeout: default_request_timeout(),
allowed_file_extensions: Vec::new(), // Empty = all allowed
rate_limit: default_rate_limit(),
}
}
}
impl SecurityConfig {
/// Validate security configuration
pub fn validate(&self) -> Result<(), McpError> {
if self.max_request_size > 100 * 1024 * 1024 {
return Err(McpError::ValidationError {
field: "max_request_size".to_string(),
message: "Maximum request size too large (>100MB)".to_string(),
});
}
if self.request_timeout > 300 {
return Err(McpError::ValidationError {
field: "request_timeout".to_string(),
message: "Request timeout too long (>5 minutes)".to_string(),
});
}
Ok(())
}
/// Check if file extension is allowed
pub fn is_file_extension_allowed(&self, path: &std::path::Path) -> bool {
if self.allowed_file_extensions.is_empty() {
return true; // All extensions allowed
}
path.extension()
.and_then(|ext| ext.to_str())
.map(|ext| self.allowed_file_extensions.contains(&ext.to_string()))
.unwrap_or(false)
}
}
Testing Patterns¶
Property-Based Testing¶
// โ
Good: Property-based testing for robust validation
use quickcheck::{quickcheck, TestResult};
use quickcheck_macros::quickcheck;
#[quickcheck]
fn prop_safe_file_path_prevents_traversal(input: String) -> TestResult {
// Skip inputs that are valid paths to focus on traversal attempts
if !input.contains("..") {
return TestResult::discard();
}
let result = SafeFilePath::new(&input);
TestResult::from_bool(result.is_err())
}
#[quickcheck]
fn prop_tool_name_length_limits(input: String) -> TestResult {
let result = ToolName::new(input.clone());
if input.is_empty() || input.len() > 100 {
TestResult::from_bool(result.is_err())
} else {
TestResult::from_bool(result.is_ok())
}
}
// Test tool execution properties
#[quickcheck]
fn prop_tool_execution_timeout(name: String, args: HashMap<String, Value>) -> TestResult {
// Create a slow tool for testing
struct SlowTool;
#[async_trait::async_trait]
impl ToolExecutor for SlowTool {
fn name(&self) -> &str { "slow_tool" }
async fn execute(&self, _args: HashMap<String, Value>) -> Result<ToolResult, McpError> {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(ToolResult { content: vec![], is_error: None })
}
}
// Test should complete within reasonable time due to timeout
let rt = tokio::runtime::Runtime::new().unwrap();
let result = rt.block_on(async {
let tool = Arc::new(SlowTool);
tokio::time::timeout(
Duration::from_secs(1),
execute_tool_with_timeout(tool, args, Duration::from_millis(500))
).await
});
TestResult::from_bool(result.is_ok())
}
async fn execute_tool_with_timeout(
tool: Arc<dyn ToolExecutor>,
args: HashMap<String, Value>,
timeout: Duration,
) -> Result<ToolResult, McpError> {
tokio::time::timeout(timeout, tool.execute(args))
.await
.map_err(|_| McpError::ToolExecution {
tool_name: tool.name().to_string(),
source: Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Tool execution timed out"
)),
})?
}
Fuzzing Integration¶
// โ
Good: Fuzz testing for security-critical components
#[cfg(fuzzing)]
pub mod fuzz {
use super::*;
pub fn fuzz_json_parsing(data: &[u8]) {
if let Ok(s) = std::str::from_utf8(data) {
let _ = serde_json::from_str::<serde_json::Value>(s);
}
}
pub fn fuzz_tool_arguments(data: &[u8]) {
if let Ok(s) = std::str::from_utf8(data) {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(s) {
if let Ok(map) = serde_json::from_value::<HashMap<String, serde_json::Value>>(value) {
// Test validation doesn't panic
for tool in get_all_tools() {
let _ = tool.validate_arguments(&map);
}
}
}
}
}
pub fn fuzz_url_validation(data: &[u8]) {
if let Ok(s) = std::str::from_utf8(data) {
let _ = SafeUrl::new(s, &["http", "https"]);
}
}
}
// Cargo.toml should include:
// [dependencies]
// afl = { version = "0.12", optional = true }
//
// [features]
// fuzzing = ["afl"]
Performance Optimization¶
Compile-Time Optimizations¶
// โ
Good: Compile-time string interning
use std::sync::LazyLock;
use std::collections::HashMap;
static TOOL_SCHEMAS: LazyLock<HashMap<&'static str, serde_json::Value>> = LazyLock::new(|| {
let mut schemas = HashMap::new();
schemas.insert("http_request", serde_json::json!({
"type": "object",
"properties": {
"url": {"type": "string", "format": "uri"},
"method": {"type": "string", "enum": ["GET", "POST", "PUT", "DELETE"]}
},
"required": ["url"]
}));
schemas.insert("read_file", serde_json::json!({
"type": "object",
"properties": {
"path": {"type": "string"}
},
"required": ["path"]
}));
schemas
});
// โ
Good: Const generics for compile-time validation
pub struct BoundedString<const MAX_LEN: usize> {
value: String,
}
impl<const MAX_LEN: usize> BoundedString<MAX_LEN> {
pub fn new(value: String) -> Result<Self, McpError> {
if value.len() > MAX_LEN {
return Err(McpError::ValidationError {
field: "bounded_string".to_string(),
message: format!("String too long: {} > {}", value.len(), MAX_LEN),
});
}
Ok(Self { value })
}
pub fn as_str(&self) -> &str {
&self.value
}
}
// Usage with different bounds
type ToolName = BoundedString<100>;
type Description = BoundedString<1000>;
These Rust-specific best practices ensure your MCP server leverages Rust's unique strengths in memory safety, performance, and concurrent programming while maintaining code clarity and maintainability.