diff --git a/Directory.Packages.props b/Directory.Packages.props index 4caf048c6..498e2a75c 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -25,6 +25,7 @@ + diff --git a/ModelContextProtocol.slnx b/ModelContextProtocol.slnx index 1090c5377..fda2f311e 100644 --- a/ModelContextProtocol.slnx +++ b/ModelContextProtocol.slnx @@ -44,7 +44,6 @@ - diff --git a/docs/concepts/tasks/tasks.md b/docs/concepts/tasks/tasks.md index c0b571f77..2c782ccf1 100644 --- a/docs/concepts/tasks/tasks.md +++ b/docs/concepts/tasks/tasks.md @@ -1,604 +1,140 @@ --- title: Tasks -author: eiriktsarpalis description: MCP Tasks for Long-Running Operations uid: tasks --- # MCP Tasks - -> [!WARNING] -> Tasks are an **experimental feature** in the MCP specification (version 2025-11-25). The API may change in future releases. See the [Experimental APIs](../../experimental.md) documentation for details on working with experimental APIs. +> **Status**: Experimental (`MCPEXP001`). Based on [SEP-2663](https://github.com/nicholasgasior/specification/blob/main/docs/specification/2025-03-26/extensions/tasks.md). -The Model Context Protocol (MCP) supports [task-based execution] for long-running operations. Tasks enable a "call-now, fetch-later" pattern where clients can initiate operations that may take significant time to complete, then poll for status and retrieve results when ready. - -[task-based execution]: https://modelcontextprotocol.io/specification/draft/basic/utilities/tasks +Tasks allow MCP servers to run tool invocations asynchronously, reporting progress and requesting additional input from the client while execution continues in the background. ## Overview -Tasks are useful when operations may take a long time to complete, such as: - -- Large dataset processing or analysis -- Complex report generation -- Code migration or refactoring operations -- Machine learning inference or training -- Batch data transformations - -Without tasks, clients must keep connections open for the entire duration of long-running operations. Tasks allow clients to: - -1. Initiate an operation and receive a task ID immediately -2. Disconnect and reconnect later -3. Poll for status updates -4. Retrieve results when complete -5. Cancel operations if needed - -## Task Lifecycle - -Tasks follow a defined lifecycle through these status values: - -| Status | Description | -|--------|-------------| -| `working` | Task is actively being processed | -| `input_required` | Task is waiting for additional input (e.g., elicitation) | -| `completed` | Task finished successfully; results are available | -| `failed` | Task encountered an error | -| `cancelled` | Task was cancelled by the client | - -Tasks begin in the `working` status and transition to one of the terminal states (`completed`, `failed`, or `cancelled`). Once in a terminal state, the status cannot change. - -## Server Implementation - -### Configuring Task Support - -To enable task support on a server, configure a task store when setting up the MCP server: - -```csharp -var builder = WebApplication.CreateBuilder(args); - -// Create a task store for managing task state -var taskStore = new InMemoryMcpTaskStore(); - -builder.Services.AddMcpServer(options => -{ - // Enable tasks by providing a task store - options.TaskStore = taskStore; -}) -.WithHttpTransport(o => o.Stateless = true) -.WithTools(); -``` - -The is a reference implementation suitable for development and single-server deployments. For production multi-server scenarios, implement with a persistent backing store (database, Redis, etc.). - -### Task Store Configuration - -The `InMemoryMcpTaskStore` constructor accepts several optional parameters: - -```csharp -var taskStore = new InMemoryMcpTaskStore( - defaultTtl: TimeSpan.FromHours(1), // Default task retention time - maxTtl: TimeSpan.FromHours(24), // Maximum allowed TTL - pollInterval: TimeSpan.FromSeconds(1), // Suggested client poll interval - cleanupInterval: TimeSpan.FromMinutes(5), // Background cleanup frequency - pageSize: 100, // Tasks per page for listing - maxTasks: 1000, // Maximum total tasks allowed - maxTasksPerSession: 100 // Maximum tasks per session -); -``` - -### Tool Task Support - -Tools automatically advertise task support when they return `Task`, `ValueTask`, `Task`, or `ValueTask`: - -```csharp -[McpServerToolType] -public class MyTools -{ - // This tool automatically supports task-augmented calls - // because it returns Task (async method) - [McpServerTool, Description("Processes a large dataset")] - public static async Task ProcessDataset( - int recordCount, - CancellationToken cancellationToken) - { - // Long-running operation - await Task.Delay(5000, cancellationToken); - return $"Processed {recordCount} records"; - } - - // Synchronous tools don't support task augmentation by default - [McpServerTool, Description("Quick operation")] - public static string QuickOperation(string input) => $"Result: {input}"; -} -``` +When a client calls a tool and includes the `io.modelcontextprotocol/tasks` extension key in `_meta`, the server may return a `CreateTaskResult` instead of an immediate `CallToolResult`. The client then polls via `tasks/get` until the task reaches a terminal state. -You can explicitly control task support using : +### Task Lifecycle -```csharp -// In Program.cs or configuration -builder.Services.AddMcpServer() - .WithTools([ - McpServerTool.Create( - (int count, CancellationToken ct) => ProcessAsync(count, ct), - new McpServerToolCreateOptions - { - Name = "requiredTaskTool", - Execution = new ToolExecution - { - // Require clients to use task augmentation - TaskSupport = ToolTaskSupport.Required - } - }) - ]); ``` - -Task support levels: -- `Forbidden` (default for sync methods): Tool cannot be called with task augmentation -- `Optional` (default for async methods): Tool can be called with or without task augmentation -- `Required`: Tool must be called with task augmentation - -### Explicit Task Creation with `IMcpTaskStore` - -For more control over task lifecycle, tools can directly interact with and return an `McpTask`. This approach allows you to: - -- Create a task and return immediately while work continues in the background -- Control exactly when and how task status and results are updated -- Integrate with external systems for task execution - -Here's a simple example using `Task.Run` to schedule background work: - -```csharp -[McpServerToolType] -public class MyTools(IMcpTaskStore taskStore) -{ - [McpServerTool] - [Description("Starts a background job and returns a task for polling.")] - public async Task StartBackgroundJob( - [Description("Number of items to process")] int itemCount, - RequestContext context, - CancellationToken cancellationToken) - { - // Create a task in the store - this records the task metadata - var task = await taskStore.CreateTaskAsync( - new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(30) }, - context.JsonRpcRequest.Id!, - context.JsonRpcRequest, - context.Server.SessionId, - cancellationToken); - - // Schedule work to run in the background (fire-and-forget) - _ = Task.Run(async () => - { - try - { - // Simulate long-running work - await Task.Delay(TimeSpan.FromSeconds(10)); - var result = $"Processed {itemCount} items successfully"; - - // Store the completed result - await taskStore.StoreTaskResultAsync( - task.TaskId, - McpTaskStatus.Completed, - JsonSerializer.SerializeToElement(new CallToolResult - { - Content = [new TextContentBlock { Text = result }] - }), - context.Server.SessionId); - } - catch (Exception ex) - { - // Mark task as failed on error - await taskStore.StoreTaskResultAsync( - task.TaskId, - McpTaskStatus.Failed, - JsonSerializer.SerializeToElement(new CallToolResult - { - Content = [new TextContentBlock { Text = ex.Message }], - IsError = true - }), - context.Server.SessionId); - } - }, CancellationToken.None); - - // Return immediately - client will poll for completion - return task; - } -} +Working → Completed +Working → Failed +Working → Cancelled +Working → InputRequired → Working (loop) ``` -When a tool returns `McpTask`, the SDK bypasses automatic task wrapping and returns the task directly to the client. - - -> [!IMPORTANT] -> **No Fault Tolerance Guarantees**: Both `InMemoryMcpTaskStore` and the automatic task support for `Task`-returning tool methods do **not** provide fault tolerance. Task state and execution are bounded by the memory of the server process. If the server crashes or restarts: -> - All in-memory task metadata is lost -> - Any in-flight task execution is terminated -> - Clients will receive errors when polling for previously created tasks -> -> For fault-tolerant task execution, see the [Fault-Tolerant Task Implementations](#fault-tolerant-task-implementations) section. +## Server Configuration -### Task Status Notifications +### Using the Task Store -When `SendTaskStatusNotifications` is enabled, the server automatically sends status updates to connected clients: +The `InMemoryMcpTaskStore` provides a ready-to-use in-memory implementation: ```csharp -builder.Services.AddMcpServer(options => +var builder = McpServerBuilder.Create(options => { - options.TaskStore = taskStore; - options.SendTaskStatusNotifications = true; // Enable notifications + options.TaskStore = new InMemoryMcpTaskStore(); }); +builder.WithTools(); ``` -Clients receive `notifications/tasks/status` messages when task status changes. +When a `TaskStore` is configured: +- `tasks/get`, `tasks/update`, and `tasks/cancel` handlers are auto-wired from the store. +- Built-in tools are automatically wrapped: if the client signals task support, the tool is offloaded to a background task via the store. +- Server-initiated requests (elicitation, sampling) are redirected through the store's input request mechanism while inside a task scope. -## Client Implementation +### Custom Task Handlers -### Calling Tools as Tasks - -To execute a tool as a task, include the `Task` property in the request: +For full control without a store, set handlers directly: ```csharp -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; - -var client = await McpClient.CreateAsync(transport); - -// Call tool with task augmentation -var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "processDataset", - Arguments = new Dictionary - { - ["recordCount"] = JsonSerializer.SerializeToElement(1000) - }, - Task = new McpTaskMetadata - { - TimeToLive = TimeSpan.FromHours(2) // Request 2-hour retention - } - }, - cancellationToken); - -// Check if a task was created -if (result.Task != null) -{ - Console.WriteLine($"Task created: {result.Task.TaskId}"); - Console.WriteLine($"Status: {result.Task.Status}"); -} +options.Handlers.GetTaskHandler = async (request, ct) => { ... }; +options.Handlers.UpdateTaskHandler = async (request, ct) => { ... }; +options.Handlers.CancelTaskHandler = async (request, ct) => { ... }; ``` -### Polling for Task Status +### Task-Aware Tool Handlers -Use to check task status: +The `CallToolWithTaskHandler` returns `ResultOrCreatedTask`, allowing the handler to return either an immediate result or a task: ```csharp -var task = await client.GetTaskAsync(taskId, cancellationToken: cancellationToken); -Console.WriteLine($"Status: {task.Status}"); -Console.WriteLine($"Last Updated: {task.LastUpdatedAt}"); - -if (task.StatusMessage != null) +options.Handlers.CallToolWithTaskHandler = async (request, ct) => { - Console.WriteLine($"Message: {task.StatusMessage}"); -} -``` - -### Waiting for Completion - -The SDK provides helper methods for polling until a task completes: + // Return immediate result + return new CallToolResult { ... }; -```csharp -// Poll until task reaches terminal state -var completedTask = await client.PollTaskUntilCompleteAsync( - taskId, - cancellationToken: cancellationToken); - -if (completedTask.Status == McpTaskStatus.Completed) -{ - // Get the result as raw JSON - var resultJson = await client.GetTaskResultAsync( - taskId, - cancellationToken: cancellationToken); - - // Deserialize to the expected type - var result = resultJson.Deserialize(McpJsonUtilities.DefaultOptions); - - foreach (var content in result?.Content ?? []) - { - if (content is TextContentBlock text) - { - Console.WriteLine(text.Text); - } - } -} -else if (completedTask.Status == McpTaskStatus.Failed) -{ - Console.WriteLine($"Task failed: {completedTask.StatusMessage}"); -} -``` - -### Listing Tasks - -List all tasks for the current session: - -```csharp -var tasks = await client.ListTasksAsync(cancellationToken: cancellationToken); - -foreach (var task in tasks) -{ - Console.WriteLine($"{task.TaskId}: {task.Status}"); -} -``` - -### Cancelling Tasks - -Cancel a running task: - -```csharp -var cancelledTask = await client.CancelTaskAsync( - taskId, - cancellationToken: cancellationToken); - -Console.WriteLine($"Task status: {cancelledTask.Status}"); // Cancelled -``` - -### Handling Status Notifications - -Register a handler to receive real-time status updates: - -```csharp -var options = new McpClientOptions -{ - Handlers = new McpClientHandlers - { - TaskStatusHandler = (task, cancellationToken) => - { - Console.WriteLine($"Task {task.TaskId} status changed to {task.Status}"); - return ValueTask.CompletedTask; - } - } + // Or return a task + return new CreateTaskResult { TaskId = "...", Status = McpTaskStatus.Working, ... }; }; - -var client = await McpClient.CreateAsync(transport, options); ``` - -> [!NOTE] -> Clients should not rely on receiving status notifications. Notifications are optional and may not be sent in all scenarios. Always use polling as the primary mechanism for tracking task status. +> **Note**: `CallToolHandler` and `CallToolWithTaskHandler` are mutually exclusive. If both are set, an exception is thrown. -## Implementing a Custom Task Store +### Task Scope for Server-Initiated Requests -For production deployments, implement with a persistent backing store: +When executing tool logic as a background task, use `CreateMcpTaskScope` to redirect elicitation/sampling/roots requests through the task store: ```csharp -public class DatabaseTaskStore : IMcpTaskStore +using (server.CreateMcpTaskScope(taskId, taskStore)) { - private readonly IDbConnection _db; - - public DatabaseTaskStore(IDbConnection db) => _db = db; - - public async Task CreateTaskAsync( - McpTaskMetadata taskMetadata, - RequestId requestId, - JsonRpcRequest request, - string? sessionId, - CancellationToken cancellationToken) - { - var task = new McpTask - { - TaskId = Guid.NewGuid().ToString(), - Status = McpTaskStatus.Working, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = taskMetadata.TimeToLive ?? TimeSpan.FromHours(1) - }; - - // Store in database - await _db.ExecuteAsync( - "INSERT INTO Tasks (TaskId, SessionId, Status, ...) VALUES (@TaskId, @SessionId, @Status, ...)", - new { task.TaskId, sessionId, task.Status, ... }); - - return task; - } - - public async Task GetTaskAsync( - string taskId, - string? sessionId, - CancellationToken cancellationToken) - { - // Retrieve from database with session isolation - return await _db.QuerySingleOrDefaultAsync( - "SELECT * FROM Tasks WHERE TaskId = @TaskId AND SessionId = @SessionId", - new { taskId, sessionId }); - } - - // Implement other interface methods... + // Any ElicitAsync/SampleAsync calls here will be stored as + // input requests and await client responses via tasks/update. + var result = await server.ElicitAsync(...); } ``` -### Task Store Best Practices - -1. **Session Isolation**: Always filter tasks by session ID to prevent cross-session access -2. **TTL Enforcement**: Implement background cleanup of expired tasks -3. **Thread Safety**: Ensure all operations are thread-safe for concurrent access -4. **Atomic Updates**: Use database transactions for status transitions -5. **Optimistic Concurrency**: Prevent lost updates with version checking or row locks - -## Error Handling +## Client Usage -Task operations may throw with these error codes: +### Automatic Polling -| Error Code | Scenario | -|------------|----------| -| `InvalidParams` | Invalid or nonexistent task ID or invalid cursor | -| `InvalidParams` | Tool with `taskSupport: forbidden` called with task metadata, or tool with `taskSupport: required` called without task metadata | -| `InternalError` | Task execution failure or result unavailable | - -Example error handling: +`CallToolAsync` handles the full lifecycle automatically: ```csharp -try -{ - var task = await client.GetTaskAsync(taskId, cancellationToken: ct); -} -catch (McpProtocolException ex) when (ex.ErrorCode == McpErrorCode.InvalidParams) +var result = await client.CallToolAsync(new CallToolRequestParams { - Console.WriteLine($"Task not found: {taskId}"); -} + Name = "long-running-tool", + Arguments = { ... }, +}, cancellationToken); +// Blocks until completed, resolving input requests along the way. ``` -## Complete Example - - - -See the [LongRunningTasks sample](https://github.com/modelcontextprotocol/csharp-sdk/tree/main/samples/LongRunningTasks) for a complete working example demonstrating: - - -- Server setup with a file-based `IMcpTaskStore` for durability -- Explicit task creation via `IMcpTaskStore` in tools returning `McpTask` -- Task polling and result retrieval across server restarts -- Cancellation support - -## Fault-Tolerant Task Implementations - -The default `InMemoryMcpTaskStore` and automatic task support for async tools are convenient for development, but they provide no durability or fault tolerance. When the server process terminates—whether due to a crash, deployment, or scaling event—all task state and in-flight computations are lost. - -### Why Fault Tolerance Requires External Systems - -True fault tolerance for long-running tasks requires two key capabilities that cannot be provided by an in-process solution: - -1. **Durable Task State**: Task metadata (ID, status, results) must survive process termination. This requires an external persistent store such as a database, Redis, or distributed cache. - -2. **Resumable Compute**: The actual work being performed must be executed by an external system that can continue running independently of the MCP server process—such as a job queue (Azure Service Bus, RabbitMQ), workflow engine (Temporal, Azure Durable Functions), or batch processing system (Azure Batch, Kubernetes Jobs). - -### Explicit Task Creation with `IMcpTaskStore` +### Manual Control -To implement fault-tolerant tasks, tools can directly interact with `IMcpTaskStore` and return an `McpTask` instead of relying on automatic task wrapping. This approach gives you full control over task lifecycle and enables integration with external compute fabrics: +Use `CallToolRawAsync` for manual lifecycle management: ```csharp -[McpServerToolType] -public class FaultTolerantTools(IMcpTaskStore taskStore, IJobQueue jobQueue) +var raw = await client.CallToolRawAsync(requestParams, cancellationToken); +if (raw.IsTask) { - [McpServerTool] - [Description("Submits a long-running job with fault-tolerant execution.")] - public async Task SubmitJob( - [Description("The job parameters")] string jobInput, - RequestContext context, - CancellationToken cancellationToken) - { - // 1. Create a task in the durable store - var task = await taskStore.CreateTaskAsync( - new McpTaskMetadata { TimeToLive = TimeSpan.FromHours(24) }, - context.JsonRpcRequest.Id!, - context.JsonRpcRequest, - context.Server.SessionId, - cancellationToken); - - // 2. Submit work to an external compute fabric - // The job queue handles execution independently of this process - await jobQueue.EnqueueAsync(new JobMessage - { - TaskId = task.TaskId, - SessionId = context.Server.SessionId, - Input = jobInput - }, cancellationToken); - - // 3. Return the task immediately - client will poll for completion - return task; - } + // Poll manually via client.GetTaskAsync(raw.TaskCreated!.TaskId, ...) } ``` -The external job processor updates the task store when work completes: +## Input Requests (Multi-Round-Trip) -```csharp -// In a separate worker process or Azure Function -public class JobProcessor(IMcpTaskStore taskStore) -{ - public async Task ProcessJobAsync(JobMessage job, CancellationToken cancellationToken) - { - try - { - // Perform the actual long-running work - var result = await DoExpensiveWorkAsync(job.Input, cancellationToken); - - // Store the result in the durable task store - await taskStore.StoreTaskResultAsync( - job.TaskId, - McpTaskStatus.Completed, - JsonSerializer.SerializeToElement(new CallToolResult - { - Content = [new TextContentBlock { Text = result }] - }), - job.SessionId, - cancellationToken); - } - catch (Exception ex) - { - // Mark task as failed - await taskStore.StoreTaskResultAsync( - job.TaskId, - McpTaskStatus.Failed, - JsonSerializer.SerializeToElement(new CallToolResult - { - Content = [new TextContentBlock { Text = ex.Message }], - IsError = true - }), - job.SessionId, - cancellationToken); - } - } -} -``` +Per [SEP-2322 (MRTR)](https://modelcontextprotocol.io/seps/2322-MRTR), tasks can request additional input from the client. The server adds input requests to the store, and the client provides responses via `tasks/update`. -### Simplified Example: File-Based Task Store +Supported input request types: +- **Elicitation** (`elicitation/create`) +- **Sampling** (`sampling/createMessage`) - - -The [LongRunningTasks sample](https://github.com/modelcontextprotocol/csharp-sdk/tree/main/samples/LongRunningTasks) demonstrates a simplified fault-tolerant approach using the file system. The `FileBasedMcpTaskStore` persists task state to disk, allowing tasks to survive server restarts: - +The client deduplicates input requests across polling cycles to avoid re-resolving the same request. -```csharp -// Use a file-based task store for durability -var taskStorePath = Path.Combine(Path.GetTempPath(), "mcp-tasks"); -var taskStore = new FileBasedMcpTaskStore(taskStorePath); +## Architecture Notes -builder.Services.AddMcpServer(options => -{ - options.TaskStore = taskStore; -}) -.WithHttpTransport(o => o.Stateless = true) -.WithTools(); -``` +### Filter Model (3 Cases) -The sample's tool returns an `McpTask` directly by calling `CreateTaskAsync`: +1. **Non-task filter + non-task handler**: Filters applied normally, final result converted to task shape. +2. **Task filter + task handler**: Filters applied directly to the task-augmented handler. +3. **Mixed**: Throws `InvalidOperationException`. -```csharp -[McpServerToolType] -public class TaskTools(IMcpTaskStore taskStore) -{ - [McpServerTool] - [Description("Submits a job and returns a task that can be polled for completion.")] - public async Task SubmitJob( - [Description("A label for the job")] string jobName, - RequestContext context, - CancellationToken cancellationToken) - { - return await taskStore.CreateTaskAsync( - new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(10) }, - context.JsonRpcRequest.Id!, - context.JsonRpcRequest, - context.Server.SessionId, - cancellationToken); - } -} -``` +### Immutable Store Design + +`InMemoryMcpTaskStore` uses immutable records with compare-and-swap (CAS) updates for lock-free thread safety. `ImmutableDictionary` is used for input requests/responses. -While this file-based approach demonstrates the pattern, production systems should use proper distributed storage and compute infrastructure for true fault tolerance and scalability. +## Known Limitations / TODOs -## See Also +- **Task status notifications (SEP-2575)**: Server-to-client push notifications for task state changes are not yet implemented. The client currently relies on polling only. +- **Lazy task creation**: Currently, `CreateTaskAsync` is called eagerly before the inner handler runs. Ideally, task creation should be deferred until the handler actually needs it (avoids unnecessary store writes for tools that return immediately). +- **Mid-execution promotion to task**: There is currently no way for a tool to start executing synchronously and then transition the remaining work to a background task. A user can achieve this manually with a custom `CallToolWithTaskHandler`, but there is no built-in support for `[McpServerTool]`-attributed methods to say "the remaining work should continue as a task." This could be addressed with an API like `McpServer.PromoteToTaskAsync()` callable from within tool execution. +- **Extensions serialization round-trip**: `ServerCapabilities.Extensions` (backed by `IDictionary`) does not survive JSON round-trip via source-generated serialization. The `object` values cannot be deserialized by the source generator. -- -- -- -- -- [MCP Tasks Specification](https://modelcontextprotocol.io/specification/draft/basic/utilities/tasks) diff --git a/docs/experimental.md b/docs/experimental.md index 1ad75a9b4..59a1d7579 100644 --- a/docs/experimental.md +++ b/docs/experimental.md @@ -26,8 +26,8 @@ Add the diagnostic ID to `` in your project file: Use `#pragma warning disable` around specific call sites: ```csharp -#pragma warning disable MCPEXP001 // The Tasks feature is experimental per the MCP specification and is subject to change. -tool.Execution = new ToolExecution { ... }; +#pragma warning disable MCPEXP001 // The Extensions feature is part of a future MCP specification version that has not yet been ratified and is subject to change. +capabilities.Extensions = new Dictionary { ... }; #pragma warning restore MCPEXP001 ``` @@ -67,4 +67,3 @@ By placing the SDK's resolver first, MCP types are serialized using the SDK's co - [Versioning](versioning.md) - [List of diagnostics](list-of-diagnostics.md#experimental-apis) -- [Tasks](concepts/tasks/tasks.md) (an experimental feature) diff --git a/docs/list-of-diagnostics.md b/docs/list-of-diagnostics.md index 515472817..fb44442ef 100644 --- a/docs/list-of-diagnostics.md +++ b/docs/list-of-diagnostics.md @@ -23,7 +23,7 @@ If you use experimental APIs, you will get one of the diagnostics shown below. T | Diagnostic ID | Description | | :------------ | :---------- | -| `MCPEXP001` | Experimental APIs for features in the MCP specification itself, including Tasks and Extensions. Tasks provide a mechanism for asynchronous long-running operations that can be polled for status and results (see [MCP Tasks specification](https://modelcontextprotocol.io/specification/draft/basic/utilities/tasks)). Extensions provide a framework for extending the Model Context Protocol while maintaining interoperability (see [SEP-2133](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2133)). | +| `MCPEXP001` | Experimental APIs for features in the MCP specification itself, including Extensions. Extensions provide a framework for extending the Model Context Protocol while maintaining interoperability (see [SEP-2133](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2133)). | | `MCPEXP002` | Experimental SDK APIs unrelated to the MCP specification itself, including subclassing `McpClient`/`McpServer` (see [#1363](https://github.com/modelcontextprotocol/csharp-sdk/pull/1363)) and `RunSessionHandler`, which may be removed or change signatures in a future release (consider using `ConfigureSessionOptions` instead). | ## Obsolete APIs diff --git a/docs/roadmap.md b/docs/roadmap.md index 81955a710..105a039d0 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -12,7 +12,7 @@ The C# SDK tracks implementation of MCP spec components using the [modelcontextp ### Next Spec Revision -The next MCP specification revision is being developed in the [protocol repository](https://github.com/modelcontextprotocol/modelcontextprotocol). The C# SDK already has experimental support for [Tasks](concepts/tasks/tasks.md) (experimental in the specification), which will be updated as the specification is revised. +The next MCP specification revision is being developed in the [protocol repository](https://github.com/modelcontextprotocol/modelcontextprotocol). ### Feedback and End-to-End Scenarios diff --git a/samples/LongRunningTasks/FileBasedMcpTaskStore.cs b/samples/LongRunningTasks/FileBasedMcpTaskStore.cs deleted file mode 100644 index 55a6e77d5..000000000 --- a/samples/LongRunningTasks/FileBasedMcpTaskStore.cs +++ /dev/null @@ -1,393 +0,0 @@ -using ModelContextProtocol; -using ModelContextProtocol.Protocol; -using System.Text.Json; -using System.Text.Json.Nodes; -using System.Text.Json.Serialization; - -namespace LongRunningTasks; - -/// -/// A minimal file-based implementation of that demonstrates -/// durable, fault-tolerant task storage using simple time-based completion. -/// -/// -/// -/// This implementation stores task data to disk: task ID, creation timestamp, execution duration, -/// session ID, TTL, and optional result. Task completion is determined by: -/// -/// Explicit completion or failure via -/// Explicit cancellation via -/// Time-based auto-completion when execution time has elapsed -/// -/// -/// -/// The file-based approach enables durability across process restarts - if the server -/// crashes and restarts, tasks can still be queried and will complete based on elapsed time. -/// -/// -public sealed partial class FileBasedMcpTaskStore : IMcpTaskStore -{ - private readonly string _storePath; - private readonly TimeSpan _executionTime; - - /// - /// Initializes a new instance of the class. - /// - /// The directory path where task files will be stored. - /// - /// The fixed execution time for all tasks. Tasks are reported as completed once this - /// duration has elapsed since creation. Defaults to 5 seconds. - /// - public FileBasedMcpTaskStore(string storePath, TimeSpan? executionTime = null) - { - _storePath = storePath ?? throw new ArgumentNullException(nameof(storePath)); - _executionTime = executionTime ?? TimeSpan.FromSeconds(5); - Directory.CreateDirectory(_storePath); - } - - /// - public async Task CreateTaskAsync( - McpTaskMetadata taskParams, - RequestId requestId, - JsonRpcRequest request, - string? sessionId = null, - CancellationToken cancellationToken = default) - { - var taskId = Guid.NewGuid().ToString("N"); - var now = DateTimeOffset.UtcNow; - - var entry = new TaskFileEntry - { - TaskId = taskId, - SessionId = sessionId, - Status = McpTaskStatus.Working, - CreatedAt = now, - ExecutionTime = _executionTime, - TimeToLive = taskParams.TimeToLive, - Result = JsonSerializer.SerializeToElement(request.Params, JsonContext.Default.JsonNode) - }; - - await WriteTaskEntryAsync(GetTaskFilePath(taskId), entry); - - return ToMcpTask(entry); - } - - /// - public async Task GetTaskAsync( - string taskId, - string? sessionId = null, - CancellationToken cancellationToken = default) - { - var entry = await ReadTaskEntryAsync(taskId); - if (entry is null) - { - return null; - } - - // Session isolation - if (sessionId is not null && entry.SessionId != sessionId) - { - return null; - } - - // Skip if TTL has expired - if (IsExpired(entry)) - { - return null; - } - - return ToMcpTask(entry); - } - - /// - public async Task StoreTaskResultAsync( - string taskId, - McpTaskStatus status, - JsonElement result, - string? sessionId = null, - CancellationToken cancellationToken = default) - { - if (status is not (McpTaskStatus.Completed or McpTaskStatus.Failed)) - { - throw new ArgumentException( - $"Status must be {nameof(McpTaskStatus.Completed)} or {nameof(McpTaskStatus.Failed)}.", - nameof(status)); - } - - var updatedEntry = await UpdateTaskEntryAsync(taskId, sessionId, entry => - { - var effectiveStatus = GetEffectiveStatus(entry); - if (IsTerminalStatus(effectiveStatus)) - { - throw new InvalidOperationException( - $"Cannot store result for task in terminal state: {effectiveStatus}"); - } - - return entry with - { - Status = status, - Result = result - }; - }); - - return ToMcpTask(updatedEntry); - } - - /// - public async Task GetTaskResultAsync( - string taskId, - string? sessionId = null, - CancellationToken cancellationToken = default) - { - var entry = await ReadTaskEntryAsync(taskId) - ?? throw new InvalidOperationException($"Task not found: {taskId}"); - - if (sessionId is not null && entry.SessionId != sessionId) - { - throw new InvalidOperationException($"Task not found: {taskId}"); - } - - var effectiveStatus = GetEffectiveStatus(entry); - if (!IsTerminalStatus(effectiveStatus)) - { - throw new InvalidOperationException($"Task not yet completed: {taskId}"); - } - - // Return stored result - return entry.Result ?? default; - } - - /// - public async Task UpdateTaskStatusAsync( - string taskId, - McpTaskStatus status, - string? statusMessage, - string? sessionId = null, - CancellationToken cancellationToken = default) - { - var updatedEntry = await UpdateTaskEntryAsync(taskId, sessionId, entry => - entry with - { - Status = status, - StatusMessage = statusMessage - }); - - return ToMcpTask(updatedEntry); - } - - /// - public async Task ListTasksAsync( - string? cursor = null, - string? sessionId = null, - CancellationToken cancellationToken = default) - { - var tasks = new List(); - - foreach (var file in Directory.EnumerateFiles(_storePath, "*.json")) - { - try - { - var entry = await ReadTaskEntryFromFileAsync(file); - if (entry is not null) - { - // Session isolation - if (sessionId is not null && entry.SessionId != sessionId) - { - continue; - } - - // Skip expired tasks - if (IsExpired(entry)) - { - continue; - } - - tasks.Add(ToMcpTask(entry)); - } - } - catch - { - // Skip corrupted or inaccessible files - } - } - - tasks.Sort((a, b) => a.CreatedAt.CompareTo(b.CreatedAt)); - - return new ListTasksResult { Tasks = [.. tasks] }; - } - - /// - public async Task CancelTaskAsync( - string taskId, - string? sessionId = null, - CancellationToken cancellationToken = default) - { - var updatedEntry = await UpdateTaskEntryAsync(taskId, sessionId, entry => - { - var effectiveStatus = GetEffectiveStatus(entry); - if (IsTerminalStatus(effectiveStatus)) - { - // Already terminal, return unchanged - return entry; - } - - return entry with { Status = McpTaskStatus.Cancelled }; - }); - - return ToMcpTask(updatedEntry); - } - - private string GetTaskFilePath(string taskId) => Path.Combine(_storePath, $"{taskId}.json"); - - /// - /// Reads, transforms, and writes a task entry while holding an exclusive file lock. - /// - /// The task ID to update. - /// Optional session ID for access control. - /// A function that transforms the entry. May throw to abort the update. - /// The updated task entry. - private async Task UpdateTaskEntryAsync( - string taskId, - string? sessionId, - Func updateFunc) - { - var filePath = GetTaskFilePath(taskId); - - // Acquire exclusive lock on the file for the entire read-modify-write cycle - using var stream = await AcquireFileStreamAsync(filePath, FileMode.Open, FileAccess.ReadWrite); - - var entry = await JsonSerializer.DeserializeAsync(stream, JsonContext.Default.TaskFileEntry) - ?? throw new InvalidOperationException($"Task not found: {taskId}"); - - // Enforce session isolation - if (sessionId is not null && entry.SessionId != sessionId) - { - throw new InvalidOperationException($"Task not found: {taskId}"); - } - - // Apply the transformation (may throw to abort) - var updatedEntry = updateFunc(entry); - - // Write back to the same stream - stream.SetLength(0); - stream.Position = 0; - await JsonSerializer.SerializeAsync(stream, updatedEntry, JsonContext.Default.TaskFileEntry); - - return updatedEntry; - } - - private async Task ReadTaskEntryAsync(string taskId) - { - var filePath = GetTaskFilePath(taskId); - return File.Exists(filePath) ? await ReadTaskEntryFromFileAsync(filePath) : null; - } - - private static async Task ReadTaskEntryFromFileAsync(string filePath) - { - try - { - using var stream = await AcquireFileStreamAsync(filePath, FileMode.Open, FileAccess.Read); - return await JsonSerializer.DeserializeAsync(stream, JsonContext.Default.TaskFileEntry); - } - catch - { - return null; - } - } - - private static async Task WriteTaskEntryAsync(string filePath, TaskFileEntry entry) - { - using var stream = await AcquireFileStreamAsync(filePath, FileMode.Create, FileAccess.Write); - await JsonSerializer.SerializeAsync(stream, entry, JsonContext.Default.TaskFileEntry); - } - - private static async Task AcquireFileStreamAsync(string filePath, FileMode fileMode, FileAccess fileAccess) - { - const int MaxRetries = 10; - const int RetryDelayMs = 50; - - for (int attempt = 0; ; attempt++) - { - try - { - return new FileStream(filePath, fileMode, fileAccess, FileShare.None); - } - catch (IOException) when (attempt < MaxRetries) - { - await Task.Delay(RetryDelayMs); // File is locked by another process, wait and retry - } - } - } - - private McpTask ToMcpTask(TaskFileEntry entry) - { - var now = DateTimeOffset.UtcNow; - return new McpTask - { - TaskId = entry.TaskId, - Status = GetEffectiveStatus(entry), - StatusMessage = entry.StatusMessage, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = now, - TimeToLive = entry.TimeToLive - }; - } - - private static McpTaskStatus GetEffectiveStatus(TaskFileEntry entry) - { - // If already in a terminal state, return it - if (IsTerminalStatus(entry.Status)) - { - return entry.Status; - } - - // Check if execution time has elapsed - auto-complete - if (DateTimeOffset.UtcNow - entry.CreatedAt >= entry.ExecutionTime) - { - return McpTaskStatus.Completed; - } - - return entry.Status; - } - - private static bool IsTerminalStatus(McpTaskStatus status) => - status is McpTaskStatus.Completed or McpTaskStatus.Failed or McpTaskStatus.Cancelled; - - private static bool IsExpired(TaskFileEntry entry) => - entry.TimeToLive.HasValue && DateTimeOffset.UtcNow - entry.CreatedAt > entry.TimeToLive.Value; - - /// - /// Represents the data stored for each task. - /// - private sealed record TaskFileEntry - { - /// The unique task identifier. - public required string TaskId { get; init; } - - /// The session that created this task. - public string? SessionId { get; init; } - - /// The current task status. - public required McpTaskStatus Status { get; init; } - - /// Optional status message describing the current state. - public string? StatusMessage { get; init; } - - /// When the task was created. - public required DateTimeOffset CreatedAt { get; init; } - - /// How long until the task is considered complete (if not explicitly completed). - public required TimeSpan ExecutionTime { get; init; } - - /// Time to live - task is filtered out after this duration from creation. - public TimeSpan? TimeToLive { get; init; } - - /// The task result - initialized with request params, updated via StoreTaskResultAsync. - public JsonElement? Result { get; init; } - } - - [JsonSourceGenerationOptions(WriteIndented = true)] - [JsonSerializable(typeof(TaskFileEntry))] - [JsonSerializable(typeof(JsonNode))] - private sealed partial class JsonContext : JsonSerializerContext; -} diff --git a/samples/LongRunningTasks/LongRunningTasks.csproj b/samples/LongRunningTasks/LongRunningTasks.csproj deleted file mode 100644 index ffe1fc716..000000000 --- a/samples/LongRunningTasks/LongRunningTasks.csproj +++ /dev/null @@ -1,14 +0,0 @@ - - - - net9.0 - enable - enable - $(NoWarn);MCPEXP001 - - - - - - - diff --git a/samples/LongRunningTasks/Program.cs b/samples/LongRunningTasks/Program.cs deleted file mode 100644 index ee9174554..000000000 --- a/samples/LongRunningTasks/Program.cs +++ /dev/null @@ -1,34 +0,0 @@ -// This sample demonstrates using a custom IMcpTaskStore implementation for -// durable task storage. The FileBasedMcpTaskStore persists tasks to disk, -// allowing them to survive server restarts. -// -// To test: -// 1. Start the server and call the SubmitJob tool -// 2. Poll the returned task using tasks/get -// 3. Optionally restart the server - the task will still be queryable - -using LongRunningTasks; -using LongRunningTasks.Tools; - -var builder = WebApplication.CreateBuilder(args); - -// Use a file-based task store for persistence across server restarts. -// Tasks survive server restarts and can be resumed or queried after a crash. -var taskStorePath = Path.Combine(Path.GetTempPath(), "mcp-tasks"); -var taskStore = new FileBasedMcpTaskStore(taskStorePath); - -builder.Services.AddMcpServer(options => -{ - options.TaskStore = taskStore; - options.ServerInfo = new() - { - Name = "LongRunningTasksServer", - Version = "1.0.0" - }; -}) -.WithHttpTransport(o => o.Stateless = true) -.WithTools(); - -var app = builder.Build(); -app.MapMcp(); -app.Run(); diff --git a/samples/LongRunningTasks/Properties/launchSettings.json b/samples/LongRunningTasks/Properties/launchSettings.json deleted file mode 100644 index 9a7c84f4b..000000000 --- a/samples/LongRunningTasks/Properties/launchSettings.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "profiles": { - "LongRunningTasks": { - "commandName": "Project", - "launchBrowser": true, - "environmentVariables": { - "ASPNETCORE_ENVIRONMENT": "Development" - }, - "applicationUrl": "https://localhost:60964;http://localhost:60965" - } - } -} \ No newline at end of file diff --git a/samples/LongRunningTasks/README.md b/samples/LongRunningTasks/README.md deleted file mode 100644 index 71130e44a..000000000 --- a/samples/LongRunningTasks/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Long-Running Tasks Sample - -This sample demonstrates **explicit task handling** in MCP servers using the `IMcpTaskStore` interface directly. Unlike implicit task handling (where the server framework manages tasks automatically), this approach gives you full control over task lifecycle. \ No newline at end of file diff --git a/samples/LongRunningTasks/Tools/TaskTools.cs b/samples/LongRunningTasks/Tools/TaskTools.cs deleted file mode 100644 index 30eb43335..000000000 --- a/samples/LongRunningTasks/Tools/TaskTools.cs +++ /dev/null @@ -1,31 +0,0 @@ -using ModelContextProtocol; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.ComponentModel; - -namespace LongRunningTasks.Tools; - -/// -/// Demonstrates creating and returning tasks via . -/// -[McpServerToolType] -public class TaskTools(IMcpTaskStore taskStore) -{ - /// - /// Submits a job to the task store and returns a task handle for polling. - /// - [McpServerTool] - [Description("Submits a job and returns a task that can be polled for completion.")] - public Task SubmitJob( - [Description("A label for the job")] string jobName, - RequestContext context, - CancellationToken cancellationToken) - { - return taskStore.CreateTaskAsync( - new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(10) }, - context.JsonRpcRequest.Id!, - context.JsonRpcRequest, - context.Server.SessionId, - cancellationToken); - } -} diff --git a/samples/LongRunningTasks/appsettings.json b/samples/LongRunningTasks/appsettings.json deleted file mode 100644 index 757d8426e..000000000 --- a/samples/LongRunningTasks/appsettings.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "Logging": { - "LogLevel": { - "Default": "Information", - "Microsoft.AspNetCore": "Warning" - } - }, - "AllowedHosts": "localhost;127.0.0.1;[::1]" -} diff --git a/src/Common/Experimentals.cs b/src/Common/Experimentals.cs index 7e7e969bb..1fe7979b0 100644 --- a/src/Common/Experimentals.cs +++ b/src/Common/Experimentals.cs @@ -10,7 +10,7 @@ namespace ModelContextProtocol; /// /// /// MCPEXP001 covers APIs related to experimental features in the MCP specification itself, -/// such as Tasks and Extensions. These APIs may change as the specification evolves. +/// such as Extensions. These APIs may change as the specification evolves. /// /// /// MCPEXP002 covers experimental SDK APIs that are unrelated to the MCP specification, @@ -35,30 +35,9 @@ namespace ModelContextProtocol; /// internal static class Experimentals { - /// - /// Diagnostic ID for the experimental MCP Tasks feature. - /// - public const string Tasks_DiagnosticId = "MCPEXP001"; - - /// - /// Message for the experimental MCP Tasks feature. - /// - public const string Tasks_Message = "The Tasks feature is experimental per the MCP specification and is subject to change."; - - /// - /// URL for the experimental MCP Tasks feature. - /// - public const string Tasks_Url = "https://github.com/modelcontextprotocol/csharp-sdk/blob/main/docs/list-of-diagnostics.md#mcpexp001"; - /// /// Diagnostic ID for the experimental MCP Extensions feature. /// - /// - /// This uses the same diagnostic ID as because both - /// Tasks and Extensions are covered by the same MCPEXP001 diagnostic for experimental - /// MCP features. Having separate constants improves code clarity while maintaining a - /// single diagnostic suppression point. - /// public const string Extensions_DiagnosticId = "MCPEXP001"; /// diff --git a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs index f04c32ffd..ac924beaf 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using System.Diagnostics; @@ -955,394 +955,302 @@ async ValueTask SendRequestWithProgressAsync( /// The result of the request. /// is . /// The request failed or the server returned an error response. - public ValueTask CallToolAsync( + /// + /// This method automatically includes the io.modelcontextprotocol/tasks extension capability + /// in the request metadata. If the server returns a task handle instead of an immediate result, + /// this method transparently polls tasks/get until the task completes, fails, or is cancelled. + /// Use + /// to receive the raw without automatic polling. + /// + public async ValueTask CallToolAsync( CallToolRequestParams requestParams, CancellationToken cancellationToken = default) { Throw.IfNull(requestParams); - return SendRequestAsync( - RequestMethods.ToolsCall, - requestParams, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult, - cancellationToken: cancellationToken); + var augmented = await CallToolRawAsync(requestParams, cancellationToken).ConfigureAwait(false); + + if (!augmented.IsTask) + { + return augmented.Result!; + } + + return await PollTaskToCompletionAsync(augmented.TaskCreated!, cancellationToken).ConfigureAwait(false); } /// - /// Invokes a tool on the server as a task for long-running operations. + /// Polls a task until it reaches a terminal state and returns the final . /// - /// The name of the tool to call on the server. - /// An optional dictionary of arguments to pass to the tool. - /// Metadata for task augmentation, including optional TTL. If , an empty metadata is used. - /// An optional progress reporter for server notifications. - /// Optional request options including metadata, serialization settings, and progress tracking. - /// The to monitor for cancellation requests. The default is . - /// - /// An representing the created task. Use to poll for status updates - /// and to retrieve the final result. - /// - /// is . - /// The request failed or the server returned an error response. - /// - /// - /// Task-augmented tool calls allow long-running operations to be executed asynchronously. Instead of blocking - /// until the tool completes, the server immediately returns a task identifier that can be used to poll for - /// status updates and retrieve the final result. - /// - /// - /// The server must advertise task support via capabilities.tasks.requests.tools.call and the tool - /// must have execution.taskSupport set to "optional" or "required". - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public ValueTask CallToolAsTaskAsync( - string toolName, - IReadOnlyDictionary? arguments = null, - McpTaskMetadata? taskMetadata = null, - IProgress? progress = null, - RequestOptions? options = null, - CancellationToken cancellationToken = default) + private async ValueTask PollTaskToCompletionAsync( + CreateTaskResult taskCreated, + CancellationToken cancellationToken) { - Throw.IfNull(toolName); + string taskId = taskCreated.TaskId; + long pollIntervalMs = taskCreated.PollIntervalMs ?? 1000; + HashSet? resolvedRequestKeys = null; - var serializerOptions = options?.JsonSerializerOptions ?? McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - if (progress is null) + while (true) { - return SendTaskAugmentedCallToolRequestAsync(toolName, arguments, taskMetadata, options?.GetMetaForRequest(), serializerOptions, cancellationToken); - } + await Task.Delay(TimeSpan.FromMilliseconds(pollIntervalMs), cancellationToken).ConfigureAwait(false); - return SendTaskAugmentedCallToolRequestWithProgressAsync(toolName, arguments, taskMetadata, progress, options?.GetMetaForRequest(), serializerOptions, cancellationToken); + var taskResult = await GetTaskAsync(taskId, cancellationToken).ConfigureAwait(false); - async ValueTask SendTaskAugmentedCallToolRequestAsync( - string toolName, - IReadOnlyDictionary? arguments, - McpTaskMetadata? taskMetadata, - JsonObject? meta, - JsonSerializerOptions serializerOptions, - CancellationToken cancellationToken) - { - var result = await SendRequestAsync( - RequestMethods.ToolsCall, - new CallToolRequestParams - { - Name = toolName, - Arguments = ToArgumentsDictionary(arguments, serializerOptions), - Meta = meta, - Task = taskMetadata ?? new McpTaskMetadata(), - }, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CreateTaskResult, - cancellationToken: cancellationToken).ConfigureAwait(false); + // Update poll interval if the server changed it. + if (taskResult.PollIntervalMs is { } newInterval) + { + pollIntervalMs = newInterval; + } - return result.Task; - } + switch (taskResult) + { + case CompletedTaskResult completed: + return JsonSerializer.Deserialize(completed.TaskResult, McpJsonUtilities.JsonContext.Default.CallToolResult) + ?? throw new JsonException("Failed to deserialize CallToolResult from completed task."); - async ValueTask SendTaskAugmentedCallToolRequestWithProgressAsync( - string toolName, - IReadOnlyDictionary? arguments, - McpTaskMetadata? taskMetadata, - IProgress progress, - JsonObject? meta, - JsonSerializerOptions serializerOptions, - CancellationToken cancellationToken) - { - ProgressToken progressToken = new(Guid.NewGuid().ToString("N")); + case FailedTaskResult failed: + throw new McpException($"Task '{taskId}' failed: {failed.Error}"); - await using var _ = RegisterNotificationHandler(NotificationMethods.ProgressNotification, - (notification, cancellationToken) => - { - if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotificationParams) is { } pn && - pn.ProgressToken == progressToken) + case CancelledTaskResult: + throw new OperationCanceledException($"Task '{taskId}' was cancelled by the server."); + + case InputRequiredTaskResult inputRequired: + // Dedup: only resolve input requests we haven't already responded to. + var newRequests = new Dictionary(); + foreach (var kvp in inputRequired.InputRequests) { - progress.Report(pn.Progress); + if (resolvedRequestKeys is null || !resolvedRequestKeys.Contains(kvp.Key)) + { + newRequests[kvp.Key] = kvp.Value; + } } - return default; - }).ConfigureAwait(false); - - JsonObject metaWithProgress = meta is not null ? (JsonObject)meta.DeepClone() : []; - metaWithProgress["progressToken"] = progressToken.ToString(); + if (newRequests.Count > 0) + { + var inputResponses = await ResolveInputRequestsAsync(newRequests, cancellationToken).ConfigureAwait(false); + await UpdateTaskAsync(new UpdateTaskRequestParams + { + TaskId = taskId, + InputResponses = inputResponses, + }, cancellationToken).ConfigureAwait(false); + + resolvedRequestKeys ??= new HashSet(StringComparer.Ordinal); + foreach (var key in inputResponses.Keys) + { + resolvedRequestKeys.Add(key); + } + } - var result = await SendRequestAsync( - RequestMethods.ToolsCall, - new CallToolRequestParams - { - Name = toolName, - Arguments = ToArgumentsDictionary(arguments, serializerOptions), - Meta = metaWithProgress, - Task = taskMetadata ?? new McpTaskMetadata(), - }, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CreateTaskResult, - cancellationToken: cancellationToken).ConfigureAwait(false); + break; - return result.Task; + case WorkingTaskResult: + // Continue polling. + break; + } } } /// - /// Retrieves the current state of a specific task from the server. + /// Invokes a tool on the server with task extension support, returning the raw response + /// without automatic polling. The caller is responsible for handling task lifecycle. /// - /// The unique identifier of the task to retrieve. - /// Optional request options including metadata, serialization settings, and progress tracking. + /// The request parameters to send. The tasks extension capability will be injected into the request metadata. /// The to monitor for cancellation requests. The default is . - /// The current state of the task. - /// is . - /// is empty or composed entirely of whitespace. + /// A that is either an immediate result or a task handle. + /// is . /// The request failed or the server returned an error response. - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask GetTaskAsync( - string taskId, - RequestOptions? options = null, + /// + /// + /// Unlike , this method does not + /// automatically poll for task completion. If the server returns a , + /// the caller must manage polling via . + /// + /// + public async ValueTask> CallToolRawAsync( + CallToolRequestParams requestParams, CancellationToken cancellationToken = default) { - Throw.IfNullOrWhiteSpace(taskId); + Throw.IfNull(requestParams); - var result = await SendRequestAsync( - RequestMethods.TasksGet, - new GetTaskRequestParams { TaskId = taskId, Meta = options?.GetMetaForRequest() }, - McpJsonUtilities.JsonContext.Default.GetTaskRequestParams, - McpJsonUtilities.JsonContext.Default.GetTaskResult, - cancellationToken: cancellationToken).ConfigureAwait(false); + var paramsWithMeta = new CallToolRequestParams + { + Name = requestParams.Name, + Arguments = requestParams.Arguments, + Meta = GetMetaWithTaskCapability(requestParams.Meta), + }; - // Convert GetTaskResult to McpTask - return new McpTask + JsonRpcRequest jsonRpcRequest = new() { - TaskId = result.TaskId, - Status = result.Status, - StatusMessage = result.StatusMessage, - CreatedAt = result.CreatedAt, - LastUpdatedAt = result.LastUpdatedAt, - TimeToLive = result.TimeToLive, - PollInterval = result.PollInterval + Method = RequestMethods.ToolsCall, + Params = JsonSerializer.SerializeToNode(paramsWithMeta, McpJsonUtilities.JsonContext.Default.CallToolRequestParams), }; + + JsonRpcResponse response = await SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); + + // Discriminate based on resultType field. + if (response.Result is JsonObject resultObj && + resultObj.TryGetPropertyValue("resultType", out var resultTypeNode) && + resultTypeNode?.GetValue() == "task") + { + var taskCreated = resultObj.Deserialize(McpJsonUtilities.JsonContext.Default.CreateTaskResult) + ?? throw new JsonException("Failed to deserialize CreateTaskResult from response."); + return new ResultOrCreatedTask(taskCreated); + } + + var callToolResult = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.JsonContext.Default.CallToolResult) + ?? throw new JsonException("Failed to deserialize CallToolResult from response."); + return new ResultOrCreatedTask(callToolResult); } /// - /// Retrieves the result of a completed task, blocking until the task reaches a terminal state. + /// Sets the logging level for the server to control which log messages are sent to the client. /// - /// The unique identifier of the task whose result to retrieve. + /// The minimum severity level of log messages to receive from the server. /// Optional request options including metadata, serialization settings, and progress tracking. /// The to monitor for cancellation requests. The default is . - /// The raw JSON result of the task. - /// is . - /// is empty or composed entirely of whitespace. + /// A task representing the asynchronous operation. /// The request failed or the server returned an error response. - /// - /// This method sends a tasks/result request to the server, which will block until the task completes if it hasn't already. - /// The server handles all polling logic internally. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public ValueTask GetTaskResultAsync( - string taskId, - RequestOptions? options = null, - CancellationToken cancellationToken = default) - { - Throw.IfNullOrWhiteSpace(taskId); - - return SendRequestAsync( - RequestMethods.TasksResult, - new GetTaskPayloadRequestParams { TaskId = taskId, Meta = options?.GetMetaForRequest() }, - McpJsonUtilities.JsonContext.Default.GetTaskPayloadRequestParams, - McpJsonUtilities.JsonContext.Default.JsonElement, - cancellationToken: cancellationToken); - } + public Task SetLoggingLevelAsync(LogLevel level, RequestOptions? options = null, CancellationToken cancellationToken = default) => + SetLoggingLevelAsync(McpServerImpl.ToLoggingLevel(level), options, cancellationToken); /// - /// Retrieves a list of all tasks from the server. + /// Sets the logging level for the server to control which log messages are sent to the client. /// + /// The minimum severity level of log messages to receive from the server. /// Optional request options including metadata, serialization settings, and progress tracking. /// The to monitor for cancellation requests. The default is . - /// A list of all tasks. + /// A task representing the asynchronous operation. /// The request failed or the server returned an error response. - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask> ListTasksAsync( - RequestOptions? options = null, - CancellationToken cancellationToken = default) + public Task SetLoggingLevelAsync(LoggingLevel level, RequestOptions? options = null, CancellationToken cancellationToken = default) { - ListTasksRequestParams requestParams = new() { Meta = options?.GetMetaForRequest() }; - List tasks = new(); - do - { - var taskResults = await ListTasksAsync(requestParams, cancellationToken).ConfigureAwait(false); - tasks.AddRange(taskResults.Tasks); - requestParams.Cursor = taskResults.NextCursor; - } - while (requestParams.Cursor is not null); - - return tasks; + return SetLoggingLevelAsync( + new SetLevelRequestParams + { + Level = level, + Meta = options?.GetMetaForRequest() + }, + cancellationToken); } /// - /// Retrieves a list of tasks from the server. + /// Sets the logging level for the server to control which log messages are sent to the client. /// /// The request parameters to send in the request. /// The to monitor for cancellation requests. The default is . - /// The result of the request as provided by the server. + /// The result of the request. /// is . /// The request failed or the server returned an error response. - /// - /// The overload retrieves all tasks by automatically handling pagination. - /// This overload works with the lower-level and , returning the raw result from the server. - /// Any pagination needs to be managed by the caller. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public ValueTask ListTasksAsync( - ListTasksRequestParams requestParams, + public Task SetLoggingLevelAsync( + SetLevelRequestParams requestParams, CancellationToken cancellationToken = default) { Throw.IfNull(requestParams); return SendRequestAsync( - RequestMethods.TasksList, + RequestMethods.LoggingSetLevel, requestParams, - McpJsonUtilities.JsonContext.Default.ListTasksRequestParams, - McpJsonUtilities.JsonContext.Default.ListTasksResult, - cancellationToken: cancellationToken); + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); } /// - /// Cancels a running task on the server. + /// Retrieves the current state of a task from the server. /// - /// The unique identifier of the task to cancel. - /// Optional request options including metadata, serialization settings, and progress tracking. + /// The stable identifier of the task to retrieve. /// The to monitor for cancellation requests. The default is . - /// The updated state of the task after cancellation. + /// A subtype representing the current task state. /// is . - /// is empty or composed entirely of whitespace. /// The request failed or the server returned an error response. - /// - /// Cancelling a task requests that the server stop execution. The server may not immediately cancel the task, - /// and may choose to allow the task to complete if it's close to finishing. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask CancelTaskAsync( + public ValueTask GetTaskAsync( string taskId, - RequestOptions? options = null, CancellationToken cancellationToken = default) { - Throw.IfNullOrWhiteSpace(taskId); - - var result = await SendRequestAsync( - RequestMethods.TasksCancel, - new CancelMcpTaskRequestParams { TaskId = taskId, Meta = options?.GetMetaForRequest() }, - McpJsonUtilities.JsonContext.Default.CancelMcpTaskRequestParams, - McpJsonUtilities.JsonContext.Default.CancelMcpTaskResult, - cancellationToken: cancellationToken).ConfigureAwait(false); + Throw.IfNull(taskId); - // Convert CancelMcpTaskResult to McpTask - return new McpTask - { - TaskId = result.TaskId, - Status = result.Status, - StatusMessage = result.StatusMessage, - CreatedAt = result.CreatedAt, - LastUpdatedAt = result.LastUpdatedAt, - TimeToLive = result.TimeToLive, - PollInterval = result.PollInterval - }; + return GetTaskAsync(new GetTaskRequestParams { TaskId = taskId }, cancellationToken); } /// - /// Polls a task until it reaches a terminal status (completed, failed, or cancelled). + /// Retrieves the current state of a task from the server. /// - /// The unique identifier of the task to poll. - /// Optional request options including metadata, serialization settings, and progress tracking. + /// The request parameters to send in the request. /// The to monitor for cancellation requests. The default is . - /// The task in its terminal state. - /// is . - /// is empty or composed entirely of whitespace. - /// - /// - /// This method repeatedly calls until the task reaches a terminal status. - /// It respects the returned by the server to determine how long - /// to wait between polling attempts. - /// - /// - /// For retrieving the actual result of a completed task, use . - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask PollTaskUntilCompleteAsync( - string taskId, - RequestOptions? options = null, + /// A subtype representing the current task state. + /// is . + /// The request failed or the server returned an error response. + public ValueTask GetTaskAsync( + GetTaskRequestParams requestParams, CancellationToken cancellationToken = default) { - Throw.IfNullOrWhiteSpace(taskId); - - McpTask task; - do - { - task = await GetTaskAsync(taskId, options, cancellationToken).ConfigureAwait(false); - - // If task is in a terminal state, we're done - if (task.Status is McpTaskStatus.Completed or McpTaskStatus.Failed or McpTaskStatus.Cancelled) - { - break; - } - - // Wait for the poll interval before checking again (default to 1 second) - var pollInterval = task.PollInterval ?? TimeSpan.FromSeconds(1); - await Task.Delay(pollInterval, cancellationToken).ConfigureAwait(false); - } - while (true); + Throw.IfNull(requestParams); - return task; + return SendRequestAsync( + RequestMethods.TasksGet, + requestParams, + McpJsonUtilities.JsonContext.Default.GetTaskRequestParams, + McpJsonUtilities.JsonContext.Default.GetTaskResult, + cancellationToken: cancellationToken); } /// - /// Sets the logging level for the server to control which log messages are sent to the client. + /// Provides input responses to a task that is in the state. /// - /// The minimum severity level of log messages to receive from the server. - /// Optional request options including metadata, serialization settings, and progress tracking. + /// The request parameters containing the task ID and input responses. /// The to monitor for cancellation requests. The default is . - /// A task representing the asynchronous operation. + /// The result acknowledging the update. + /// is . /// The request failed or the server returned an error response. - public Task SetLoggingLevelAsync(LogLevel level, RequestOptions? options = null, CancellationToken cancellationToken = default) => - SetLoggingLevelAsync(McpServerImpl.ToLoggingLevel(level), options, cancellationToken); + public ValueTask UpdateTaskAsync( + UpdateTaskRequestParams requestParams, + CancellationToken cancellationToken = default) + { + Throw.IfNull(requestParams); + + return SendRequestAsync( + RequestMethods.TasksUpdate, + requestParams, + McpJsonUtilities.JsonContext.Default.UpdateTaskRequestParams, + McpJsonUtilities.JsonContext.Default.UpdateTaskResult, + cancellationToken: cancellationToken); + } /// - /// Sets the logging level for the server to control which log messages are sent to the client. + /// Requests cancellation of an in-progress task on the server. /// - /// The minimum severity level of log messages to receive from the server. - /// Optional request options including metadata, serialization settings, and progress tracking. + /// The stable identifier of the task to cancel. /// The to monitor for cancellation requests. The default is . - /// A task representing the asynchronous operation. + /// The result acknowledging the cancellation request. + /// is . /// The request failed or the server returned an error response. - public Task SetLoggingLevelAsync(LoggingLevel level, RequestOptions? options = null, CancellationToken cancellationToken = default) + public ValueTask CancelTaskAsync( + string taskId, + CancellationToken cancellationToken = default) { - return SetLoggingLevelAsync( - new SetLevelRequestParams - { - Level = level, - Meta = options?.GetMetaForRequest() - }, - cancellationToken); + Throw.IfNull(taskId); + + return CancelTaskAsync(new CancelTaskRequestParams { TaskId = taskId }, cancellationToken); } /// - /// Sets the logging level for the server to control which log messages are sent to the client. + /// Requests cancellation of an in-progress task on the server. /// /// The request parameters to send in the request. /// The to monitor for cancellation requests. The default is . - /// The result of the request. + /// The result acknowledging the cancellation request. /// is . /// The request failed or the server returned an error response. - public Task SetLoggingLevelAsync( - SetLevelRequestParams requestParams, + public ValueTask CancelTaskAsync( + CancelTaskRequestParams requestParams, CancellationToken cancellationToken = default) { Throw.IfNull(requestParams); return SendRequestAsync( - RequestMethods.LoggingSetLevel, + RequestMethods.TasksCancel, requestParams, - McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); + McpJsonUtilities.JsonContext.Default.CancelTaskRequestParams, + McpJsonUtilities.JsonContext.Default.CancelTaskResult, + cancellationToken: cancellationToken); } /// Converts a dictionary with values to a dictionary with values. @@ -1363,4 +1271,13 @@ public Task SetLoggingLevelAsync( return result; } + + private static JsonObject GetMetaWithTaskCapability(JsonObject? existingMeta) + { + JsonObject meta = existingMeta is not null + ? (JsonObject)existingMeta.DeepClone() + : []; + meta.TryAdd(McpExtensions.Tasks, new JsonObject()); + return meta; + } } diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index 406969121..efef64aa0 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -1,4 +1,5 @@ using System.Diagnostics.CodeAnalysis; +using System.Text.Json; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Client; @@ -70,4 +71,17 @@ protected McpClient() /// /// public abstract Task Completion { get; } + + /// + /// Resolves input requests embedded in an by dispatching + /// each request to the appropriate registered handler. + /// + /// + /// The input requests from the task, keyed by request identifier. Each value is a JSON object + /// with method and params fields representing a server-to-client request. + /// + /// The to monitor for cancellation requests. + /// A dictionary of responses keyed by the same identifiers as the input requests. + private protected abstract ValueTask> ResolveInputRequestsAsync( + IDictionary inputRequests, CancellationToken cancellationToken); } diff --git a/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs b/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs index 2109555bc..0866e4aef 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientHandlers.cs @@ -86,25 +86,4 @@ public sealed class McpClientHandlers /// /// public Func, CancellationToken, ValueTask>? SamplingHandler { get; set; } - - /// - /// Gets or sets the handler for processing notifications. - /// - /// - /// - /// This handler is called when the server sends a task status notification to inform the client - /// about changes to a task's state. These notifications are optional and clients MUST NOT rely - /// on receiving them. - /// - /// - /// The handler receives the updated object containing the current task state, - /// including its status, status message, and timestamps. - /// - /// - /// This handler is typically used to update UI or trigger actions based on task progress - /// without requiring explicit polling. - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public Func? TaskStatusHandler { get; set; } } diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs index 0d5803559..681b47a00 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -22,7 +22,6 @@ internal sealed partial class McpClientImpl : McpClient private readonly McpClientOptions _options; private readonly McpSessionHandler _sessionHandler; private readonly SemaphoreSlim _disposeLock = new(1, 1); - private readonly McpTaskCancellationTokenProvider? _taskCancellationTokenProvider; private readonly ConcurrentDictionary _toolCache = new(StringComparer.Ordinal); private ServerCapabilities? _serverCapabilities; @@ -49,12 +48,6 @@ internal McpClientImpl(ITransport transport, string endpointName, McpClientOptio _options = options; _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; - // Only allocate the cancellation token provider if a task store is configured - if (options.TaskStore is not null) - { - _taskCancellationTokenProvider = new(); - } - var notificationHandlers = new NotificationHandlers(); var requestHandlers = new RequestHandlers(); @@ -83,89 +76,22 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not var samplingHandler = handlers.SamplingHandler; var rootsHandler = handlers.RootsHandler; var elicitationHandler = handlers.ElicitationHandler; - var taskStatusHandler = handlers.TaskStatusHandler; - var taskStore = options.TaskStore; if (notificationHandlersFromOptions is not null) { notificationHandlers.RegisterRange(notificationHandlersFromOptions); } - if (taskStatusHandler is not null) - { - notificationHandlers.Register( - NotificationMethods.TaskStatusNotification, - (notification, cancellationToken) => - { - if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.McpTaskStatusNotificationParams) is { } notificationParams) - { - var task = new McpTask - { - TaskId = notificationParams.TaskId, - Status = notificationParams.Status, - StatusMessage = notificationParams.StatusMessage, - CreatedAt = notificationParams.CreatedAt, - LastUpdatedAt = notificationParams.LastUpdatedAt, - TimeToLive = notificationParams.TimeToLive, - PollInterval = notificationParams.PollInterval - }; - return taskStatusHandler(task, cancellationToken); - } - - return default; - }); - } - if (samplingHandler is not null) { - // If task store is configured, wrap the handler to support task-augmented requests - if (taskStore is not null) - { - requestHandlers.Set( - RequestMethods.SamplingCreateMessage, - async (request, jsonRpcRequest, cancellationToken) => - { - // Check if this is a task-augmented request - if (request?.Task is { } taskMetadata) - { - // Create task in store and return immediately - return await ExecuteAsTaskAsync( - taskStore, - taskMetadata, - jsonRpcRequest, - async ct => - { - var result = await samplingHandler( - request, - request.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - ct).ConfigureAwait(false); - return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CreateMessageResult); - }, - options.SendTaskStatusNotifications, - cancellationToken).ConfigureAwait(false); - } - - // Normal synchronous execution - serialize result to JsonElement - var samplingResult = await samplingHandler( - request, - request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken).ConfigureAwait(false); - return JsonSerializer.SerializeToElement(samplingResult, McpJsonUtilities.JsonContext.Default.CreateMessageResult); - }, - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.JsonElement); // Return JsonElement to support both CreateMessageResult and CreateTaskResult - } - else - { - requestHandlers.Set( - RequestMethods.SamplingCreateMessage, - (request, _, cancellationToken) => samplingHandler( - request, - request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult); - } + requestHandlers.Set( + RequestMethods.SamplingCreateMessage, + (request, _, cancellationToken) => samplingHandler( + request, + request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken), + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult); _options.Capabilities ??= new(); _options.Capabilities.Sampling ??= new(); @@ -185,51 +111,15 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not if (elicitationHandler is not null) { - // If task store is configured, wrap the handler to support task-augmented requests - if (taskStore is not null) - { - requestHandlers.Set( - RequestMethods.ElicitationCreate, - async (request, jsonRpcRequest, cancellationToken) => - { - // Check if this is a task-augmented request - if (request?.Task is { } taskMetadata) - { - // Create task in store and return immediately - return await ExecuteAsTaskAsync( - taskStore, - taskMetadata, - jsonRpcRequest, - async ct => - { - var result = await elicitationHandler(request, ct).ConfigureAwait(false); - result = ElicitResult.WithDefaults(request, result); - return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.ElicitResult); - }, - options.SendTaskStatusNotifications, - cancellationToken).ConfigureAwait(false); - } - - // Normal synchronous execution - serialize result to JsonElement - var elicitResult = await elicitationHandler(request, cancellationToken).ConfigureAwait(false); - elicitResult = ElicitResult.WithDefaults(request, elicitResult); - return JsonSerializer.SerializeToElement(elicitResult, McpJsonUtilities.JsonContext.Default.ElicitResult); - }, - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.JsonElement); // Return JsonElement to support both ElicitResult and CreateTaskResult - } - else - { - requestHandlers.Set( - RequestMethods.ElicitationCreate, - async (request, _, cancellationToken) => - { - var result = await elicitationHandler(request, cancellationToken).ConfigureAwait(false); - return ElicitResult.WithDefaults(request, result); - }, - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult); - } + requestHandlers.Set( + RequestMethods.ElicitationCreate, + async (request, _, cancellationToken) => + { + var result = await elicitationHandler(request, cancellationToken).ConfigureAwait(false); + return ElicitResult.WithDefaults(request, result); + }, + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult); _options.Capabilities ??= new(); _options.Capabilities.Elicitation ??= new(); @@ -240,276 +130,6 @@ private void RegisterHandlers(McpClientOptions options, NotificationHandlers not _options.Capabilities.Elicitation.Form = new(); } } - - // Register task handlers if a task store is configured - if (taskStore is not null) - { - RegisterTaskHandlers(requestHandlers, taskStore); - } - } - - /// - /// Executes an operation as a task, creating the task immediately and running the operation asynchronously. - /// - private async ValueTask ExecuteAsTaskAsync( - IMcpTaskStore taskStore, - McpTaskMetadata taskMetadata, - JsonRpcRequest jsonRpcRequest, - Func> operation, - bool sendNotifications, - CancellationToken cancellationToken) - { - // Create the task in the store - var mcpTask = await taskStore.CreateTaskAsync( - taskMetadata, - jsonRpcRequest.Id, - jsonRpcRequest, - SessionId, - cancellationToken).ConfigureAwait(false); - - // Register the task for TTL-based cancellation - var taskCancellationToken = _taskCancellationTokenProvider!.RequestToken(mcpTask.TaskId, mcpTask.TimeToLive); - - // Execute the operation asynchronously in the background - _ = Task.Run(async () => - { - try - { - // Send notification if enabled - if (sendNotifications) - { - var workingTask = await taskStore.GetTaskAsync(mcpTask.TaskId, SessionId, CancellationToken.None).ConfigureAwait(false); - if (workingTask is not null) - { - _ = NotifyTaskStatusAsync(workingTask, CancellationToken.None); - } - } - - // Execute the operation with task-specific cancellation token - var result = await operation(taskCancellationToken).ConfigureAwait(false); - - // Store the result - var completedTask = await taskStore.StoreTaskResultAsync( - mcpTask.TaskId, - McpTaskStatus.Completed, - result, - SessionId, - CancellationToken.None).ConfigureAwait(false); - - // Send final notification if enabled - if (sendNotifications) - { - _ = NotifyTaskStatusAsync(completedTask, CancellationToken.None); - } - } - catch (OperationCanceledException) when (taskCancellationToken.IsCancellationRequested) - { - // Task was cancelled via TTL expiration or explicit cancellation. - // For TTL expiration, the task is deleted so no status update needed. - // For explicit cancellation, the cancel handler already updates the status. - } - catch (Exception ex) - { - // Store error result using a simple string message - try - { - var errorElement = JsonSerializer.SerializeToElement(ex.Message, McpJsonUtilities.JsonContext.Default.String); - await taskStore.StoreTaskResultAsync( - mcpTask.TaskId, - McpTaskStatus.Failed, - errorElement, - SessionId, - CancellationToken.None).ConfigureAwait(false); - - // Update task with error message - var failedTask = await taskStore.UpdateTaskStatusAsync( - mcpTask.TaskId, - McpTaskStatus.Failed, - ex.Message, - SessionId, - CancellationToken.None).ConfigureAwait(false); - - // Send failure notification if enabled - if (sendNotifications) - { - _ = NotifyTaskStatusAsync(failedTask, CancellationToken.None); - } - } - catch - { - // If we can't store the error result, there's not much we can do - } - } - finally - { - // Clean up task cancellation tracking - _taskCancellationTokenProvider!.Complete(mcpTask.TaskId); - } - }, CancellationToken.None); - - // Return the task result immediately - var createTaskResult = new CreateTaskResult { Task = mcpTask }; - return JsonSerializer.SerializeToElement(createTaskResult, McpJsonUtilities.JsonContext.Default.CreateTaskResult); - } - - /// - /// Sends a task status notification to the connected server. - /// - private Task NotifyTaskStatusAsync(McpTask task, CancellationToken cancellationToken) - { - var notificationParams = new McpTaskStatusNotificationParams - { - TaskId = task.TaskId, - Status = task.Status, - StatusMessage = task.StatusMessage, - CreatedAt = task.CreatedAt, - LastUpdatedAt = task.LastUpdatedAt, - TimeToLive = task.TimeToLive, - PollInterval = task.PollInterval - }; - - return this.SendNotificationAsync( - NotificationMethods.TaskStatusNotification, - notificationParams, - McpJsonUtilities.JsonContext.Default.McpTaskStatusNotificationParams, - cancellationToken); - } - - /// - /// Registers handlers for task-related requests from the server. - /// - private void RegisterTaskHandlers(RequestHandlers requestHandlers, IMcpTaskStore taskStore) - { - // tasks/get handler - Retrieve task status - requestHandlers.Set( - RequestMethods.TasksGet, - async (request, _, cancellationToken) => - { - if (request?.TaskId is not { } taskId) - { - throw new McpProtocolException("Missing required parameter 'taskId'", McpErrorCode.InvalidParams); - } - - var task = await taskStore.GetTaskAsync(taskId, SessionId, cancellationToken).ConfigureAwait(false); - if (task is null) - { - throw new McpProtocolException($"Task not found: '{taskId}'", McpErrorCode.InvalidParams); - } - - return new GetTaskResult - { - TaskId = task.TaskId, - Status = task.Status, - StatusMessage = task.StatusMessage, - CreatedAt = task.CreatedAt, - LastUpdatedAt = task.LastUpdatedAt, - TimeToLive = task.TimeToLive, - PollInterval = task.PollInterval - }; - }, - McpJsonUtilities.JsonContext.Default.GetTaskRequestParams, - McpJsonUtilities.JsonContext.Default.GetTaskResult); - - // tasks/result handler - Retrieve task result (blocking until terminal status) - requestHandlers.Set( - RequestMethods.TasksResult, - async (request, _, cancellationToken) => - { - if (request?.TaskId is not { } taskId) - { - throw new McpProtocolException("Missing required parameter 'taskId'", McpErrorCode.InvalidParams); - } - - // Poll until task reaches terminal status - while (true) - { - McpTask? task = await taskStore.GetTaskAsync(taskId, SessionId, cancellationToken).ConfigureAwait(false); - if (task is null) - { - throw new McpProtocolException($"Task not found: '{taskId}'", McpErrorCode.InvalidParams); - } - - // If terminal, break and retrieve result - if (task.Status is McpTaskStatus.Completed or McpTaskStatus.Failed or McpTaskStatus.Cancelled) - { - break; - } - - // Poll according to task's pollInterval (default 1 second) - var pollInterval = task.PollInterval ?? TimeSpan.FromSeconds(1); - await Task.Delay(pollInterval, cancellationToken).ConfigureAwait(false); - } - - // Retrieve the stored result - return await taskStore.GetTaskResultAsync(taskId, SessionId, cancellationToken).ConfigureAwait(false); - }, - McpJsonUtilities.JsonContext.Default.GetTaskPayloadRequestParams, - McpJsonUtilities.JsonContext.Default.JsonElement); - - // tasks/list handler - List tasks with pagination - requestHandlers.Set( - RequestMethods.TasksList, - async (request, _, cancellationToken) => - { - var cursor = request?.Cursor; - return await taskStore.ListTasksAsync(cursor, SessionId, cancellationToken).ConfigureAwait(false); - }, - McpJsonUtilities.JsonContext.Default.ListTasksRequestParams, - McpJsonUtilities.JsonContext.Default.ListTasksResult); - - // tasks/cancel handler - Cancel a task - requestHandlers.Set( - RequestMethods.TasksCancel, - async (request, _, cancellationToken) => - { - if (request?.TaskId is not { } taskId) - { - throw new McpProtocolException("Missing required parameter 'taskId'", McpErrorCode.InvalidParams); - } - - // Signal cancellation if task is still running - _taskCancellationTokenProvider!.Cancel(taskId); - - var task = await taskStore.CancelTaskAsync(taskId, SessionId, cancellationToken).ConfigureAwait(false); - if (task is null) - { - throw new McpProtocolException($"Task not found: '{taskId}'", McpErrorCode.InvalidParams); - } - - return new CancelMcpTaskResult - { - TaskId = task.TaskId, - Status = task.Status, - StatusMessage = task.StatusMessage, - CreatedAt = task.CreatedAt, - LastUpdatedAt = task.LastUpdatedAt, - TimeToLive = task.TimeToLive, - PollInterval = task.PollInterval - }; - }, - McpJsonUtilities.JsonContext.Default.CancelMcpTaskRequestParams, - McpJsonUtilities.JsonContext.Default.CancelMcpTaskResult); - - // Advertise task capabilities - _options.Capabilities ??= new(); - var tasksCapability = _options.Capabilities.Tasks ??= new McpTasksCapability(); - tasksCapability.List ??= new ListMcpTasksCapability(); - tasksCapability.Cancel ??= new CancelMcpTasksCapability(); - var requestsCapability = tasksCapability.Requests ??= new RequestMcpTasksCapability(); - - // Only advertise sampling tasks if sampling handler is present - if (_options.Handlers.SamplingHandler is not null) - { - var samplingCapability = requestsCapability.Sampling ??= new SamplingMcpTasksCapability(); - samplingCapability.CreateMessage ??= new CreateMessageMcpTasksCapability(); - } - - // Only advertise elicitation tasks if elicitation handler is present - if (_options.Handlers.ElicitationHandler is not null) - { - var elicitationCapability = requestsCapability.Elicitation ??= new ElicitationMcpTasksCapability(); - elicitationCapability.Create ??= new CreateElicitationMcpTasksCapability(); - } } /// @@ -637,6 +257,66 @@ internal void ResumeSession(ResumeClientSessionOptions resumeOptions) LogClientSessionResumed(_endpointName); } + /// + private protected override async ValueTask> ResolveInputRequestsAsync( + IDictionary inputRequests, CancellationToken cancellationToken) + { + var responses = new Dictionary(inputRequests.Count); + + foreach (var kvp in inputRequests) + { + var response = await ResolveInputRequestAsync(kvp.Value, cancellationToken).ConfigureAwait(false); + responses[kvp.Key] = response; + } + + return responses; + } + + private async Task ResolveInputRequestAsync(JsonElement requestElement, CancellationToken cancellationToken) + { + using var doc = JsonDocument.Parse(requestElement.GetRawText()); + var root = doc.RootElement; + + var method = root.GetProperty("method").GetString() + ?? throw new McpException("Input request is missing 'method' property."); + + JsonElement paramsElement = root.TryGetProperty("params", out var p) ? p : default; + + switch (method) + { + case RequestMethods.SamplingCreateMessage: + if (_options.Handlers.SamplingHandler is { } samplingHandler) + { + var samplingParams = JsonSerializer.Deserialize(paramsElement, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams) + ?? throw new McpException("Failed to deserialize sampling parameters from input request."); + var result = await samplingHandler( + samplingParams, + samplingParams.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken).ConfigureAwait(false); + return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CreateMessageResult); + } + + throw new InvalidOperationException( + $"Server sent a sampling input request, but no {nameof(McpClientHandlers.SamplingHandler)} is registered."); + + case RequestMethods.ElicitationCreate: + if (_options.Handlers.ElicitationHandler is { } elicitationHandler) + { + var elicitParams = JsonSerializer.Deserialize(paramsElement, McpJsonUtilities.JsonContext.Default.ElicitRequestParams) + ?? throw new McpException("Failed to deserialize elicitation parameters from input request."); + var result = await elicitationHandler(elicitParams, cancellationToken).ConfigureAwait(false); + result = ElicitResult.WithDefaults(elicitParams, result); + return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.ElicitResult); + } + + throw new InvalidOperationException( + $"Server sent an elicitation input request, but no {nameof(McpClientHandlers.ElicitationHandler)} is registered."); + + default: + throw new NotSupportedException($"Unsupported input request method: '{method}'."); + } + } + /// public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { @@ -676,7 +356,6 @@ public override async ValueTask DisposeAsync() _disposed = true; - _taskCancellationTokenProvider?.Dispose(); await _sessionHandler.DisposeAsync().ConfigureAwait(false); await _transport.DisposeAsync().ConfigureAwait(false); diff --git a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs index 6d91f5b03..8a2364ca4 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs @@ -79,36 +79,4 @@ public McpClientHandlers Handlers field = value; } } - - /// - /// Gets or sets the task store for managing client-side tasks. - /// - /// - /// - /// When a task store is configured, the client will support task-augmented requests from the server. - /// This allows the server to request sampling or elicitation as tasks, which the client executes - /// asynchronously and allows the server to poll for status and results. - /// - /// - /// If not set, task-augmented requests will not be supported, and the client will not advertise - /// task capabilities to the server. - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public IMcpTaskStore? TaskStore { get; set; } - - /// - /// Gets or sets a value indicating whether the client should send task status notifications to the server. - /// - /// - /// to send task status notifications; otherwise. - /// The default is . - /// - /// - /// When enabled and a is configured, the client will send optional - /// notifications/tasks/status notifications to inform the server of task state changes. - /// Servers MUST NOT rely on receiving these notifications and should continue polling via tasks/get. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public bool SendTaskStatusNotifications { get; set; } = true; } diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index abb6d29df..eda78c5aa 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -108,12 +108,16 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(ResourceUpdatedNotificationParams))] [JsonSerializable(typeof(RootsListChangedNotificationParams))] [JsonSerializable(typeof(ToolListChangedNotificationParams))] - [JsonSerializable(typeof(McpTaskStatusNotificationParams))] + [JsonSerializable(typeof(TaskStatusNotificationParams))] + [JsonSerializable(typeof(WorkingTaskNotificationParams))] + [JsonSerializable(typeof(CompletedTaskNotificationParams))] + [JsonSerializable(typeof(FailedTaskNotificationParams))] + [JsonSerializable(typeof(CancelledTaskNotificationParams))] + [JsonSerializable(typeof(InputRequiredTaskNotificationParams))] // MCP Request Params / Results [JsonSerializable(typeof(CallToolRequestParams))] [JsonSerializable(typeof(CallToolResult))] - [JsonSerializable(typeof(CreateTaskResult))] [JsonSerializable(typeof(CompleteRequestParams))] [JsonSerializable(typeof(CompleteResult))] [JsonSerializable(typeof(CreateMessageRequestParams))] @@ -143,22 +147,18 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(SetLevelRequestParams))] [JsonSerializable(typeof(SubscribeRequestParams))] [JsonSerializable(typeof(UnsubscribeRequestParams))] - - // MCP Task Request Params / Results - [JsonSerializable(typeof(McpTask))] - [JsonSerializable(typeof(McpTaskStatus))] - [JsonSerializable(typeof(McpTaskMetadata))] [JsonSerializable(typeof(GetTaskRequestParams))] [JsonSerializable(typeof(GetTaskResult))] - [JsonSerializable(typeof(GetTaskPayloadRequestParams))] - [JsonSerializable(typeof(ListTasksRequestParams))] - [JsonSerializable(typeof(ListTasksResult))] - [JsonSerializable(typeof(CancelMcpTaskRequestParams))] - [JsonSerializable(typeof(CancelMcpTaskResult))] - [JsonSerializable(typeof(McpTasksCapability))] - [JsonSerializable(typeof(RequestMcpTasksCapability))] - [JsonSerializable(typeof(ToolExecution))] - [JsonSerializable(typeof(ToolTaskSupport))] + [JsonSerializable(typeof(WorkingTaskResult))] + [JsonSerializable(typeof(CompletedTaskResult))] + [JsonSerializable(typeof(FailedTaskResult))] + [JsonSerializable(typeof(CancelledTaskResult))] + [JsonSerializable(typeof(InputRequiredTaskResult))] + [JsonSerializable(typeof(UpdateTaskRequestParams))] + [JsonSerializable(typeof(UpdateTaskResult))] + [JsonSerializable(typeof(CancelTaskRequestParams))] + [JsonSerializable(typeof(CancelTaskResult))] + [JsonSerializable(typeof(CreateTaskResult))] // MCP Content [JsonSerializable(typeof(ContentBlock))] @@ -177,9 +177,9 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(TextResourceContents))] // Other MCP Types + [JsonSerializable(typeof(IDictionary))] [JsonSerializable(typeof(IReadOnlyDictionary))] [JsonSerializable(typeof(ProgressToken))] - [JsonSerializable(typeof(JsonElement))] [JsonSerializable(typeof(ProtectedResourceMetadata))] [JsonSerializable(typeof(AuthorizationServerMetadata))] diff --git a/src/ModelContextProtocol.Core/McpTaskCancellationTokenProvider.cs b/src/ModelContextProtocol.Core/McpTaskCancellationTokenProvider.cs deleted file mode 100644 index 6ecfc4f4a..000000000 --- a/src/ModelContextProtocol.Core/McpTaskCancellationTokenProvider.cs +++ /dev/null @@ -1,127 +0,0 @@ -using System.Collections.Concurrent; - -namespace ModelContextProtocol; - -/// -/// Provides cancellation tokens for running MCP tasks, enabling TTL-based -/// automatic cancellation and explicit task cancellation. -/// -/// -/// -/// This class provides lifecycle management for instances -/// associated with running tasks. Each task gets its own CTS that can be: -/// -/// -/// Automatically cancelled when the task's TTL expires -/// Explicitly cancelled via the method -/// Cleaned up when the task completes via -/// -/// -/// Both McpClient and McpServer use this class to manage task cancellation -/// independently of request cancellation tokens. -/// -/// -internal sealed class McpTaskCancellationTokenProvider : IDisposable -{ - private readonly ConcurrentDictionary _runningTasks = new(); - private bool _disposed; - - /// - /// Registers a new task and returns a cancellation token for use during execution. - /// - /// The unique identifier of the task. - /// - /// Optional TTL duration. If specified, the returned token will be automatically - /// cancelled when the TTL expires. - /// - /// - /// A that will be cancelled when the TTL expires, - /// when is called, or when this provider is disposed. - /// - /// The provider has been disposed. - /// A task with the same ID is already registered. - public CancellationToken RequestToken(string taskId, TimeSpan? timeToLive) - { - if (_disposed) - { - throw new ObjectDisposedException(nameof(McpTaskCancellationTokenProvider)); - } - - Throw.IfNullOrWhiteSpace(taskId); - CancellationTokenSource cts = new(); - - if (timeToLive is { } ttl) - { - cts.CancelAfter(ttl); - } - - if (!_runningTasks.TryAdd(taskId, cts)) - { - cts.Dispose(); - throw new InvalidOperationException($"Task '{taskId}' is already registered."); - } - - return cts.Token; - } - - /// - /// Attempts to cancel a running task. - /// - /// The unique identifier of the task to cancel. - /// - /// This method signals cancellation but does not remove the task from tracking. - /// The task executor should call when it observes - /// the cancellation and finishes cleanup. - /// - public void Cancel(string taskId) - { - if (_runningTasks.TryGetValue(taskId, out var cts)) - { - cts.Cancel(); - } - } - - /// - /// Marks a task as complete and releases its associated resources. - /// - /// The unique identifier of the task that has completed. - /// - /// This method should be called from a finally block in the task execution - /// to ensure proper cleanup regardless of success, failure, or cancellation. - /// - public void Complete(string taskId) - { - if (_runningTasks.TryRemove(taskId, out var cts)) - { - cts.Dispose(); - } - } - - /// - /// Cancels all running tasks and releases all resources. - /// - public void Dispose() - { - if (_disposed) - { - return; - } - - _disposed = true; - - foreach (var kvp in _runningTasks) - { - try - { - kvp.Value.Cancel(); - kvp.Value.Dispose(); - } - catch - { - // Best effort cleanup - } - } - - _runningTasks.Clear(); - } -} diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj index b6423b0c8..23045b317 100644 --- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj +++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj @@ -7,7 +7,7 @@ ModelContextProtocol.Core Core .NET SDK for the Model Context Protocol (MCP) README.md - + $(NoWarn);MCPEXP001 @@ -35,6 +35,7 @@ + diff --git a/src/ModelContextProtocol.Core/Protocol/CallToolRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/CallToolRequestParams.cs index 8267cd06f..d311c6b4f 100644 --- a/src/ModelContextProtocol.Core/Protocol/CallToolRequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/CallToolRequestParams.cs @@ -26,24 +26,4 @@ public sealed class CallToolRequestParams : RequestParams /// [JsonPropertyName("arguments")] public IDictionary? Arguments { get; set; } - - /// - /// Gets or sets optional task metadata to augment this request with task execution. - /// - /// - /// When present, indicates that the requestor wants this operation executed as a task. - /// The receiver must support task augmentation for this specific request type. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - [JsonIgnore] - public McpTaskMetadata? Task - { - get => TaskCore; - set => TaskCore = value; - } - - // See ExperimentalInternalPropertyTests.cs before modifying this property. - [JsonInclude] - [JsonPropertyName("task")] - internal McpTaskMetadata? TaskCore { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/CallToolResult.cs b/src/ModelContextProtocol.Core/Protocol/CallToolResult.cs index 35dba5b6e..b2fdb3d05 100644 --- a/src/ModelContextProtocol.Core/Protocol/CallToolResult.cs +++ b/src/ModelContextProtocol.Core/Protocol/CallToolResult.cs @@ -64,25 +64,4 @@ public sealed class CallToolResult : Result /// [JsonPropertyName("isError")] public bool? IsError { get; set; } - - /// - /// Gets or sets the task data for the newly created task. - /// - /// - /// This property is populated only for task-augmented tool calls. When present, the other properties - /// (, , ) may not be populated. - /// The actual tool result can be retrieved later via tasks/result. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - [JsonIgnore] - public McpTask? Task - { - get => TaskCore; - set => TaskCore = value; - } - - // See ExperimentalInternalPropertyTests.cs before modifying this property. - [JsonInclude] - [JsonPropertyName("task")] - internal McpTask? TaskCore { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/CancelMcpTaskRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/CancelMcpTaskRequestParams.cs deleted file mode 100644 index c4fb540b2..000000000 --- a/src/ModelContextProtocol.Core/Protocol/CancelMcpTaskRequestParams.cs +++ /dev/null @@ -1,84 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Represents the parameters for a tasks/cancel request to explicitly cancel a task. -/// -/// -/// -/// Receivers must reject cancellation requests for tasks already in a terminal status -/// (, , or -/// ) with error code -32602 (Invalid params). -/// -/// -/// Upon receiving a valid cancellation request, receivers should attempt to stop the task -/// execution and must transition the task to status -/// before sending the response. -/// -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class CancelMcpTaskRequestParams : RequestParams -{ - /// - /// Gets or sets the unique identifier of the task to cancel. - /// - [JsonPropertyName("taskId")] - public required string TaskId { get; set; } -} - -/// -/// Represents the result of a tasks/cancel request. -/// -/// -/// The result contains the updated task state after cancellation. The task will be in -/// status if the cancellation was successful. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class CancelMcpTaskResult : Result -{ - /// - /// Gets or sets the task ID. - /// - [JsonPropertyName("taskId")] - public required string TaskId { get; set; } - - /// - /// Gets or sets the current status of the task (should be ). - /// - [JsonPropertyName("status")] - public required McpTaskStatus Status { get; set; } - - /// - /// Gets or sets an optional message describing the cancellation. - /// - [JsonPropertyName("statusMessage")] - public string? StatusMessage { get; set; } - - /// - /// Gets or sets the ISO 8601 timestamp when the task was created. - /// - [JsonPropertyName("createdAt")] - public required DateTimeOffset CreatedAt { get; set; } - - /// - /// Gets or sets the ISO 8601 timestamp when the task status was last updated. - /// - [JsonPropertyName("lastUpdatedAt")] - public required DateTimeOffset LastUpdatedAt { get; set; } - - /// - /// Gets or sets the time to live (retention duration) from creation before the task may be deleted. - /// - [JsonPropertyName("ttl")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? TimeToLive { get; set; } - - /// - /// Gets or sets the suggested time between status checks. - /// - [JsonPropertyName("pollInterval")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? PollInterval { get; set; } -} diff --git a/src/ModelContextProtocol.Core/Protocol/CancelTaskRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/CancelTaskRequestParams.cs new file mode 100644 index 000000000..ee458d064 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/CancelTaskRequestParams.cs @@ -0,0 +1,29 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the parameters for a tasks/cancel request to signal intent to cancel an in-progress task. +/// +/// +/// +/// Cancellation is cooperative: the request signals intent, and the server decides whether and when to honor it. +/// A server is not obligated to actually stop the work; it is only obligated to acknowledge the request. +/// Eventual transition to is not guaranteed. +/// +/// +/// The notifications/cancelled notification must not be used for task cancellation. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +public sealed class CancelTaskRequestParams : RequestParams +{ + /// + /// Gets or sets the identifier of the task to cancel. + /// + [JsonPropertyName("taskId")] + public required string TaskId { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/CancelTaskResult.cs b/src/ModelContextProtocol.Core/Protocol/CancelTaskResult.cs new file mode 100644 index 000000000..4d066862b --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/CancelTaskResult.cs @@ -0,0 +1,22 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the result of a tasks/cancel request. This is an empty acknowledgement. +/// +/// +/// +/// The server acknowledges the request with an empty result. Cancellation processing is +/// eventually consistent — the task's observable status may remain +/// after the ack, and may ultimately reach a terminal status other than +/// if the work finished before cancellation could take effect. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +public sealed class CancelTaskResult : Result +{ +} diff --git a/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs b/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs index 77b2bef9f..f41f50fd8 100644 --- a/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs +++ b/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs @@ -68,32 +68,6 @@ public sealed class ClientCapabilities [JsonPropertyName("elicitation")] public ElicitationCapability? Elicitation { get; set; } - /// - /// Gets or sets the client's tasks capability for supporting task-augmented requests. - /// - /// - /// - /// The tasks capability enables servers to augment their requests with tasks for long-running - /// operations. When present, servers can request that certain operations (like sampling or - /// elicitation) execute asynchronously, with the ability to poll for status and retrieve results later. - /// - /// - /// See for details on configuring which operations support tasks. - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - [JsonIgnore] - public McpTasksCapability? Tasks - { - get => TasksCore; - set => TasksCore = value; - } - - // See ExperimentalInternalPropertyTests.cs before modifying this property. - [JsonInclude] - [JsonPropertyName("tasks")] - internal McpTasksCapability? TasksCore { get; set; } - /// /// Gets or sets optional MCP extensions that the client supports. /// diff --git a/src/ModelContextProtocol.Core/Protocol/CreateMessageRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/CreateMessageRequestParams.cs index ef5e57d2c..bb27d70fd 100644 --- a/src/ModelContextProtocol.Core/Protocol/CreateMessageRequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/CreateMessageRequestParams.cs @@ -153,24 +153,4 @@ public sealed class CreateMessageRequestParams : RequestParams /// [JsonPropertyName("toolChoice")] public ToolChoice? ToolChoice { get; set; } - - /// - /// Gets or sets optional task metadata to augment this request with task execution. - /// - /// - /// When present, indicates that the requestor wants this operation executed as a task. - /// The receiver must support task augmentation for this specific request type. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - [JsonIgnore] - public McpTaskMetadata? Task - { - get => TaskCore; - set => TaskCore = value; - } - - // See ExperimentalInternalPropertyTests.cs before modifying this property. - [JsonInclude] - [JsonPropertyName("task")] - internal McpTaskMetadata? TaskCore { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/CreateTaskResult.cs b/src/ModelContextProtocol.Core/Protocol/CreateTaskResult.cs index 166d05e49..2e5bc0c41 100644 --- a/src/ModelContextProtocol.Core/Protocol/CreateTaskResult.cs +++ b/src/ModelContextProtocol.Core/Protocol/CreateTaskResult.cs @@ -1,28 +1,68 @@ -using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; /// -/// Represents the response to a task-augmented request. +/// Represents the result returned by a server when it creates a task in lieu of a standard result. /// /// /// -/// When a client sends a request with a task parameter, the server immediately returns -/// a containing the created task information instead of the -/// normal result type. The actual result can be retrieved later via tasks/result. +/// A server returns instead of the standard result shape (e.g., ) +/// to indicate that the request will be processed asynchronously. The client then uses +/// for subsequent tasks/get, tasks/update, and tasks/cancel calls. /// /// -/// This type is returned for any task-augmented request including tools/call, -/// sampling/createMessage, and elicitation/create. +/// A server must not return to a client that did not include the +/// io.modelcontextprotocol/tasks extension capability on its request. +/// +/// +/// See the SEP-2663 +/// specification for details. /// /// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] public sealed class CreateTaskResult : Result { /// - /// Gets or sets the task data for the newly created task. + /// Gets or sets the stable identifier for this task. + /// + [JsonPropertyName("taskId")] + public required string TaskId { get; set; } + + /// + /// Gets or sets the current task status. + /// + [JsonPropertyName("status")] + public required McpTaskStatus Status { get; set; } + + /// + /// Gets or sets an optional message describing the current task state. + /// + [JsonPropertyName("statusMessage")] + public string? StatusMessage { get; set; } + + /// + /// Gets or sets the ISO 8601 timestamp when the task was created. + /// + [JsonPropertyName("createdAt")] + public required DateTimeOffset CreatedAt { get; set; } + + /// + /// Gets or sets the ISO 8601 timestamp when the task was last updated. + /// + [JsonPropertyName("lastUpdatedAt")] + public required DateTimeOffset LastUpdatedAt { get; set; } + + /// + /// Gets or sets the time-to-live duration from creation in milliseconds, or for unlimited. + /// + [JsonPropertyName("ttlMs")] + public long? TtlMs { get; set; } + + /// + /// Gets or sets the suggested polling interval in milliseconds. /// - [JsonPropertyName("task")] - public McpTask Task { get; set; } = null!; + [JsonPropertyName("pollIntervalMs")] + public long? PollIntervalMs { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs index 39a5bd358..9dc1ac903 100644 --- a/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs @@ -92,26 +92,6 @@ public string Mode [JsonPropertyName("requestedSchema")] public RequestSchema? RequestedSchema { get; set; } - /// - /// Gets or sets optional task metadata to augment this request with task execution. - /// - /// - /// When present, indicates that the requestor wants this operation executed as a task. - /// The receiver must support task augmentation for this specific request type. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - [JsonIgnore] - public McpTaskMetadata? Task - { - get => TaskCore; - set => TaskCore = value; - } - - // See ExperimentalInternalPropertyTests.cs before modifying this property. - [JsonInclude] - [JsonPropertyName("task")] - internal McpTaskMetadata? TaskCore { get; set; } - /// Represents a request schema used in a form mode elicitation request. public sealed class RequestSchema { diff --git a/src/ModelContextProtocol.Core/Protocol/GetTaskPayloadRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/GetTaskPayloadRequestParams.cs deleted file mode 100644 index d64a8b1f9..000000000 --- a/src/ModelContextProtocol.Core/Protocol/GetTaskPayloadRequestParams.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Represents the parameters for a tasks/result request to retrieve the result of a completed task. -/// -/// -/// -/// This request blocks until the task reaches a terminal status (, -/// , or ). -/// -/// -/// The result structure matches the original request type (e.g., for tools/call). -/// This is distinct from the initial response, which contains only task data. -/// -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class GetTaskPayloadRequestParams : RequestParams -{ - /// - /// Gets or sets the unique identifier of the task whose result to retrieve. - /// - [JsonPropertyName("taskId")] - public required string TaskId { get; set; } -} diff --git a/src/ModelContextProtocol.Core/Protocol/GetTaskRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/GetTaskRequestParams.cs index a8aaaea93..52b82d902 100644 --- a/src/ModelContextProtocol.Core/Protocol/GetTaskRequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/GetTaskRequestParams.cs @@ -1,77 +1,26 @@ -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; /// -/// Represents the parameters for a tasks/get request to retrieve task status. +/// Represents the parameters for a tasks/get request to poll for task completion. /// /// -/// Requestors poll for task completion by sending tasks/get requests. They should -/// respect the provided in responses when determining -/// polling frequency. +/// +/// Clients poll for task completion by sending tasks/get requests. +/// Clients should respect the provided in responses +/// when determining polling frequency. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// /// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] public sealed class GetTaskRequestParams : RequestParams { /// - /// Gets or sets the unique identifier of the task to retrieve. + /// Gets or sets the identifier of the task to query. /// [JsonPropertyName("taskId")] public required string TaskId { get; set; } } - -/// -/// Represents the result of a tasks/get request. -/// -/// -/// The result contains the current state of the task, including its status, timestamps, -/// and any status message. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class GetTaskResult : Result -{ - /// - /// Gets or sets the task ID. - /// - [JsonPropertyName("taskId")] - public required string TaskId { get; set; } - - /// - /// Gets or sets the current status of the task. - /// - [JsonPropertyName("status")] - public required McpTaskStatus Status { get; set; } - - /// - /// Gets or sets an optional human-readable message describing the current state. - /// - [JsonPropertyName("statusMessage")] - public string? StatusMessage { get; set; } - - /// - /// Gets or sets the ISO 8601 timestamp when the task was created. - /// - [JsonPropertyName("createdAt")] - public required DateTimeOffset CreatedAt { get; set; } - - /// - /// Gets or sets the ISO 8601 timestamp when the task status was last updated. - /// - [JsonPropertyName("lastUpdatedAt")] - public required DateTimeOffset LastUpdatedAt { get; set; } - - /// - /// Gets or sets the time to live (retention duration) from creation before the task may be deleted. - /// - [JsonPropertyName("ttl")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? TimeToLive { get; set; } - - /// - /// Gets or sets the suggested time between status checks. - /// - [JsonPropertyName("pollInterval")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? PollInterval { get; set; } -} diff --git a/src/ModelContextProtocol.Core/Protocol/GetTaskResult.cs b/src/ModelContextProtocol.Core/Protocol/GetTaskResult.cs new file mode 100644 index 000000000..874551636 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/GetTaskResult.cs @@ -0,0 +1,419 @@ +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the result of a tasks/get request, containing the full task state. +/// +/// +/// +/// This is the abstract base for status-specific task results. The concrete type returned depends on the +/// task's current : +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +[JsonConverter(typeof(Converter))] +public abstract class GetTaskResult : Result +{ + /// Prevent external derivations. + private protected GetTaskResult() + { + } + + /// + /// Gets or sets the stable identifier for this task. + /// + [JsonPropertyName("taskId")] + public required string TaskId { get; set; } + + /// + /// Gets or sets the current task status. + /// + [JsonPropertyName("status")] + public abstract McpTaskStatus Status { get; } + + /// + /// Gets or sets an optional message describing the current task state. + /// + [JsonPropertyName("statusMessage")] + public string? StatusMessage { get; set; } + + /// + /// Gets or sets the ISO 8601 timestamp when the task was created. + /// + [JsonPropertyName("createdAt")] + public required DateTimeOffset CreatedAt { get; set; } + + /// + /// Gets or sets the ISO 8601 timestamp when the task was last updated. + /// + [JsonPropertyName("lastUpdatedAt")] + public required DateTimeOffset LastUpdatedAt { get; set; } + + /// + /// Gets or sets the time-to-live duration from creation in milliseconds, or for unlimited. + /// + [JsonPropertyName("ttlMs")] + public long? TtlMs { get; set; } + + /// + /// Gets or sets the suggested polling interval in milliseconds. + /// + [JsonPropertyName("pollIntervalMs")] + public long? PollIntervalMs { get; set; } + + /// + /// JSON converter that deserializes to the appropriate concrete subtype + /// based on the status discriminator field. + /// + internal sealed class Converter : JsonConverter + { + public override GetTaskResult? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected StartObject token for GetTaskResult."); + } + + string? taskId = null; + string? statusString = null; + string? statusMessage = null; + DateTimeOffset? createdAt = null; + DateTimeOffset? lastUpdatedAt = null; + long? ttlMs = null; + long? pollIntervalMs = null; + string? resultType = null; + JsonObject? meta = null; + JsonElement? result = null; + JsonElement? error = null; + IDictionary? inputRequests = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected property name."); + } + + string propertyName = reader.GetString()!; + reader.Read(); + + switch (propertyName) + { + case "taskId": + taskId = reader.GetString(); + break; + case "status": + statusString = reader.GetString(); + break; + case "statusMessage": + statusMessage = reader.GetString(); + break; + case "createdAt": + createdAt = reader.GetDateTimeOffset(); + break; + case "lastUpdatedAt": + lastUpdatedAt = reader.GetDateTimeOffset(); + break; + case "ttlMs": + ttlMs = reader.GetInt64(); + break; + case "pollIntervalMs": + pollIntervalMs = reader.GetInt64(); + break; + case "resultType": + resultType = reader.GetString(); + break; + case "_meta": + meta = JsonSerializer.Deserialize(ref reader, options.GetTypeInfo()); + break; + case "result": + result = JsonElement.ParseValue(ref reader); + break; + case "error": + error = JsonElement.ParseValue(ref reader); + break; + case "inputRequests": + inputRequests = JsonSerializer.Deserialize(ref reader, options.GetTypeInfo>()); + break; + default: + reader.Skip(); + break; + } + } + + if (taskId is null) + { + throw new JsonException("Missing required 'taskId' property on GetTaskResult."); + } + + if (statusString is null) + { + throw new JsonException("Missing required 'status' property on GetTaskResult."); + } + + if (createdAt is null) + { + throw new JsonException("Missing required 'createdAt' property on GetTaskResult."); + } + + if (lastUpdatedAt is null) + { + throw new JsonException("Missing required 'lastUpdatedAt' property on GetTaskResult."); + } + + GetTaskResult taskResult = statusString switch + { + "working" => new WorkingTaskResult + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + }, + "completed" => result is not null + ? new CompletedTaskResult + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + TaskResult = result.Value, + } + : throw new JsonException("Completed task is missing required 'result' property."), + "failed" => error is not null + ? new FailedTaskResult + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + Error = error.Value, + } + : throw new JsonException("Failed task is missing required 'error' property."), + "cancelled" => new CancelledTaskResult + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + }, + "input_required" => inputRequests is not null + ? new InputRequiredTaskResult + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + InputRequests = inputRequests, + } + : throw new JsonException("Input-required task is missing required 'inputRequests' property."), + _ => throw new JsonException($"Unknown task status: '{statusString}'.") + }; + + taskResult.StatusMessage = statusMessage; + taskResult.TtlMs = ttlMs; + taskResult.PollIntervalMs = pollIntervalMs; + taskResult.ResultType = resultType; + taskResult.Meta = meta; + + return taskResult; + } + + public override void Write(Utf8JsonWriter writer, GetTaskResult value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + if (value.ResultType is not null) + { + writer.WriteString("resultType", value.ResultType); + } + + if (value.Meta is not null) + { + writer.WritePropertyName("_meta"); + JsonSerializer.Serialize(writer, value.Meta, options.GetTypeInfo()); + } + + writer.WriteString("taskId", value.TaskId); + writer.WriteString("status", value.Status switch + { + McpTaskStatus.Working => "working", + McpTaskStatus.Completed => "completed", + McpTaskStatus.Failed => "failed", + McpTaskStatus.Cancelled => "cancelled", + McpTaskStatus.InputRequired => "input_required", + _ => throw new JsonException($"Unknown McpTaskStatus: {value.Status}") + }); + + if (value.StatusMessage is not null) + { + writer.WriteString("statusMessage", value.StatusMessage); + } + + writer.WriteString("createdAt", value.CreatedAt); + writer.WriteString("lastUpdatedAt", value.LastUpdatedAt); + + if (value.TtlMs is not null) + { + writer.WriteNumber("ttlMs", value.TtlMs.Value); + } + + if (value.PollIntervalMs is not null) + { + writer.WriteNumber("pollIntervalMs", value.PollIntervalMs.Value); + } + + switch (value) + { + case CompletedTaskResult completed: + writer.WritePropertyName("result"); + completed.TaskResult.WriteTo(writer); + break; + case FailedTaskResult failed: + writer.WritePropertyName("error"); + failed.Error.WriteTo(writer); + break; + case InputRequiredTaskResult inputRequired: + writer.WritePropertyName("inputRequests"); + JsonSerializer.Serialize(writer, inputRequired.InputRequests, options.GetTypeInfo>()); + break; + } + + writer.WriteEndObject(); + } + } +} + +/// +/// Represents a task that is currently being processed by the server. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +public sealed class WorkingTaskResult : GetTaskResult +{ + /// + [JsonPropertyName("status")] + public override McpTaskStatus Status => McpTaskStatus.Working; +} + +/// +/// Represents a task that has completed successfully, carrying the final result. +/// +/// +/// +/// The field contains the result structure matching the original request type. +/// For example, a tools/call task would contain the structure. +/// This includes tool calls that returned results with isError: true. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +public sealed class CompletedTaskResult : GetTaskResult +{ + /// + [JsonPropertyName("status")] + public override McpTaskStatus Status => McpTaskStatus.Completed; + + /// + /// Gets or sets the final result of the task as raw JSON. + /// + /// + /// The structure matches the result type of the original request. + /// + [JsonPropertyName("result")] + public required JsonElement TaskResult { get; set; } +} + +/// +/// Represents a task that failed due to a JSON-RPC error during execution. +/// +/// +/// +/// The field contains the JSON-RPC error object that caused the failure. +/// This status must not be used for non-JSON-RPC errors. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +public sealed class FailedTaskResult : GetTaskResult +{ + /// + [JsonPropertyName("status")] + public override McpTaskStatus Status => McpTaskStatus.Failed; + + /// + /// Gets or sets the JSON-RPC error that caused the task to fail. + /// + [JsonPropertyName("error")] + public required JsonElement Error { get; set; } +} + +/// +/// Represents a task that was cancelled before completion. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +public sealed class CancelledTaskResult : GetTaskResult +{ + /// + [JsonPropertyName("status")] + public override McpTaskStatus Status => McpTaskStatus.Cancelled; +} + +/// +/// Represents a task that requires input from the client before it can proceed. +/// +/// +/// +/// The field contains outstanding server-to-client requests +/// that the client must fulfil. Each entry is keyed by an arbitrary identifier for matching +/// requests to responses, and the value is a JSON object with method and params fields. +/// +/// +/// Clients must treat each entry as they would the equivalent standalone server-to-client request. +/// Clients should deduplicate keys across consecutive polls to avoid presenting the same request +/// to the user or model more than once. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +public sealed class InputRequiredTaskResult : GetTaskResult +{ + /// + [JsonPropertyName("status")] + public override McpTaskStatus Status => McpTaskStatus.InputRequired; + + /// + /// Gets or sets the server-to-client requests that need to be fulfilled. + /// + /// + /// Keys are arbitrary identifiers for matching requests to responses. + /// Each value is a JSON object with method and params fields representing + /// the server-to-client request (e.g., an elicitation request). + /// + [JsonPropertyName("inputRequests")] + public required IDictionary InputRequests { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/ListTasksRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/ListTasksRequestParams.cs deleted file mode 100644 index 3036d977b..000000000 --- a/src/ModelContextProtocol.Core/Protocol/ListTasksRequestParams.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Represents the parameters for a tasks/list request to retrieve a list of tasks. -/// -/// -/// This operation supports cursor-based pagination. Receivers should use cursor-based -/// pagination to limit the number of tasks returned in a single response. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class ListTasksRequestParams : PaginatedRequestParams -{ - // Inherits Cursor property from PaginatedRequestParams -} - -/// -/// Represents the result of a tasks/list request. -/// -/// -/// The result contains an array of task objects and an optional cursor for pagination. -/// If is present, more tasks are available. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class ListTasksResult : PaginatedResult -{ - /// - /// Gets or sets the list of tasks. - /// - [JsonPropertyName("tasks")] - public required IList Tasks { get; set; } -} diff --git a/src/ModelContextProtocol.Core/Protocol/McpExtensions.cs b/src/ModelContextProtocol.Core/Protocol/McpExtensions.cs new file mode 100644 index 000000000..a41e4a576 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/McpExtensions.cs @@ -0,0 +1,18 @@ +namespace ModelContextProtocol.Protocol; + +/// +/// Provides constants for well-known MCP extension identifiers. +/// +public static class McpExtensions +{ + /// + /// The extension identifier for the MCP Tasks extension. + /// + /// + /// When included in client per-request capabilities, indicates the client can handle + /// in lieu of a standard result. + /// See the SEP-2663 + /// specification for details. + /// + public const string Tasks = "io.modelcontextprotocol/tasks"; +} diff --git a/src/ModelContextProtocol.Core/Protocol/McpTask.cs b/src/ModelContextProtocol.Core/Protocol/McpTask.cs deleted file mode 100644 index 2056c5890..000000000 --- a/src/ModelContextProtocol.Core/Protocol/McpTask.cs +++ /dev/null @@ -1,104 +0,0 @@ -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Represents an MCP task, which is a durable state machine carrying information -/// about the underlying execution state of a request. -/// -/// -/// -/// Tasks are useful for representing expensive computations and batch processing requests. -/// Each task is uniquely identifiable by a receiver-generated task ID. -/// -/// -/// Tasks follow a defined lifecycle through the property. They begin -/// in the status and may transition through various states -/// before reaching a terminal status (, , -/// or ). -/// -/// -/// See the tasks specification for details. -/// -/// -[DebuggerDisplay("{DebuggerDisplay,nq}")] -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class McpTask -{ - /// - /// Gets or sets the unique identifier for the task. - /// - /// - /// Task IDs are generated by the receiver when creating a task and must be unique - /// among all tasks controlled by that receiver. - /// - [JsonPropertyName("taskId")] - public required string TaskId { get; set; } - - /// - /// Gets or sets the current state of the task execution. - /// - [JsonPropertyName("status")] - public required McpTaskStatus Status { get; set; } - - /// - /// Gets or sets an optional human-readable message describing the current state. - /// - /// - /// This message can be present for any status, including error details for failed tasks. - /// - [JsonPropertyName("statusMessage")] - public string? StatusMessage { get; set; } - - /// - /// Gets or sets the ISO 8601 timestamp when the task was created. - /// - /// - /// Receivers must include this timestamp in all task responses to indicate when - /// the task was created. - /// - [JsonPropertyName("createdAt")] - public required DateTimeOffset CreatedAt { get; set; } - - /// - /// Gets or sets the ISO 8601 timestamp when the task status was last updated. - /// - /// - /// Receivers must include this timestamp in all task responses to indicate when - /// the task was last updated. - /// - [JsonPropertyName("lastUpdatedAt")] - public required DateTimeOffset LastUpdatedAt { get; set; } - - /// - /// Gets or sets the time to live (retention duration) from creation before the task may be deleted. - /// - /// - /// - /// A null value indicates unlimited lifetime. After a task's TTL lifetime has elapsed, - /// receivers may delete the task and its results, regardless of the task status. - /// - /// - /// Receivers may override the requested TTL duration and must include the actual TTL - /// duration (or null for unlimited) in task responses. - /// - /// - [JsonPropertyName("ttl")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? TimeToLive { get; set; } - - /// - /// Gets or sets the suggested time between status checks. - /// - /// - /// Requestors should respect this value when provided to avoid excessive polling. - /// This value is optional and may not be present in all task responses. - /// - [JsonPropertyName("pollInterval")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? PollInterval { get; set; } - - private string DebuggerDisplay => $"Task {TaskId}: {Status}" + (StatusMessage != null ? $" - {StatusMessage}" : ""); -} diff --git a/src/ModelContextProtocol.Core/Protocol/McpTaskMetadata.cs b/src/ModelContextProtocol.Core/Protocol/McpTaskMetadata.cs deleted file mode 100644 index 72dea54f3..000000000 --- a/src/ModelContextProtocol.Core/Protocol/McpTaskMetadata.cs +++ /dev/null @@ -1,41 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Represents metadata for augmenting a request with task execution. -/// -/// -/// -/// When included in a request's params, this metadata signals that the requestor -/// wants the receiver to execute the request as a task rather than synchronously. -/// The receiver will return a containing task data -/// instead of the actual operation result. -/// -/// -/// Requestors can specify a desired TTL (time-to-live) duration for the task, -/// though receivers may override this value based on their resource management policies. -/// -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class McpTaskMetadata -{ - /// - /// Gets or sets the requested time to live (retention duration) to retain the task from creation. - /// - /// - /// - /// This is a hint to the receiver about how long the requestor expects to need access - /// to the task data. Receivers may override this value based on their resource constraints - /// and policies. - /// - /// - /// A null value indicates no specific retention requirement. The actual TTL used by the - /// receiver will be returned in the property. - /// - /// - [JsonPropertyName("ttl")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? TimeToLive { get; set; } -} diff --git a/src/ModelContextProtocol.Core/Protocol/McpTaskStatus.cs b/src/ModelContextProtocol.Core/Protocol/McpTaskStatus.cs index 9cf8a2f66..3b705a947 100644 --- a/src/ModelContextProtocol.Core/Protocol/McpTaskStatus.cs +++ b/src/ModelContextProtocol.Core/Protocol/McpTaskStatus.cs @@ -1,4 +1,3 @@ -using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; @@ -7,73 +6,44 @@ namespace ModelContextProtocol.Protocol; /// Represents the status of an MCP task. /// /// -/// -/// Tasks progress through a defined lifecycle: -/// -/// : The request is currently being processed. -/// : The receiver needs input from the requestor. -/// The requestor should call tasks/result to receive input requests. -/// : The request completed successfully and results are available. -/// : The request did not complete successfully. -/// : The request was cancelled before completion. -/// -/// -/// -/// Terminal states are , , and . -/// Once a task reaches a terminal state, it cannot transition to any other status. -/// +/// Tasks are durable state machines that carry information about the underlying execution state +/// of the request they augment. See the +/// SEP-2663 +/// specification for details. /// [JsonConverter(typeof(JsonStringEnumConverter))] -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] public enum McpTaskStatus { /// /// The request is currently being processed. /// - /// - /// Tasks begin in this status when created. From , tasks may transition - /// to , , , or . - /// [JsonStringEnumMemberName("working")] Working, /// - /// The receiver needs input from the requestor. + /// The server needs input from the client before the task can proceed. + /// The tasks/get response will include outstanding requests in the inputRequests field. /// - /// - /// The requestor should call tasks/result to receive input requests, even though the task - /// has not reached a terminal state. From , tasks may transition - /// to , , , or . - /// [JsonStringEnumMemberName("input_required")] InputRequired, /// /// The request completed successfully and results are available. + /// This includes tool calls that returned results with isError: true. /// - /// - /// This is a terminal status. Tasks in this status cannot transition to any other status. - /// [JsonStringEnumMemberName("completed")] Completed, /// - /// The associated request did not complete successfully. + /// The request was cancelled before completion. /// - /// - /// This is a terminal status. For tool calls specifically, this includes cases where - /// the tool call result has isError set to true. Tasks in this status cannot transition - /// to any other status. - /// - [JsonStringEnumMemberName("failed")] - Failed, + [JsonStringEnumMemberName("cancelled")] + Cancelled, /// - /// The request was cancelled before completion. + /// The request failed due to a JSON-RPC error during execution. + /// This status must not be used for non-JSON-RPC errors. /// - /// - /// This is a terminal status. Tasks in this status cannot transition to any other status. - /// - [JsonStringEnumMemberName("cancelled")] - Cancelled + [JsonStringEnumMemberName("failed")] + Failed, } diff --git a/src/ModelContextProtocol.Core/Protocol/McpTaskStatusNotificationParams.cs b/src/ModelContextProtocol.Core/Protocol/McpTaskStatusNotificationParams.cs deleted file mode 100644 index a9b536102..000000000 --- a/src/ModelContextProtocol.Core/Protocol/McpTaskStatusNotificationParams.cs +++ /dev/null @@ -1,67 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Represents the parameters for a notifications/tasks/status notification. -/// -/// -/// -/// When a task status changes, receivers may send this notification to inform the -/// requestor of the change. This notification includes the full task state. -/// -/// -/// Requestors must not rely on receiving this notification, as it is optional. Receivers -/// are not required to send status notifications and may choose to only send them for -/// certain status transitions. Requestors should continue to poll via tasks/get to ensure -/// they receive status updates. -/// -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class McpTaskStatusNotificationParams : NotificationParams -{ - /// - /// Gets or sets the task ID. - /// - [JsonPropertyName("taskId")] - public required string TaskId { get; set; } - - /// - /// Gets or sets the current status of the task. - /// - [JsonPropertyName("status")] - public required McpTaskStatus Status { get; set; } - - /// - /// Gets or sets an optional human-readable message describing the current state. - /// - [JsonPropertyName("statusMessage")] - public string? StatusMessage { get; set; } - - /// - /// Gets or sets the ISO 8601 timestamp when the task was created. - /// - [JsonPropertyName("createdAt")] - public required DateTimeOffset CreatedAt { get; set; } - - /// - /// Gets or sets the ISO 8601 timestamp when the task status was last updated. - /// - [JsonPropertyName("lastUpdatedAt")] - public required DateTimeOffset LastUpdatedAt { get; set; } - - /// - /// Gets or sets the time to live (retention duration) from creation before the task may be deleted. - /// - [JsonPropertyName("ttl")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? TimeToLive { get; set; } - - /// - /// Gets or sets the suggested time between status checks. - /// - [JsonPropertyName("pollInterval")] - [JsonConverter(typeof(TimeSpanMillisecondsConverter))] - public TimeSpan? PollInterval { get; set; } -} diff --git a/src/ModelContextProtocol.Core/Protocol/McpTasksCapability.cs b/src/ModelContextProtocol.Core/Protocol/McpTasksCapability.cs deleted file mode 100644 index 1b3ccd9dd..000000000 --- a/src/ModelContextProtocol.Core/Protocol/McpTasksCapability.cs +++ /dev/null @@ -1,160 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Represents the tasks capability configuration for servers and clients. -/// -/// -/// -/// The tasks capability enables requestors (clients or servers) to augment their requests with -/// tasks for long-running operations. Tasks are durable state machines that carry information -/// about the underlying execution state of requests. -/// -/// -/// During initialization, both parties exchange their tasks capabilities to establish which -/// operations support task-based execution. Requestors should only augment requests with a -/// task if the corresponding capability has been declared by the receiver. -/// -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class McpTasksCapability -{ - /// - /// Gets or sets whether this party supports the tasks/list operation. - /// - /// - /// When present, indicates support for listing all tasks. - /// - [JsonPropertyName("list")] - public ListMcpTasksCapability? List { get; set; } - - /// - /// Gets or sets whether this party supports the tasks/cancel operation. - /// - /// - /// When present, indicates support for cancelling tasks. - /// - [JsonPropertyName("cancel")] - public CancelMcpTasksCapability? Cancel { get; set; } - - /// - /// Gets or sets which request types support task augmentation. - /// - /// - /// - /// The set of capabilities in this property is exhaustive. If a request type is not present, - /// it does not support task augmentation. - /// - /// - /// For servers, this typically includes tools/call. For clients, this typically includes - /// sampling/createMessage and elicitation/create. - /// - /// - [JsonPropertyName("requests")] - public RequestMcpTasksCapability? Requests { get; set; } -} - -/// -/// Represents task support for tool-specific requests. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class RequestMcpTasksCapability -{ - /// - /// Gets or sets task support for tool-related requests. - /// - [JsonPropertyName("tools")] - public ToolsMcpTasksCapability? Tools { get; set; } - - /// - /// Gets or sets task support for sampling-related requests. - /// - [JsonPropertyName("sampling")] - public SamplingMcpTasksCapability? Sampling { get; set; } - - /// - /// Gets or sets task support for elicitation-related requests. - /// - [JsonPropertyName("elicitation")] - public ElicitationMcpTasksCapability? Elicitation { get; set; } -} - -/// -/// Represents task support for tool-related requests. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class ToolsMcpTasksCapability -{ - /// - /// Gets or sets whether tools/call requests support task augmentation. - /// - /// - /// When present, indicates that the server supports task-augmented tools/call requests. - /// - [JsonPropertyName("call")] - public CallToolMcpTasksCapability? Call { get; set; } -} - -/// -/// Represents task support for sampling-related requests. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class SamplingMcpTasksCapability -{ - /// - /// Gets or sets whether sampling/createMessage requests support task augmentation. - /// - /// - /// When present, indicates that the client supports task-augmented sampling/createMessage requests. - /// - [JsonPropertyName("createMessage")] - public CreateMessageMcpTasksCapability? CreateMessage { get; set; } -} - -/// -/// Represents task support for elicitation-related requests. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class ElicitationMcpTasksCapability -{ - /// - /// Gets or sets whether elicitation/create requests support task augmentation. - /// - /// - /// When present, indicates that the client supports task-augmented elicitation/create requests. - /// - [JsonPropertyName("create")] - public CreateElicitationMcpTasksCapability? Create { get; set; } -} - -/// -/// Represents the capability for listing tasks. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class ListMcpTasksCapability; - -/// -/// Represents the capability for cancelling tasks. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class CancelMcpTasksCapability; - -/// -/// Represents the capability for task-augmented tools/call requests. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class CallToolMcpTasksCapability; - -/// -/// Represents the capability for task-augmented sampling/createMessage requests. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class CreateMessageMcpTasksCapability; - -/// -/// Represents the capability for task-augmented elicitation/create requests. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class CreateElicitationMcpTasksCapability; diff --git a/src/ModelContextProtocol.Core/Protocol/NotificationMethods.cs b/src/ModelContextProtocol.Core/Protocol/NotificationMethods.cs index 949361650..cab98a5bc 100644 --- a/src/ModelContextProtocol.Core/Protocol/NotificationMethods.cs +++ b/src/ModelContextProtocol.Core/Protocol/NotificationMethods.cs @@ -143,39 +143,12 @@ public static class NotificationMethods public const string CancelledNotification = "notifications/cancelled"; /// - /// The name of the notification sent when a task status changes. + /// The name of the notification sent by the server to push task status updates to subscribed clients. /// /// - /// - /// When a task status changes, receivers may send this notification to inform the requestor - /// of the change. This notification includes the full task state. - /// - /// - /// Requestors must not rely on receiving this notification, as it is optional. Receivers - /// are not required to send status notifications and may choose to only send them for - /// certain status transitions. Requestors should continue to poll via tasks/get to ensure - /// they receive status updates. - /// - /// - public const string TaskStatusNotification = "notifications/tasks/status"; - - /// - /// The metadata key used to associate requests, responses, and notifications with a task. - /// - /// - /// - /// This constant defines the key "io.modelcontextprotocol/related-task" used in the - /// _meta field to associate messages with their originating task across the entire - /// request lifecycle. - /// - /// - /// For example, an elicitation that a task-augmented tool call depends on must share the - /// same related task ID with that tool call's task. - /// - /// - /// For tasks/get, tasks/list, and tasks/cancel operations, this - /// metadata should not be included as the taskId is already present in the message structure. - /// + /// Part of the io.modelcontextprotocol/tasks extension. + /// Each notification carries a complete task state for the current status, identical to what + /// tasks/get would have returned at that moment. /// - public const string RelatedTaskMetaKey = "io.modelcontextprotocol/related-task"; + public const string TaskStatusNotification = "notifications/tasks"; } diff --git a/src/ModelContextProtocol.Core/Protocol/RequestMethods.cs b/src/ModelContextProtocol.Core/Protocol/RequestMethods.cs index e0118fa57..6967dd07d 100644 --- a/src/ModelContextProtocol.Core/Protocol/RequestMethods.cs +++ b/src/ModelContextProtocol.Core/Protocol/RequestMethods.cs @@ -123,30 +123,29 @@ public static class RequestMethods public const string Initialize = "initialize"; /// - /// The name of the request method to retrieve task status. + /// The name of the request method sent from the client to poll for task completion. /// /// - /// Requestors poll for task completion by sending tasks/get requests. They should respect - /// the pollInterval provided in responses when determining polling frequency. + /// Part of the io.modelcontextprotocol/tasks extension. + /// Clients poll for task status by sending this request with the task ID. /// public const string TasksGet = "tasks/get"; /// - /// The name of the request method to retrieve the result of a completed task. + /// The name of the request method sent from the client to provide input responses to a task. /// /// - /// This request blocks until the task reaches a terminal status (completed, failed, or cancelled). - /// The result structure matches the original request type (e.g., CallToolResult for tools/call). + /// Part of the io.modelcontextprotocol/tasks extension. + /// Used when a task has input_required status and the client needs to fulfill outstanding requests. /// - public const string TasksResult = "tasks/result"; + public const string TasksUpdate = "tasks/update"; /// - /// The name of the request method to retrieve a list of tasks with pagination support. - /// - public const string TasksList = "tasks/list"; - - /// - /// The name of the request method to explicitly cancel a task. + /// The name of the request method sent from the client to signal intent to cancel a task. /// + /// + /// Part of the io.modelcontextprotocol/tasks extension. + /// Cancellation is cooperative — the server decides whether and when to honor it. + /// public const string TasksCancel = "tasks/cancel"; } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/Result.cs b/src/ModelContextProtocol.Core/Protocol/Result.cs index 58b076ddb..d16c90fe8 100644 --- a/src/ModelContextProtocol.Core/Protocol/Result.cs +++ b/src/ModelContextProtocol.Core/Protocol/Result.cs @@ -21,4 +21,22 @@ private protected Result() /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the result type discriminator used to distinguish polymorphic results. + /// + /// + /// + /// Standard results use "complete" (or omit this field). When a server returns a task + /// in lieu of a standard result, it sets this to "task". + /// + /// + /// See SEP-2322 + /// for the introduction of this field, and + /// SEP-2663 + /// for the "task" discriminator value. + /// + /// + [JsonPropertyName("resultType")] + public string? ResultType { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/ResultOrCreatedTask.cs b/src/ModelContextProtocol.Core/Protocol/ResultOrCreatedTask.cs new file mode 100644 index 000000000..87b857470 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/ResultOrCreatedTask.cs @@ -0,0 +1,76 @@ +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the result of a request that supports task-augmented execution, which may be either +/// the standard result or a indicating asynchronous processing. +/// +/// The standard result type for the request (e.g., ). +/// +/// +/// When a server supports the io.modelcontextprotocol/tasks extension and the client declares +/// the extension capability on its request, the server may return a +/// instead of the standard result. This type represents that polymorphic response. +/// +/// +/// Use to determine which variant was returned, then access either +/// for the immediate result or for the task handle. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +public class ResultOrCreatedTask where TResult : Result +{ + private readonly TResult? _result; + private readonly CreateTaskResult? _taskCreated; + + /// + /// Initializes a new instance of with an immediate result. + /// + /// The standard result returned by the server. + public ResultOrCreatedTask(TResult result) + { + Throw.IfNull(result); + _result = result; + } + + /// + /// Initializes a new instance of with a task handle. + /// + /// The task creation result returned by the server. + public ResultOrCreatedTask(CreateTaskResult taskCreated) + { + Throw.IfNull(taskCreated); + _taskCreated = taskCreated; + } + + /// + /// Gets a value indicating whether the server created a task instead of returning an immediate result. + /// + public bool IsTask => _taskCreated is not null; + + /// + /// Gets the immediate result, or if the server created a task. + /// + public TResult? Result => _result; + + /// + /// Gets the task creation result, or if the server returned an immediate result. + /// + public CreateTaskResult? TaskCreated => _taskCreated; + + /// + /// Implicitly converts a to a + /// wrapping the immediate result. + /// + /// The result to wrap. + public static implicit operator ResultOrCreatedTask(TResult result) => new(result); + + /// + /// Implicitly converts a to a + /// wrapping the task handle. + /// + /// The task creation result to wrap. + public static implicit operator ResultOrCreatedTask(CreateTaskResult taskCreated) => new(taskCreated); +} diff --git a/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs b/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs index d4e23a66f..92ffff424 100644 --- a/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs +++ b/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs @@ -67,32 +67,6 @@ public sealed class ServerCapabilities [JsonPropertyName("completions")] public CompletionsCapability? Completions { get; set; } - /// - /// Gets or sets a server's tasks capability for supporting task-augmented requests. - /// - /// - /// - /// The tasks capability enables clients to augment their requests with tasks for long-running - /// operations. When present, clients can request that certain operations (like tool calls) - /// execute asynchronously, with the ability to poll for status and retrieve results later. - /// - /// - /// See for details on configuring which operations support tasks. - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - [JsonIgnore] - public McpTasksCapability? Tasks - { - get => TasksCore; - set => TasksCore = value; - } - - // See ExperimentalInternalPropertyTests.cs before modifying this property. - [JsonInclude] - [JsonPropertyName("tasks")] - internal McpTasksCapability? TasksCore { get; set; } - /// /// Gets or sets optional MCP extensions that the server supports. /// diff --git a/src/ModelContextProtocol.Core/Protocol/TaskStatusNotificationParams.cs b/src/ModelContextProtocol.Core/Protocol/TaskStatusNotificationParams.cs new file mode 100644 index 000000000..f1b99cea6 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/TaskStatusNotificationParams.cs @@ -0,0 +1,358 @@ +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the parameters for a notifications/tasks notification sent by the server +/// to push task status updates to the client. +/// +/// +/// +/// Each notification carries a complete task state for the current status, identical to what +/// tasks/get would have returned at that moment. The concrete type depends on the task's +/// current status: +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// To receive task status notifications, clients send a subscriptions/listen request +/// including the task IDs they are interested in. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +[JsonConverter(typeof(Converter))] +public abstract class TaskStatusNotificationParams : NotificationParams +{ + /// Prevent external derivations. + private protected TaskStatusNotificationParams() + { + } + + /// + /// Gets or sets the stable identifier for this task. + /// + [JsonPropertyName("taskId")] + public required string TaskId { get; set; } + + /// + /// Gets or sets the current task status. + /// + [JsonPropertyName("status")] + public abstract McpTaskStatus Status { get; } + + /// + /// Gets or sets an optional message describing the current task state. + /// + [JsonPropertyName("statusMessage")] + public string? StatusMessage { get; set; } + + /// + /// Gets or sets the ISO 8601 timestamp when the task was created. + /// + [JsonPropertyName("createdAt")] + public required DateTimeOffset CreatedAt { get; set; } + + /// + /// Gets or sets the ISO 8601 timestamp when the task was last updated. + /// + [JsonPropertyName("lastUpdatedAt")] + public required DateTimeOffset LastUpdatedAt { get; set; } + + /// + /// Gets or sets the time-to-live duration from creation in milliseconds, or for unlimited. + /// + [JsonPropertyName("ttlMs")] + public long? TtlMs { get; set; } + + /// + /// Gets or sets the suggested polling interval in milliseconds. + /// + [JsonPropertyName("pollIntervalMs")] + public long? PollIntervalMs { get; set; } + + /// + /// JSON converter that deserializes to the appropriate + /// concrete subtype based on the status discriminator field. + /// + internal sealed class Converter : JsonConverter + { + public override TaskStatusNotificationParams? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected StartObject token for TaskStatusNotificationParams."); + } + + string? taskId = null; + string? statusString = null; + string? statusMessage = null; + DateTimeOffset? createdAt = null; + DateTimeOffset? lastUpdatedAt = null; + long? ttlMs = null; + long? pollIntervalMs = null; + JsonObject? meta = null; + JsonElement? result = null; + JsonElement? error = null; + IDictionary? inputRequests = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected property name."); + } + + string propertyName = reader.GetString()!; + reader.Read(); + + switch (propertyName) + { + case "taskId": + taskId = reader.GetString(); + break; + case "status": + statusString = reader.GetString(); + break; + case "statusMessage": + statusMessage = reader.GetString(); + break; + case "createdAt": + createdAt = reader.GetDateTimeOffset(); + break; + case "lastUpdatedAt": + lastUpdatedAt = reader.GetDateTimeOffset(); + break; + case "ttlMs": + ttlMs = reader.GetInt64(); + break; + case "pollIntervalMs": + pollIntervalMs = reader.GetInt64(); + break; + case "_meta": + meta = JsonSerializer.Deserialize(ref reader, options.GetTypeInfo()); + break; + case "result": + result = JsonElement.ParseValue(ref reader); + break; + case "error": + error = JsonElement.ParseValue(ref reader); + break; + case "inputRequests": + inputRequests = JsonSerializer.Deserialize(ref reader, options.GetTypeInfo>()); + break; + default: + reader.Skip(); + break; + } + } + + if (taskId is null) + { + throw new JsonException("Missing required 'taskId' property on TaskStatusNotificationParams."); + } + + if (statusString is null) + { + throw new JsonException("Missing required 'status' property on TaskStatusNotificationParams."); + } + + if (createdAt is null) + { + throw new JsonException("Missing required 'createdAt' property on TaskStatusNotificationParams."); + } + + if (lastUpdatedAt is null) + { + throw new JsonException("Missing required 'lastUpdatedAt' property on TaskStatusNotificationParams."); + } + + TaskStatusNotificationParams notification = statusString switch + { + "working" => new WorkingTaskNotificationParams + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + }, + "completed" => result is not null + ? new CompletedTaskNotificationParams + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + TaskResult = result.Value, + } + : throw new JsonException("Completed task notification is missing required 'result' property."), + "failed" => error is not null + ? new FailedTaskNotificationParams + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + Error = error.Value, + } + : throw new JsonException("Failed task notification is missing required 'error' property."), + "cancelled" => new CancelledTaskNotificationParams + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + }, + "input_required" => inputRequests is not null + ? new InputRequiredTaskNotificationParams + { + TaskId = taskId, + CreatedAt = createdAt.Value, + LastUpdatedAt = lastUpdatedAt.Value, + InputRequests = inputRequests, + } + : throw new JsonException("Input-required task notification is missing required 'inputRequests' property."), + _ => throw new JsonException($"Unknown task status: '{statusString}'.") + }; + + notification.StatusMessage = statusMessage; + notification.TtlMs = ttlMs; + notification.PollIntervalMs = pollIntervalMs; + notification.Meta = meta; + + return notification; + } + + public override void Write(Utf8JsonWriter writer, TaskStatusNotificationParams value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + if (value.Meta is not null) + { + writer.WritePropertyName("_meta"); + JsonSerializer.Serialize(writer, value.Meta, options.GetTypeInfo()); + } + + writer.WriteString("taskId", value.TaskId); + writer.WriteString("status", value.Status switch + { + McpTaskStatus.Working => "working", + McpTaskStatus.Completed => "completed", + McpTaskStatus.Failed => "failed", + McpTaskStatus.Cancelled => "cancelled", + McpTaskStatus.InputRequired => "input_required", + _ => throw new JsonException($"Unknown McpTaskStatus: {value.Status}") + }); + + if (value.StatusMessage is not null) + { + writer.WriteString("statusMessage", value.StatusMessage); + } + + writer.WriteString("createdAt", value.CreatedAt); + writer.WriteString("lastUpdatedAt", value.LastUpdatedAt); + + if (value.TtlMs is not null) + { + writer.WriteNumber("ttlMs", value.TtlMs.Value); + } + + if (value.PollIntervalMs is not null) + { + writer.WriteNumber("pollIntervalMs", value.PollIntervalMs.Value); + } + + switch (value) + { + case CompletedTaskNotificationParams completed: + writer.WritePropertyName("result"); + completed.TaskResult.WriteTo(writer); + break; + case FailedTaskNotificationParams failed: + writer.WritePropertyName("error"); + failed.Error.WriteTo(writer); + break; + case InputRequiredTaskNotificationParams inputRequired: + writer.WritePropertyName("inputRequests"); + JsonSerializer.Serialize(writer, inputRequired.InputRequests, options.GetTypeInfo>()); + break; + } + + writer.WriteEndObject(); + } + } +} + +/// +/// Task notification for a task that is currently being processed. +/// +public sealed class WorkingTaskNotificationParams : TaskStatusNotificationParams +{ + /// + public override McpTaskStatus Status => McpTaskStatus.Working; +} + +/// +/// Task notification for a task that has completed successfully. +/// +public sealed class CompletedTaskNotificationParams : TaskStatusNotificationParams +{ + /// + public override McpTaskStatus Status => McpTaskStatus.Completed; + + /// + /// Gets or sets the final result of the task. + /// + [JsonPropertyName("result")] + public required JsonElement TaskResult { get; set; } +} + +/// +/// Task notification for a task that failed. +/// +public sealed class FailedTaskNotificationParams : TaskStatusNotificationParams +{ + /// + public override McpTaskStatus Status => McpTaskStatus.Failed; + + /// + /// Gets or sets the JSON-RPC error that caused the task to fail. + /// + [JsonPropertyName("error")] + public required JsonElement Error { get; set; } +} + +/// +/// Task notification for a task that was cancelled. +/// +public sealed class CancelledTaskNotificationParams : TaskStatusNotificationParams +{ + /// + public override McpTaskStatus Status => McpTaskStatus.Cancelled; +} + +/// +/// Task notification for a task that requires input from the client. +/// +public sealed class InputRequiredTaskNotificationParams : TaskStatusNotificationParams +{ + /// + public override McpTaskStatus Status => McpTaskStatus.InputRequired; + + /// + /// Gets or sets the server-to-client requests that need to be fulfilled. + /// + [JsonPropertyName("inputRequests")] + public required IDictionary InputRequests { get; set; } +} + diff --git a/src/ModelContextProtocol.Core/Protocol/TimeSpanMillisecondsConverter.cs b/src/ModelContextProtocol.Core/Protocol/TimeSpanMillisecondsConverter.cs deleted file mode 100644 index e789db186..000000000 --- a/src/ModelContextProtocol.Core/Protocol/TimeSpanMillisecondsConverter.cs +++ /dev/null @@ -1,41 +0,0 @@ -using System.ComponentModel; -using System.Text.Json; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Provides a JSON converter for that serializes as integer milliseconds. -/// -/// -/// This converter serializes TimeSpan values as the total number of milliseconds (as an integer), -/// and deserializes integer millisecond values back to TimeSpan. System.Text.Json automatically -/// handles nullable TimeSpan properties using this converter. -/// -[EditorBrowsable(EditorBrowsableState.Never)] -public sealed class TimeSpanMillisecondsConverter : JsonConverter -{ - /// - public override TimeSpan Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - if (reader.TokenType is JsonTokenType.Number) - { - if (reader.TryGetInt64(out long milliseconds)) - { - return TimeSpan.FromMilliseconds(milliseconds); - } - - // For non-integer values, convert from fractional milliseconds - double fractionalMilliseconds = reader.GetDouble(); - return TimeSpan.FromTicks((long)(fractionalMilliseconds * TimeSpan.TicksPerMillisecond)); - } - - throw new JsonException($"Unable to convert {reader.TokenType} to TimeSpan."); - } - - /// - public override void Write(Utf8JsonWriter writer, TimeSpan value, JsonSerializerOptions options) - { - writer.WriteNumberValue((long)value.TotalMilliseconds); - } -} diff --git a/src/ModelContextProtocol.Core/Protocol/Tool.cs b/src/ModelContextProtocol.Core/Protocol/Tool.cs index 8abbfd88c..9f61756f8 100644 --- a/src/ModelContextProtocol.Core/Protocol/Tool.cs +++ b/src/ModelContextProtocol.Core/Protocol/Tool.cs @@ -119,26 +119,6 @@ public JsonElement? OutputSchema [JsonPropertyName("annotations")] public ToolAnnotations? Annotations { get; set; } - /// - /// Gets or sets execution-related metadata for this tool. - /// - /// - /// This property provides hints about how the tool should be executed, particularly - /// regarding task augmentation support. See for details. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - [JsonIgnore] - public ToolExecution? Execution - { - get => ExecutionCore; - set => ExecutionCore = value; - } - - // See ExperimentalInternalPropertyTests.cs before modifying this property. - [JsonInclude] - [JsonPropertyName("execution")] - internal ToolExecution? ExecutionCore { get; set; } - /// /// Gets or sets an optional list of icons for this tool. /// diff --git a/src/ModelContextProtocol.Core/Protocol/ToolExecution.cs b/src/ModelContextProtocol.Core/Protocol/ToolExecution.cs deleted file mode 100644 index 174298471..000000000 --- a/src/ModelContextProtocol.Core/Protocol/ToolExecution.cs +++ /dev/null @@ -1,85 +0,0 @@ -using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol; - -/// -/// Represents execution-related metadata for a tool. -/// -/// -/// This type provides hints about how a tool should be executed, particularly -/// regarding task augmentation support. -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class ToolExecution -{ - /// - /// Gets or sets the level of task augmentation support for this tool. - /// - /// - /// - /// This property declares whether a tool supports task-augmented execution: - /// - /// : Clients must not attempt to invoke - /// the tool as a task. This is the default behavior. - /// : Clients may invoke the tool as a task - /// or as a normal request. - /// : Clients must invoke the tool as a task. - /// - /// - /// - /// - /// This is a fine-grained layer in addition to server capabilities. Even if a server's capabilities - /// include tasks.requests.tools.call, this property controls whether each specific tool supports tasks. - /// - /// - [JsonPropertyName("taskSupport")] - public ToolTaskSupport? TaskSupport { get; set; } -} - -/// -/// Represents the level of task augmentation support for a tool. -/// -/// -/// -/// This enum defines how a tool interacts with the task augmentation system: -/// -/// : Task augmentation is not allowed (default) -/// : Task augmentation is supported but not required -/// : Task augmentation is mandatory -/// -/// -/// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -[JsonConverter(typeof(JsonStringEnumConverter))] -public enum ToolTaskSupport -{ - /// - /// Clients must not attempt to invoke the tool as a task. - /// - /// - /// This is the default behavior. Servers should return a -32601 (Method not found) error - /// if a client attempts to invoke the tool as a task when this is set. - /// - [JsonStringEnumMemberName("forbidden")] - Forbidden, - - /// - /// Clients may invoke the tool as a task or as a normal request. - /// - /// - /// When this is set, clients can choose whether to use task augmentation based on their needs. - /// - [JsonStringEnumMemberName("optional")] - Optional, - - /// - /// Clients must invoke the tool as a task. - /// - /// - /// Servers must return a -32601 (Method not found) error if a client does not attempt - /// to invoke the tool as a task when this is set. - /// - [JsonStringEnumMemberName("required")] - Required -} diff --git a/src/ModelContextProtocol.Core/Protocol/UpdateTaskRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/UpdateTaskRequestParams.cs new file mode 100644 index 000000000..ddee81519 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/UpdateTaskRequestParams.cs @@ -0,0 +1,38 @@ +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the parameters for a tasks/update request to provide input responses +/// to outstanding server-to-client requests on a task. +/// +/// +/// +/// When a task requires input from the client (indicated by ), +/// the server includes outstanding requests in the inputRequests field of the tasks/get response. +/// The client provides responses via the field in tasks/update requests. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +public sealed class UpdateTaskRequestParams : RequestParams +{ + /// + /// Gets or sets the identifier of the task to update. + /// + [JsonPropertyName("taskId")] + public required string TaskId { get; set; } + + /// + /// Gets or sets the responses to outstanding inputRequests previously surfaced by the server. + /// + /// + /// Each key must correspond to a currently-outstanding inputRequests key. + /// A server should ignore any responses mapped to a key that is not currently outstanding. + /// + [JsonPropertyName("inputResponses")] + public required IDictionary InputResponses { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/UpdateTaskResult.cs b/src/ModelContextProtocol.Core/Protocol/UpdateTaskResult.cs new file mode 100644 index 000000000..531039c23 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/UpdateTaskResult.cs @@ -0,0 +1,21 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the result of a tasks/update request. This is an empty acknowledgement. +/// +/// +/// +/// On success, the server acknowledges the request with an empty result. +/// The acknowledgement is eventually consistent: the server may accept the responses and +/// return the ack before the task's observable status reflects them. +/// +/// +/// See the SEP-2663 +/// specification for details. +/// +/// +public sealed class UpdateTaskResult : Result +{ +} diff --git a/src/ModelContextProtocol.Core/RequestHandlers.cs b/src/ModelContextProtocol.Core/RequestHandlers.cs index 97e8b95df..a45efa2c9 100644 --- a/src/ModelContextProtocol.Core/RequestHandlers.cs +++ b/src/ModelContextProtocol.Core/RequestHandlers.cs @@ -45,4 +45,36 @@ public void Set( return JsonSerializer.SerializeToNode(result, responseTypeInfo); }; } + + /// + /// Registers a handler that may return either a standard result or a + /// for task-augmented execution. + /// + public void SetTaskAugmented( + string method, + Func>> handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo, + JsonTypeInfo taskResultTypeInfo) + where TResult : Result + { + Throw.IfNull(method); + Throw.IfNull(handler); + Throw.IfNull(requestTypeInfo); + Throw.IfNull(responseTypeInfo); + Throw.IfNull(taskResultTypeInfo); + + this[method] = async (request, cancellationToken) => + { + TParams typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo)!; + var augmented = await handler(typedRequest, request, cancellationToken).ConfigureAwait(false); + + if (augmented.IsTask) + { + return JsonSerializer.SerializeToNode(augmented.TaskCreated!, taskResultTypeInfo); + } + + return JsonSerializer.SerializeToNode(augmented.Result!, responseTypeInfo); + }; + } } diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index 961344c2c..715edb97f 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -154,23 +154,6 @@ options.OpenWorld is not null || tool.Meta = function.UnderlyingMethod is not null ? CreateMetaFromAttributes(function.UnderlyingMethod, options.Meta) : options.Meta; - - // Apply user-specified Execution settings if provided - if (options.Execution is not null) - { - tool.Execution = options.Execution; - } - } - - // Auto-detect async methods and mark with taskSupport = "optional" unless explicitly configured. - // This enables implicit task support for async tools: clients can choose to invoke them - // synchronously (wait for completion) or as a task (receive taskId, poll for result). - if (function.UnderlyingMethod is not null && - IsAsyncMethod(function.UnderlyingMethod) && - tool.Execution?.TaskSupport is null) - { - tool.Execution ??= new ToolExecution(); - tool.Execution.TaskSupport = ToolTaskSupport.Optional; } return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping, options?.Metadata ?? []); @@ -218,12 +201,6 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe serializerOptions: newOptions.SerializerOptions ?? McpJsonUtilities.DefaultOptions, inferenceOptions: newOptions.SchemaCreateOptions); } - - if (toolAttr._taskSupport is { } taskSupport) - { - newOptions.Execution ??= new ToolExecution(); - newOptions.Execution.TaskSupport ??= taskSupport; - } } if (method.GetCustomAttribute() is { } descAttr) @@ -350,27 +327,27 @@ internal static string DeriveName(MethodInfo method, JsonNamingPolicy? policy = // Case the name based on the provided naming policy. return (policy ?? JsonNamingPolicy.SnakeCaseLower).ConvertName(name) ?? name; - } - private static bool IsAsyncMethod(MethodInfo method) - { - Type t = method.ReturnType; - - if (t == typeof(Task) || t == typeof(ValueTask)) + static bool IsAsyncMethod(MethodInfo method) { - return true; - } + Type t = method.ReturnType; - if (t.IsGenericType) - { - t = t.GetGenericTypeDefinition(); - if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) + if (t == typeof(Task) || t == typeof(ValueTask)) { return true; } - } - return false; + if (t.IsGenericType) + { + t = t.GetGenericTypeDefinition(); + if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) + { + return true; + } + } + + return false; + } } /// Creates metadata from attributes on the specified method and its declaring class, with the MethodInfo as the first item. diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index 957f58a51..bbaae7913 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -1,5 +1,7 @@ using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; using System.Diagnostics; +using System.Text.Json; namespace ModelContextProtocol.Server; @@ -14,6 +16,7 @@ internal sealed class DestinationBoundMcpServer(McpServerImpl server, ITransport public override McpServerOptions ServerOptions => server.ServerOptions; public override IServiceProvider? Services => server.Services; public override LoggingLevel? LoggingLevel => server.LoggingLevel; + internal override ConcurrentDictionary<(string TaskId, string RequestId), TaskCompletionSource> TaskInputResponseWaiters => server.TaskInputResponseWaiters; public override ValueTask DisposeAsync() => server.DisposeAsync(); diff --git a/src/ModelContextProtocol.Core/Server/IMcpTaskStore.cs b/src/ModelContextProtocol.Core/Server/IMcpTaskStore.cs index d322d21ef..4f560132c 100644 --- a/src/ModelContextProtocol.Core/Server/IMcpTaskStore.cs +++ b/src/ModelContextProtocol.Core/Server/IMcpTaskStore.cs @@ -2,165 +2,117 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json; -namespace ModelContextProtocol; +namespace ModelContextProtocol.Server; /// -/// Provides an interface for pluggable task storage implementations in MCP servers. +/// Provides an interface for storing and managing the lifecycle of MCP tasks. /// /// /// -/// The task store is responsible for managing the lifecycle of tasks, including creation, -/// status updates, result storage, and retrieval. Implementations must be thread-safe and -/// may support session-based isolation for multi-session scenarios. +/// The task store manages the state of tasks created by the server's request handling pipeline. +/// When a client signals support for the io.modelcontextprotocol/tasks extension on a request, +/// the server creates a task in the store, executes the work in the background, and stores the result +/// upon completion. /// /// -/// TTL (Time To Live) Management: Implementations may override the requested TTL value in -/// to enforce resource limits. The actual TTL -/// used is returned in the property. A null TTL indicates -/// unlimited lifetime. Tasks may be deleted after their TTL expires, regardless of status. +/// Implementations must be thread-safe. The store also provides the backing implementation for +/// tasks/get, tasks/update, and tasks/cancel protocol methods. +/// +/// +/// See the SEP-2663 +/// specification for details on the tasks extension. /// /// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] +[Experimental(Experimentals.Extensions_DiagnosticId, UrlFormat = Experimentals.Extensions_Url)] public interface IMcpTaskStore { /// - /// Creates a new task for tracking an asynchronous operation. + /// Creates a new task for tracking an asynchronous execution. /// - /// Metadata for the task, including requested TTL. - /// The JSON-RPC request ID that initiated this task. - /// The original JSON-RPC request that triggered task creation. - /// Optional session identifier for multi-session isolation. /// Cancellation token for the operation. /// - /// A new with a unique task ID, initial status of , - /// and the actual TTL that will be used (which may differ from the requested TTL). + /// A with a unique task ID, initial status of , + /// and timing metadata (TTL, poll interval). /// /// - /// Implementations must generate a unique task ID and set the - /// and timestamps. The implementation may override the - /// requested TTL to enforce storage limits. + /// Implementations must generate a unique task ID and set appropriate timestamps. + /// The server infrastructure maps the returned to the appropriate + /// protocol response type when communicating with clients. /// - Task CreateTaskAsync( - McpTaskMetadata taskParams, - RequestId requestId, - JsonRpcRequest request, - string? sessionId = null, - CancellationToken cancellationToken = default); + Task CreateTaskAsync(CancellationToken cancellationToken = default); /// - /// Retrieves a task by its unique identifier. + /// Retrieves the current state of a task. /// /// The unique identifier of the task to retrieve. - /// Optional session identifier for access control. /// Cancellation token for the operation. /// - /// The if found and accessible, otherwise . + /// A representing the current task state, + /// or if the task does not exist. /// - /// - /// Returns null if the task does not exist or if session-based access control denies access. - /// - Task GetTaskAsync(string taskId, string? sessionId = null, CancellationToken cancellationToken = default); + Task GetTaskAsync(string taskId, CancellationToken cancellationToken = default); /// - /// Stores the final result of a task that has reached a terminal status. + /// Stores the result of a completed execution, transitioning the task to . /// /// The unique identifier of the task. - /// The terminal status: or . - /// The operation result to store as a JSON element. - /// Optional session identifier for access control. + /// The serialized result payload. /// Cancellation token for the operation. - /// The updated with the new status and result stored. - /// - /// - /// The must be either or - /// . This method updates the task status and stores - /// the result for later retrieval via . - /// - /// - /// Implementations should throw if called on a task - /// that is already in a terminal state, to prevent result overwrites. - /// - /// - Task StoreTaskResultAsync( - string taskId, - McpTaskStatus status, - JsonElement result, - string? sessionId = null, - CancellationToken cancellationToken = default); + /// A task representing the asynchronous operation. + Task SetCompletedAsync(string taskId, JsonElement result, CancellationToken cancellationToken = default); /// - /// Retrieves the stored result of a completed or failed task. + /// Marks a task as failed, transitioning it to . /// /// The unique identifier of the task. - /// Optional session identifier for access control. + /// The serialized error information. /// Cancellation token for the operation. - /// The stored operation result as a JSON element. - /// - /// This method should only be called on tasks in terminal states ( - /// or ). The result contains the JSON representation of the - /// original operation result (e.g., for tools/call). - /// - Task GetTaskResultAsync(string taskId, string? sessionId = null, CancellationToken cancellationToken = default); + /// A task representing the asynchronous operation. + Task SetFailedAsync(string taskId, JsonElement error, CancellationToken cancellationToken = default); /// - /// Updates the status and optional status message of a task. + /// Transitions the task to . /// - /// The unique identifier of the task. - /// The new status to set. - /// Optional diagnostic message describing the status change. - /// Optional session identifier for access control. + /// The unique identifier of the task to cancel. /// Cancellation token for the operation. - /// The updated with the new status applied. - /// - /// This method updates the task's , , - /// and properties. Common uses include transitioning to - /// , , or updating - /// progress messages while in status. - /// - Task UpdateTaskStatusAsync( - string taskId, - McpTaskStatus status, - string? statusMessage, - string? sessionId = null, - CancellationToken cancellationToken = default); + /// + /// if the task was successfully cancelled; + /// if the task does not exist or was already in a terminal state. + /// + Task SetCancelledAsync(string taskId, CancellationToken cancellationToken = default); /// - /// Lists tasks with pagination support. + /// Removes input requests that have been satisfied by the provided responses. /// - /// Optional cursor for pagination, from a previous call's nextCursor value. - /// Optional session identifier for filtering tasks by session. + /// The unique identifier of the task. + /// + /// The keys of input requests that have been satisfied. + /// Matched input requests are removed from the task's pending set. + /// /// Cancellation token for the operation. - /// A containing the tasks and an optional cursor for the next page. /// - /// When is provided, implementations should filter to only return - /// tasks associated with that session. The cursor format is implementation-specific. + /// After removing the satisfied requests, if no pending input requests remain the task + /// transitions back to . Otherwise it remains in + /// . /// - Task ListTasksAsync( - string? cursor = null, - string? sessionId = null, + Task ResolveInputRequestsAsync( + string taskId, + IEnumerable inputResponseKeys, CancellationToken cancellationToken = default); /// - /// Attempts to cancel a task, transitioning it to status. + /// Adds input requests to a task, transitioning it to . /// - /// The unique identifier of the task to cancel. - /// Optional session identifier for access control. + /// The unique identifier of the task. + /// + /// The input requests to add. Keys are arbitrary identifiers for matching requests to responses. + /// Each value is a JSON object representing the server-to-client request. + /// New requests are merged with any existing pending requests. + /// /// Cancellation token for the operation. - /// - /// The updated . If the task is already in a terminal state - /// (, , or - /// ), the task is returned unchanged. - /// - /// - /// - /// This method must be idempotent. If called on a task that is already in a terminal state, - /// it returns the current task without error. This behavior differs from the MCP specification - /// but ensures idempotency and avoids race conditions between cancellation and task completion. - /// - /// - /// For tasks not in a terminal state, the implementation should attempt to stop the underlying - /// operation and transition the task to status before returning. - /// - /// - Task CancelTaskAsync(string taskId, string? sessionId = null, CancellationToken cancellationToken = default); + /// A task representing the asynchronous operation. + Task SetInputRequestsAsync( + string taskId, + IDictionary inputRequests, + CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol.Core/Server/InMemoryMcpTaskStore.cs b/src/ModelContextProtocol.Core/Server/InMemoryMcpTaskStore.cs index b2f9b050d..c8fd4339e 100644 --- a/src/ModelContextProtocol.Core/Server/InMemoryMcpTaskStore.cs +++ b/src/ModelContextProtocol.Core/Server/InMemoryMcpTaskStore.cs @@ -1,543 +1,181 @@ using ModelContextProtocol.Protocol; using System.Collections.Concurrent; +using System.Collections.Immutable; using System.Diagnostics.CodeAnalysis; using System.Text.Json; -#if MCP_TEST_TIME_PROVIDER -namespace ModelContextProtocol.Tests.Internal; -#else -namespace ModelContextProtocol; -#endif +namespace ModelContextProtocol.Server; /// -/// Provides an in-memory implementation of for development and testing. +/// Provides an in-memory implementation of for development and testing scenarios. /// /// /// -/// This implementation uses thread-safe concurrent collections and is suitable for single-server -/// scenarios and testing. It is not recommended for production multi-server deployments as tasks -/// are stored only in memory and are lost on server restart. +/// This implementation stores all task state in memory using immutable snapshots and +/// compare-and-swap updates for thread safety without locks. +/// Tasks are not persisted across process restarts. /// /// -/// Features: -/// -/// Thread-safe operations using -/// Automatic TTL-based cleanup via background task -/// Session-based isolation when sessionId is provided -/// Configurable default TTL and maximum TTL limits -/// +/// For production scenarios requiring durability, session isolation, or TTL-based cleanup, +/// implement a custom . /// /// -[Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] -public sealed class InMemoryMcpTaskStore : IMcpTaskStore, IDisposable +[Experimental(Experimentals.Extensions_DiagnosticId, UrlFormat = Experimentals.Extensions_Url)] +public class InMemoryMcpTaskStore : IMcpTaskStore { - private readonly ConcurrentDictionary _tasks = new(); - private readonly TimeSpan? _defaultTtl; - private readonly TimeSpan? _maxTtl; - private readonly TimeSpan _pollInterval; -#if MCP_TEST_TIME_PROVIDER - private readonly ITimer? _cleanupTimer; -#else - private readonly Timer? _cleanupTimer; -#endif - private readonly int _pageSize; - private readonly int? _maxTasks; - private readonly int? _maxTasksPerSession; -#if MCP_TEST_TIME_PROVIDER - private readonly TimeProvider _timeProvider; -#endif + private readonly ConcurrentDictionary _tasks = new(); /// - /// Initializes a new instance of the class. + /// Gets or sets the default poll interval in milliseconds for new tasks. /// - /// - /// Default TTL to use when task creation does not specify a TTL. Null means unlimited. - /// - /// - /// Maximum TTL allowed. If a task requests a longer TTL, it will be capped to this value. - /// Null means no maximum limit. - /// - /// - /// Advertised polling interval for tasks. Default is 1 second. - /// This value is used when creating new tasks to indicate how frequently clients should poll for updates. - /// - /// - /// Interval for running background cleanup of expired tasks. Default is 1 minute. - /// Pass to disable automatic cleanup. - /// - /// - /// Maximum number of tasks to return per page in . Default is 100. - /// - /// - /// Maximum number of tasks allowed in the store globally. Null means unlimited. - /// When the limit is reached, will throw . - /// - /// - /// Maximum number of tasks allowed per session. Null means unlimited. - /// When the limit is reached for a session, will throw . - /// - public InMemoryMcpTaskStore( - TimeSpan? defaultTtl = null, - TimeSpan? maxTtl = null, - TimeSpan? pollInterval = null, - TimeSpan? cleanupInterval = null, - int pageSize = 100, - int? maxTasks = null, - int? maxTasksPerSession = null) - { - if (defaultTtl.HasValue && maxTtl.HasValue && defaultTtl.Value > maxTtl.Value) - { - throw new ArgumentException( - $"Default TTL ({defaultTtl.Value}) cannot exceed maximum TTL ({maxTtl.Value}).", - nameof(defaultTtl)); - } - - pollInterval ??= TimeSpan.FromSeconds(1); - if (pollInterval <= TimeSpan.Zero) - { - throw new ArgumentOutOfRangeException( - nameof(pollInterval), - pollInterval, - "Poll interval must be positive."); - } - - if (pageSize <= 0) - { - throw new ArgumentOutOfRangeException( - nameof(pageSize), - pageSize, - "Page size must be positive."); - } - - if (maxTasks is <= 0) - { - throw new ArgumentOutOfRangeException( - nameof(maxTasks), - maxTasks, - "Max tasks must be positive."); - } - - if (maxTasksPerSession is <= 0) - { - throw new ArgumentOutOfRangeException( - nameof(maxTasksPerSession), - maxTasksPerSession, - "Max tasks per session must be positive."); - } - - _defaultTtl = defaultTtl; - _maxTtl = maxTtl; - _pollInterval = pollInterval.Value; - _pageSize = pageSize; - _maxTasks = maxTasks; - _maxTasksPerSession = maxTasksPerSession; -#if MCP_TEST_TIME_PROVIDER - _timeProvider = TimeProvider.System; -#endif - - cleanupInterval ??= TimeSpan.FromMinutes(1); - if (cleanupInterval.Value != Timeout.InfiniteTimeSpan) - { -#if MCP_TEST_TIME_PROVIDER - _cleanupTimer = _timeProvider.CreateTimer(CleanupExpiredTasks, null, cleanupInterval.Value, cleanupInterval.Value); -#else - _cleanupTimer = new Timer(CleanupExpiredTasks, null, cleanupInterval.Value, cleanupInterval.Value); -#endif - } - } + /// The default is 1000 milliseconds. + public long DefaultPollIntervalMs { get; set; } = 1000; -#if MCP_TEST_TIME_PROVIDER /// - /// Initializes a new instance of the class with a custom time provider. - /// This constructor is only available for testing purposes. + /// Gets or sets the default time-to-live in milliseconds for new tasks, or for unlimited. /// - internal InMemoryMcpTaskStore( - TimeSpan? defaultTtl, - TimeSpan? maxTtl, - TimeSpan? pollInterval, - TimeSpan? cleanupInterval, - int pageSize, - int? maxTasks, - int? maxTasksPerSession, - TimeProvider timeProvider) - : this(defaultTtl, maxTtl, pollInterval, cleanupInterval, pageSize, maxTasks, maxTasksPerSession) - { - _timeProvider = timeProvider ?? TimeProvider.System; - } -#endif + public long? DefaultTtlMs { get; set; } /// - public Task CreateTaskAsync( - McpTaskMetadata taskParams, - RequestId requestId, - JsonRpcRequest request, - string? sessionId = null, - CancellationToken cancellationToken = default) + public Task CreateTaskAsync(CancellationToken cancellationToken = default) { - // Check global task limit - if (_maxTasks is { } maxTasks && _tasks.Count >= maxTasks) - { - throw new InvalidOperationException( - $"Maximum number of tasks ({maxTasks}) has been reached. Cannot create new task."); - } + var taskId = Guid.NewGuid().ToString("N"); + var now = DateTimeOffset.UtcNow; - // Check per-session task limit - if (_maxTasksPerSession is { } maxPerSession && sessionId is not null) - { - var sessionTaskCount = _tasks.Values.Count(e => e.SessionId == sessionId && !IsExpired(e)); - if (sessionTaskCount >= maxPerSession) - { - throw new InvalidOperationException( - $"Maximum number of tasks per session ({maxPerSession}) has been reached for session '{sessionId}'. Cannot create new task."); - } - } - - var taskId = GenerateTaskId(); - var now = GetUtcNow(); + var info = new McpTaskInfo(taskId, McpTaskStatus.Working, now, now, DefaultTtlMs, DefaultPollIntervalMs); + _tasks[taskId] = info; - // Determine TTL: use requested, fall back to default, respect max limit - var ttl = taskParams.TimeToLive ?? _defaultTtl; - if (ttl is { } ttlValue && _maxTtl is { } maxTtlValue && ttlValue > maxTtlValue) - { - ttl = maxTtlValue; - } - - TaskEntry entry = new() - { - TaskId = taskId, - Status = McpTaskStatus.Working, - CreatedAt = now, - LastUpdatedAt = now, - TimeToLive = ttl, - PollInterval = _pollInterval, - RequestId = requestId, - Request = request, - SessionId = sessionId - }; - - if (!_tasks.TryAdd(taskId, entry)) - { - // This should be extremely rare with GUID-based IDs - throw new InvalidOperationException($"Task ID collision: {taskId}"); - } - - return Task.FromResult(entry.ToMcpTask()); + return Task.FromResult(info); } /// - public Task GetTaskAsync(string taskId, string? sessionId = null, CancellationToken cancellationToken = default) + public Task GetTaskAsync(string taskId, CancellationToken cancellationToken = default) { - if (!_tasks.TryGetValue(taskId, out var entry)) - { - return Task.FromResult(null); - } + _tasks.TryGetValue(taskId, out var info); + return Task.FromResult(info); + } - // Enforce session isolation if sessionId is provided - if (sessionId != null && entry.SessionId != sessionId) + /// + public Task SetCompletedAsync(string taskId, JsonElement result, CancellationToken cancellationToken = default) + { + Update(taskId, entry => entry with { - return Task.FromResult(null); - } + Status = McpTaskStatus.Completed, + Result = result, + LastUpdatedAt = DateTimeOffset.UtcNow, + }); - return Task.FromResult(entry.ToMcpTask()); + return Task.CompletedTask; } /// - public Task StoreTaskResultAsync( - string taskId, - McpTaskStatus status, - JsonElement result, - string? sessionId = null, - CancellationToken cancellationToken = default) + public Task SetFailedAsync(string taskId, JsonElement error, CancellationToken cancellationToken = default) { - if (status is not (McpTaskStatus.Completed or McpTaskStatus.Failed)) + Update(taskId, entry => entry with { - throw new ArgumentException( - $"Status must be {nameof(McpTaskStatus.Completed)} or {nameof(McpTaskStatus.Failed)}.", - nameof(status)); - } - - // Retry loop for optimistic concurrency - while (true) - { - if (!_tasks.TryGetValue(taskId, out var entry)) - { - throw new InvalidOperationException($"Task not found: {taskId}"); - } - - // Enforce session isolation - if (sessionId != null && entry.SessionId != sessionId) - { - throw new InvalidOperationException($"Task not found: {taskId}"); - } - - // Prevent overwriting terminal state - if (IsTerminalStatus(entry.Status)) - { - throw new InvalidOperationException( - $"Cannot store result for task in terminal state: {entry.Status}"); - } - - var updatedEntry = new TaskEntry(entry) - { - Status = status, - LastUpdatedAt = GetUtcNow(), - StoredResult = result - }; - - if (_tasks.TryUpdate(taskId, updatedEntry, entry)) - { - return Task.FromResult(updatedEntry.ToMcpTask()); - } + Status = McpTaskStatus.Failed, + Error = error, + LastUpdatedAt = DateTimeOffset.UtcNow, + }); - // Entry was modified by another thread, retry - } + return Task.CompletedTask; } /// - public Task GetTaskResultAsync(string taskId, string? sessionId = null, CancellationToken cancellationToken = default) + public Task SetCancelledAsync(string taskId, CancellationToken cancellationToken = default) { if (!_tasks.TryGetValue(taskId, out var entry)) { - throw new InvalidOperationException($"Task not found: {taskId}"); + return Task.FromResult(false); } - // Enforce session isolation - if (sessionId != entry.SessionId) + if (entry.Status is McpTaskStatus.Completed or McpTaskStatus.Failed or McpTaskStatus.Cancelled) { - throw new InvalidOperationException($"Invalid sessionId: {sessionId} provided for {taskId}"); + return Task.FromResult(false); } - if (entry.StoredResult is not { } storedResult) - { - throw new InvalidOperationException($"No result stored for task: {taskId}"); - } + Update(taskId, e => e.Status is McpTaskStatus.Completed or McpTaskStatus.Failed or McpTaskStatus.Cancelled + ? e + : e with { Status = McpTaskStatus.Cancelled, LastUpdatedAt = DateTimeOffset.UtcNow }); - return Task.FromResult(storedResult); + return Task.FromResult(true); } /// - public Task UpdateTaskStatusAsync( + public Task ResolveInputRequestsAsync( string taskId, - McpTaskStatus status, - string? statusMessage, - string? sessionId = null, + IEnumerable inputResponseKeys, CancellationToken cancellationToken = default) { - // Retry loop for optimistic concurrency - while (true) + Update(taskId, entry => { - if (!_tasks.TryGetValue(taskId, out var entry)) - { - throw new InvalidOperationException($"Task not found: {taskId}"); - } + var requests = entry.InputRequests as ImmutableDictionary + ?? entry.InputRequests?.ToImmutableDictionary() + ?? ImmutableDictionary.Empty; - // Enforce session isolation - if (sessionId != null && entry.SessionId != sessionId) + foreach (var key in inputResponseKeys) { - throw new InvalidOperationException($"Task not found: {taskId}"); + requests = requests.Remove(key); } - var updatedEntry = new TaskEntry(entry) + var status = requests.IsEmpty ? McpTaskStatus.Working : entry.Status; + + return entry with { + InputRequests = requests, Status = status, - StatusMessage = statusMessage, - LastUpdatedAt = GetUtcNow(), + LastUpdatedAt = DateTimeOffset.UtcNow, }; + }); - if (_tasks.TryUpdate(taskId, updatedEntry, entry)) - { - return Task.FromResult(updatedEntry.ToMcpTask()); - } - - // Entry was modified by another thread, retry - } + return Task.CompletedTask; } /// - public Task ListTasksAsync( - string? cursor = null, - string? sessionId = null, + public Task SetInputRequestsAsync( + string taskId, + IDictionary inputRequests, CancellationToken cancellationToken = default) { - // Stream enumeration - filter by session, exclude expired, apply keyset pagination - var query = _tasks.Values - .Where(e => sessionId == null || e.SessionId == sessionId) - .Where(e => !IsExpired(e)); - - // Apply keyset filter if cursor provided: TaskId > cursor - // UUID v7 task IDs are monotonically increasing and inherently time-ordered - if (cursor != null) + Update(taskId, entry => { - query = query.Where(e => string.CompareOrdinal(e.TaskId, cursor) > 0); - } - - // Order by TaskId for stable, deterministic pagination - // UUID v7 task IDs sort chronologically due to embedded timestamp - var page = query - .OrderBy(e => e.TaskId, StringComparer.Ordinal) - .Take(_pageSize + 1) // Take one extra to check if there's a next page - .Select(e => e.ToMcpTask()) - .ToList(); + var requests = entry.InputRequests as ImmutableDictionary + ?? entry.InputRequests?.ToImmutableDictionary() + ?? ImmutableDictionary.Empty; - // Set nextCursor if we have more results - string? nextCursor; - if (page.Count > _pageSize) - { - var lastItemInPage = page[_pageSize - 1]; // Last item we'll actually return - nextCursor = lastItemInPage.TaskId; - page.RemoveAt(_pageSize); // Remove the extra item - } - else - { - nextCursor = null; - } + foreach (var kvp in inputRequests) + { + requests = requests.SetItem(kvp.Key, kvp.Value); + } - return Task.FromResult(new ListTasksResult - { - Tasks = page.ToArray(), - NextCursor = nextCursor + return entry with + { + InputRequests = requests, + Status = McpTaskStatus.InputRequired, + LastUpdatedAt = DateTimeOffset.UtcNow, + }; }); + + return Task.CompletedTask; } - /// - public Task CancelTaskAsync(string taskId, string? sessionId = null, CancellationToken cancellationToken = default) + private void Update(string taskId, Func transform) { - // Retry loop for optimistic concurrency + SpinWait spin = default; while (true) { - if (!_tasks.TryGetValue(taskId, out var entry)) - { - throw new InvalidOperationException($"Task not found: {taskId}"); - } - - // Enforce session isolation - if (sessionId != null && entry.SessionId != sessionId) + if (!_tasks.TryGetValue(taskId, out var current)) { - throw new InvalidOperationException($"Task not found: {taskId}"); + throw new InvalidOperationException($"Task '{taskId}' not found."); } - // If already in terminal state, return unchanged - if (IsTerminalStatus(entry.Status)) + var updated = transform(current); + if (ReferenceEquals(updated, current) || _tasks.TryUpdate(taskId, updated, current)) { - return Task.FromResult(entry.ToMcpTask()); + return; } - var updatedEntry = new TaskEntry(entry) - { - Status = McpTaskStatus.Cancelled, - LastUpdatedAt = GetUtcNow(), - }; - - if (_tasks.TryUpdate(taskId, updatedEntry, entry)) - { - return Task.FromResult(updatedEntry.ToMcpTask()); - } - - // Entry was modified by another thread, retry + spin.SpinOnce(); } } - - /// - /// Disposes the task store and stops background cleanup. - /// - public void Dispose() - { - _cleanupTimer?.Dispose(); - } - - private string GenerateTaskId() => - IdHelpers.CreateMonotonicId(GetUtcNow()); - - private static bool IsTerminalStatus(McpTaskStatus status) => - status is McpTaskStatus.Completed or McpTaskStatus.Failed or McpTaskStatus.Cancelled; - -#if MCP_TEST_TIME_PROVIDER - private DateTimeOffset GetUtcNow() => _timeProvider.GetUtcNow(); -#else - private static DateTimeOffset GetUtcNow() => DateTimeOffset.UtcNow; -#endif - -#if MCP_TEST_TIME_PROVIDER - private bool IsExpired(TaskEntry entry) -#else - private static bool IsExpired(TaskEntry entry) -#endif - { - if (entry.TimeToLive == null) - { - return false; // Unlimited lifetime - } - - var expirationTime = entry.CreatedAt + entry.TimeToLive.Value; - return GetUtcNow() >= expirationTime; - } - - private void CleanupExpiredTasks(object? state) - { - var expiredTaskIds = _tasks - .Where(kvp => IsExpired(kvp.Value)) - .Select(kvp => kvp.Key) - .ToList(); - - foreach (var taskId in expiredTaskIds) - { - _tasks.TryRemove(taskId, out _); - } - } - - private sealed class TaskEntry - { - // Flattened McpTask properties - public required string TaskId { get; init; } - public required McpTaskStatus Status { get; init; } - public string? StatusMessage { get; init; } - public required DateTimeOffset CreatedAt { get; init; } - public required DateTimeOffset LastUpdatedAt { get; init; } - public TimeSpan? TimeToLive { get; init; } - public TimeSpan? PollInterval { get; init; } - - // Request metadata - public required RequestId RequestId { get; init; } - public required JsonRpcRequest Request { get; init; } - public required string? SessionId { get; init; } - public JsonElement? StoredResult { get; init; } - - /// - /// Copy constructor for creating modified copies. - /// - [SetsRequiredMembers] - public TaskEntry(TaskEntry source) - { - TaskId = source.TaskId; - Status = source.Status; - StatusMessage = source.StatusMessage; - CreatedAt = source.CreatedAt; - LastUpdatedAt = source.LastUpdatedAt; - TimeToLive = source.TimeToLive; - PollInterval = source.PollInterval; - RequestId = source.RequestId; - Request = source.Request; - SessionId = source.SessionId; - StoredResult = source.StoredResult; - } - - /// - /// Default constructor for initial creation. - /// - public TaskEntry() { } - - /// - /// Converts this entry back to an McpTask for external consumption. - /// - public McpTask ToMcpTask() => new() - { - TaskId = TaskId, - Status = Status, - StatusMessage = StatusMessage, - CreatedAt = CreatedAt, - LastUpdatedAt = LastUpdatedAt, - TimeToLive = TimeToLive, - PollInterval = PollInterval - }; - } } diff --git a/src/ModelContextProtocol.Core/Server/McpRequestFilters.cs b/src/ModelContextProtocol.Core/Server/McpRequestFilters.cs index 5044f8928..e778d9d1b 100644 --- a/src/ModelContextProtocol.Core/Server/McpRequestFilters.cs +++ b/src/ModelContextProtocol.Core/Server/McpRequestFilters.cs @@ -36,9 +36,15 @@ public IList> ListTool /// Gets or sets the filters for the call-tool handler pipeline. /// /// + /// /// These filters wrap handlers that are invoked when a client makes a call to a tool that isn't found in the collection. /// The filters can modify, log, or perform additional operations on requests and responses for /// requests. The handler should implement logic to execute the requested tool and return appropriate results. + /// + /// + /// Cannot be used together with . If both are non-empty at configuration time, + /// an will be thrown. + /// /// public IList> CallToolFilters { @@ -50,6 +56,31 @@ public IList> CallToolFi } } + /// + /// Gets or sets the filters for the call-tool handler pipeline with task support. + /// + /// + /// + /// These filters wrap the task-augmented call-tool handler whose return type is + /// . Use these filters when the server's tool pipeline + /// supports returning either an immediate or a + /// for asynchronous execution. + /// + /// + /// Cannot be used together with . If both are non-empty at configuration time, + /// an will be thrown. + /// + /// + public IList>> CallToolWithTaskFilters + { + get => field ??= []; + set + { + Throw.IfNull(value); + field = value; + } + } + /// /// Gets or sets the filters for the list-prompts handler pipeline. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs index 3caaca5a6..4b4d93fd8 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs @@ -53,63 +53,29 @@ public static McpServer Create( /// is . /// The client does not support sampling. /// The request failed or the client returned an error response. - /// - /// When called during task-augmented tool execution, this method automatically updates the task - /// status to while waiting for the client response, - /// then returns to when the response is received. - /// - public async ValueTask SampleAsync( + public ValueTask SampleAsync( CreateMessageRequestParams requestParams, CancellationToken cancellationToken = default) { Throw.IfNull(requestParams); - ThrowIfSamplingUnsupported(); - return await SendRequestWithTaskStatusTrackingAsync( - RequestMethods.SamplingCreateMessage, - requestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult, - "Waiting for sampling response", - cancellationToken).ConfigureAwait(false); - } + // If executing inside a background task, redirect sampling through the task store. + if (McpTaskExecutionContext.Current.Value is { } taskContext) + { + return SendRequestViaTaskAsync(taskContext, RequestMethods.SamplingCreateMessage, requestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult, + cancellationToken); + } - /// - /// Requests to sample an LLM via the client as a task, allowing the server to poll for completion. - /// - /// The parameters for the sampling request. - /// The task metadata specifying TTL and other task-related options. - /// The to monitor for cancellation requests. - /// An representing the created task on the client. - /// or is . - /// The client does not support sampling or task-augmented sampling. - /// The request failed or the client returned an error response. - /// - /// Use to poll for task status and - /// (with ) to retrieve the final result when the task completes. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask SampleAsTaskAsync( - CreateMessageRequestParams requestParams, - McpTaskMetadata taskMetadata, - CancellationToken cancellationToken = default) - { - Throw.IfNull(requestParams); - Throw.IfNull(taskMetadata); ThrowIfSamplingUnsupported(); - ThrowIfTasksUnsupportedForSampling(); - - // Set the task metadata on the request - requestParams.Task = taskMetadata; - var result = await SendRequestAsync( + return SendRequestAsync( RequestMethods.SamplingCreateMessage, requestParams, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateTaskResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - return result.Task; + McpJsonUtilities.JsonContext.Default.CreateMessageResult, + cancellationToken: cancellationToken); } /// @@ -278,6 +244,16 @@ public ValueTask RequestRootsAsync( CancellationToken cancellationToken = default) { Throw.IfNull(requestParams); + + // If executing inside a background task, redirect through the task store. + if (McpTaskExecutionContext.Current.Value is { } taskContext) + { + return SendRequestViaTaskAsync(taskContext, RequestMethods.RootsList, requestParams, + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult, + cancellationToken); + } + ThrowIfRootsUnsupported(); return SendRequestAsync( @@ -297,360 +273,32 @@ public ValueTask RequestRootsAsync( /// is . /// The client does not support elicitation. /// The request failed or the client returned an error response. - /// - /// When called during task-augmented tool execution, this method automatically updates the task - /// status to while waiting for user input, - /// then returns to when the response is received. - /// public async ValueTask ElicitAsync( ElicitRequestParams requestParams, CancellationToken cancellationToken = default) { Throw.IfNull(requestParams); - ThrowIfElicitationUnsupported(requestParams); - var result = await SendRequestWithTaskStatusTrackingAsync( - RequestMethods.ElicitationCreate, - requestParams, - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult, - "Waiting for user input", - cancellationToken).ConfigureAwait(false); - - return ElicitResult.WithDefaults(requestParams, result); - } + // If executing inside a background task, redirect elicitation through the task store. + if (McpTaskExecutionContext.Current.Value is { } taskContext) + { + var taskResult = await SendRequestViaTaskAsync(taskContext, RequestMethods.ElicitationCreate, requestParams, + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult, + cancellationToken).ConfigureAwait(false); + return taskResult ?? new ElicitResult { Action = "cancel" }; + } - /// - /// Requests additional information from the user via the client as a task, allowing the server to poll for completion. - /// - /// The parameters for the elicitation request. - /// The task metadata specifying TTL and other task-related options. - /// The to monitor for cancellation requests. - /// An representing the created task on the client. - /// or is . - /// The client does not support elicitation or task-augmented elicitation. - /// The request failed or the client returned an error response. - /// - /// Use to poll for task status and - /// (with ) to retrieve the final result when the task completes. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask ElicitAsTaskAsync( - ElicitRequestParams requestParams, - McpTaskMetadata taskMetadata, - CancellationToken cancellationToken = default) - { - Throw.IfNull(requestParams); - Throw.IfNull(taskMetadata); ThrowIfElicitationUnsupported(requestParams); - ThrowIfTasksUnsupportedForElicitation(); - - // Set the task metadata on the request - requestParams.Task = taskMetadata; var result = await SendRequestAsync( RequestMethods.ElicitationCreate, requestParams, McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.CreateTaskResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - return result.Task; - } - - /// - /// Retrieves the current state of a specific task from the client. - /// - /// The unique identifier of the task to retrieve. - /// The to monitor for cancellation requests. The default is . - /// The current state of the task. - /// is . - /// is empty or composed entirely of whitespace. - /// The client does not support tasks. - /// The request failed or the client returned an error response. - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask GetTaskAsync( - string taskId, - CancellationToken cancellationToken = default) - { - Throw.IfNullOrWhiteSpace(taskId); - ThrowIfTasksUnsupported(); - - var result = await SendRequestAsync( - RequestMethods.TasksGet, - new GetTaskRequestParams { TaskId = taskId }, - McpJsonUtilities.JsonContext.Default.GetTaskRequestParams, - McpJsonUtilities.JsonContext.Default.GetTaskResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - // Convert GetTaskResult to McpTask - return new McpTask - { - TaskId = result.TaskId, - Status = result.Status, - StatusMessage = result.StatusMessage, - CreatedAt = result.CreatedAt, - LastUpdatedAt = result.LastUpdatedAt, - TimeToLive = result.TimeToLive, - PollInterval = result.PollInterval - }; - } - - /// - /// Retrieves the result of a completed task from the client, blocking until the task reaches a terminal state. - /// - /// The type to deserialize the task result into. - /// The unique identifier of the task whose result to retrieve. - /// Optional serializer options for deserializing the result. - /// The to monitor for cancellation requests. The default is . - /// The result of the task, deserialized into type . - /// is . - /// is empty or composed entirely of whitespace. - /// The client does not support tasks. - /// The request failed or the client returned an error response. - /// - /// - /// This method sends a tasks/result request to the client, which will block until the task completes if it hasn't already. - /// The client handles all polling logic internally. - /// - /// - /// For sampling tasks, use as . - /// For elicitation tasks, use as . - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask GetTaskResultAsync( - string taskId, - JsonSerializerOptions? jsonSerializerOptions = null, - CancellationToken cancellationToken = default) - { - Throw.IfNullOrWhiteSpace(taskId); - ThrowIfTasksUnsupported(); - - var result = await SendRequestAsync( - RequestMethods.TasksResult, - new GetTaskPayloadRequestParams { TaskId = taskId }, - McpJsonUtilities.JsonContext.Default.GetTaskPayloadRequestParams, - McpJsonUtilities.JsonContext.Default.JsonElement, - cancellationToken: cancellationToken).ConfigureAwait(false); - - var serializerOptions = jsonSerializerOptions ?? McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - var typeInfo = serializerOptions.GetTypeInfo(); - return result.Deserialize(typeInfo); - } - - /// - /// Retrieves a list of all tasks from the client. - /// - /// The to monitor for cancellation requests. The default is . - /// A list of all tasks. - /// The client does not support tasks or task listing. - /// The request failed or the client returned an error response. - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask> ListTasksAsync( - CancellationToken cancellationToken = default) - { - ThrowIfTasksUnsupported(); - ThrowIfTaskListingUnsupported(); - - List? tasks = null; - ListTasksRequestParams requestParams = new(); - do - { - var taskResults = await ListTasksAsync(requestParams, cancellationToken).ConfigureAwait(false); - if (tasks is null) - { - tasks = new List(taskResults.Tasks.Count); - } - - foreach (var mcpTask in taskResults.Tasks) - { - tasks.Add(mcpTask); - } - - requestParams.Cursor = taskResults.NextCursor; - } - while (requestParams.Cursor is not null); - - return tasks; - } - - /// - /// Retrieves a list of tasks from the client. - /// - /// The request parameters to send in the request. - /// The to monitor for cancellation requests. The default is . - /// The result of the request as provided by the client. - /// is . - /// The client does not support tasks or task listing. - /// The request failed or the client returned an error response. - /// - /// The overload retrieves all tasks by automatically handling pagination. - /// This overload works with the lower-level and , returning the raw result from the client. - /// Any pagination needs to be managed by the caller. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public ValueTask ListTasksAsync( - ListTasksRequestParams requestParams, - CancellationToken cancellationToken = default) - { - Throw.IfNull(requestParams); - ThrowIfTasksUnsupported(); - ThrowIfTaskListingUnsupported(); - - return SendRequestAsync( - RequestMethods.TasksList, - requestParams, - McpJsonUtilities.JsonContext.Default.ListTasksRequestParams, - McpJsonUtilities.JsonContext.Default.ListTasksResult, - cancellationToken: cancellationToken); - } - - /// - /// Cancels a running task on the client. - /// - /// The unique identifier of the task to cancel. - /// The to monitor for cancellation requests. The default is . - /// The updated state of the task after cancellation. - /// is . - /// is empty or composed entirely of whitespace. - /// The client does not support tasks or task cancellation. - /// The request failed or the client returned an error response. - /// - /// Cancelling a task requests that the client stop execution. The client may not immediately cancel the task, - /// and may choose to allow the task to complete if it's close to finishing. - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask CancelTaskAsync( - string taskId, - CancellationToken cancellationToken = default) - { - Throw.IfNullOrWhiteSpace(taskId); - ThrowIfTasksUnsupported(); - ThrowIfTaskCancellationUnsupported(); - - var result = await SendRequestAsync( - RequestMethods.TasksCancel, - new CancelMcpTaskRequestParams { TaskId = taskId }, - McpJsonUtilities.JsonContext.Default.CancelMcpTaskRequestParams, - McpJsonUtilities.JsonContext.Default.CancelMcpTaskResult, + McpJsonUtilities.JsonContext.Default.ElicitResult, cancellationToken: cancellationToken).ConfigureAwait(false); - // Convert CancelMcpTaskResult to McpTask - return new McpTask - { - TaskId = result.TaskId, - Status = result.Status, - StatusMessage = result.StatusMessage, - CreatedAt = result.CreatedAt, - LastUpdatedAt = result.LastUpdatedAt, - TimeToLive = result.TimeToLive, - PollInterval = result.PollInterval - }; - } - - /// - /// Polls a task on the client until it reaches a terminal state. - /// - /// The unique identifier of the task to poll. - /// The to monitor for cancellation requests. The default is . - /// The task in its terminal state. - /// is . - /// is empty or composed entirely of whitespace. - /// The client does not support tasks. - /// The request failed or the client returned an error response. - /// - /// - /// This method repeatedly calls until the task reaches a terminal status. - /// It respects the returned by the client to determine how long - /// to wait between polling attempts. - /// - /// - /// For retrieving the actual result of a completed task, use - /// or . - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask PollTaskUntilCompleteAsync( - string taskId, - CancellationToken cancellationToken = default) - { - Throw.IfNullOrWhiteSpace(taskId); - - McpTask task; - do - { - task = await GetTaskAsync(taskId, cancellationToken).ConfigureAwait(false); - - // If task is in a terminal state, we're done - if (task.Status is McpTaskStatus.Completed or McpTaskStatus.Failed or McpTaskStatus.Cancelled) - { - break; - } - - // Wait for the poll interval before checking again (default to 1 second) - var pollInterval = task.PollInterval ?? TimeSpan.FromSeconds(1); - await Task.Delay(pollInterval, cancellationToken).ConfigureAwait(false); - } - while (true); - - return task; - } - - /// - /// Waits for a task on the client to complete and retrieves its result. - /// - /// The type to deserialize the task result into. - /// The unique identifier of the task whose result to retrieve. - /// Optional serializer options for deserializing the result. - /// The to monitor for cancellation requests. The default is . - /// A tuple containing the final task state and its result. - /// is . - /// is empty or composed entirely of whitespace. - /// The client does not support tasks. - /// The task failed or was cancelled. - /// - /// - /// This method combines and - /// to provide a convenient way to wait for a task to complete and retrieve its result in a single call. - /// - /// - /// If the task completes with a status of or , - /// an is thrown. - /// - /// - /// For sampling tasks, use as . - /// For elicitation tasks, use as . - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public async ValueTask<(McpTask Task, TResult? Result)> WaitForTaskResultAsync( - string taskId, - JsonSerializerOptions? jsonSerializerOptions = null, - CancellationToken cancellationToken = default) - { - Throw.IfNullOrWhiteSpace(taskId); - - // Poll until task reaches terminal state - var task = await PollTaskUntilCompleteAsync(taskId, cancellationToken).ConfigureAwait(false); - - // Check for failure or cancellation - if (task.Status == McpTaskStatus.Failed) - { - throw new McpException($"Task '{taskId}' failed: {task.StatusMessage ?? "Unknown error"}"); - } - - if (task.Status == McpTaskStatus.Cancelled) - { - throw new McpException($"Task '{taskId}' was cancelled"); - } - - // Retrieve the result - var result = await GetTaskResultAsync(taskId, jsonSerializerOptions, cancellationToken).ConfigureAwait(false); - - return (task, result); + return ElicitResult.WithDefaults(requestParams, result); } /// @@ -715,6 +363,26 @@ public async ValueTask> ElicitAsync( return new ElicitResult { Action = raw.Action, Content = typed }; } + /// + /// Sends a task status notification to the connected client. + /// + /// The task status notification parameters to send. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous send operation. + /// is . + public Task SendTaskStatusNotificationAsync( + TaskStatusNotificationParams notificationParams, + CancellationToken cancellationToken = default) + { + Throw.IfNull(notificationParams); + + return SendNotificationAsync( + NotificationMethods.TaskStatusNotification, + notificationParams, + McpJsonUtilities.JsonContext.Default.TaskStatusNotificationParams, + cancellationToken); + } + /// /// Builds a request schema for elicitation based on the public serializable properties of . /// @@ -864,6 +532,74 @@ private void ThrowIfRootsUnsupported() } } + /// + /// Creates a scope that redirects server-initiated requests (elicitation, sampling, list roots) through + /// the task store as input requests for the duration of the scope. Use this when executing tool logic + /// in the background as a task, so that any server-to-client requests are surfaced to the client via + /// the task's state instead of direct JSON-RPC messages. + /// + /// The task ID in the store. + /// The task store to write input requests to. + /// An that restores the previous context when disposed. + [Experimental(Experimentals.Extensions_DiagnosticId, UrlFormat = Experimentals.Extensions_Url)] + public IDisposable CreateMcpTaskScope( + string taskId, + IMcpTaskStore store) + { + Throw.IfNull(taskId); + Throw.IfNull(store); + + var previous = McpTaskExecutionContext.Current.Value; + McpTaskExecutionContext.Current.Value = new McpTaskExecutionContext + { + TaskId = taskId, + Store = store, + }; + return new McpTaskExecutionContext.Scope(previous); + } + + /// + /// Sends a server-initiated request through the task store as an input request, then awaits the response. + /// + private async ValueTask SendRequestViaTaskAsync( + McpTaskExecutionContext taskContext, + string method, + TRequest request, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo, + CancellationToken cancellationToken) + { + var requestId = Guid.NewGuid().ToString("N"); + var paramsJson = JsonSerializer.SerializeToElement(request, requestTypeInfo); + + // Wrap in a {method, params} envelope so the client can dispatch by method name. + var envelope = new JsonObject + { + ["method"] = method, + ["params"] = JsonNode.Parse(paramsJson.GetRawText()), + }; + var requestJson = JsonSerializer.SerializeToElement(envelope, McpJsonUtilities.JsonContext.Default.JsonObject); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + TaskInputResponseWaiters[(taskContext.TaskId, requestId)] = tcs; + + try + { + await taskContext.Store.SetInputRequestsAsync( + taskContext.TaskId, + new Dictionary { [requestId] = requestJson }, + cancellationToken).ConfigureAwait(false); + + var responseJson = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + + return JsonSerializer.Deserialize(responseJson, responseTypeInfo)!; + } + finally + { + TaskInputResponseWaiters.TryRemove((taskContext.TaskId, requestId), out _); + } + } + private void ThrowIfElicitationUnsupported(ElicitRequestParams request) { if (ClientCapabilities is null) @@ -908,120 +644,6 @@ private void ThrowIfElicitationUnsupported(ElicitRequestParams request) } } - private void ThrowIfTasksUnsupportedForSampling() - { - if (ClientCapabilities?.Tasks?.Requests?.Sampling?.CreateMessage is null) - { - if (ClientCapabilities is null) - { - throw new InvalidOperationException("Task-augmented sampling is not supported in stateless mode."); - } - - throw new InvalidOperationException("Client does not support task-augmented sampling requests."); - } - } - - private void ThrowIfTasksUnsupportedForElicitation() - { - if (ClientCapabilities?.Tasks?.Requests?.Elicitation?.Create is null) - { - if (ClientCapabilities is null) - { - throw new InvalidOperationException("Task-augmented elicitation is not supported in stateless mode."); - } - - throw new InvalidOperationException("Client does not support task-augmented elicitation requests."); - } - } - - private void ThrowIfTasksUnsupported() - { - if (ClientCapabilities?.Tasks is null) - { - if (ClientCapabilities is null) - { - throw new InvalidOperationException("Tasks are not supported in stateless mode."); - } - - throw new InvalidOperationException("Client does not support tasks."); - } - } - - private void ThrowIfTaskListingUnsupported() - { - if (ClientCapabilities?.Tasks?.List is null) - { - throw new InvalidOperationException("Client does not support task listing."); - } - } - - private void ThrowIfTaskCancellationUnsupported() - { - if (ClientCapabilities?.Tasks?.Cancel is null) - { - throw new InvalidOperationException("Client does not support task cancellation."); - } - } - - /// - /// Sends a request to the client, automatically updating task status to InputRequired during - /// the request when called within a task execution context. - /// - private async ValueTask SendRequestWithTaskStatusTrackingAsync( - string method, - TParams requestParams, - JsonTypeInfo paramsTypeInfo, - JsonTypeInfo resultTypeInfo, - string inputRequiredMessage, - CancellationToken cancellationToken) - where TParams : RequestParams - where TResult : notnull - { - var taskContext = TaskExecutionContext.Current; - - // If we're not in a task execution context, just send the request normally - if (taskContext is null) - { - return await SendRequestAsync(method, requestParams, paramsTypeInfo, resultTypeInfo, cancellationToken: cancellationToken).ConfigureAwait(false); - } - - // Update task status to InputRequired - var inputRequiredTask = await taskContext.TaskStore.UpdateTaskStatusAsync( - taskContext.TaskId, - Protocol.McpTaskStatus.InputRequired, - inputRequiredMessage, - taskContext.SessionId, - CancellationToken.None).ConfigureAwait(false); - - // Send notification if enabled - if (taskContext.SendNotifications && taskContext.NotifyTaskStatusFunc is not null) - { - _ = taskContext.NotifyTaskStatusFunc(inputRequiredTask, CancellationToken.None); - } - - try - { - // Send the actual request - return await SendRequestAsync(method, requestParams, paramsTypeInfo, resultTypeInfo, cancellationToken: cancellationToken).ConfigureAwait(false); - } - finally - { - // Update task status back to Working - var workingTask = await taskContext.TaskStore.UpdateTaskStatusAsync( - taskContext.TaskId, - Protocol.McpTaskStatus.Working, - null, // Clear status message - taskContext.SessionId, - CancellationToken.None).ConfigureAwait(false); - - // Send notification if enabled - if (taskContext.SendNotifications && taskContext.NotifyTaskStatusFunc is not null) - { - _ = taskContext.NotifyTaskStatusFunc(workingTask, CancellationToken.None); - } - } - } - /// Provides an implementation that's implemented via client sampling. private sealed class SamplingChatClient(McpServer server, JsonSerializerOptions serializerOptions) : IChatClient { @@ -1059,50 +681,6 @@ async IAsyncEnumerable IChatClient.GetStreamingResponseAsync void IDisposable.Dispose() { } // nop } - /// - /// Sends a task status notification to the connected client. - /// - /// The task whose status changed. - /// The to monitor for cancellation requests. - /// A task representing the asynchronous notification operation. - /// is . - /// The notification failed or the client returned an error response. - /// - /// - /// This method sends an optional status notification to inform the client of task state changes. - /// According to the MCP specification, receivers MAY send this notification but are not required to. - /// Clients must not rely on receiving these notifications and should continue polling via tasks/get. - /// - /// - /// The notification is sent using the standard notifications/tasks/status method and includes - /// the full task state information. - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public Task NotifyTaskStatusAsync( - McpTask task, - CancellationToken cancellationToken = default) - { - Throw.IfNull(task); - - var notificationParams = new McpTaskStatusNotificationParams - { - TaskId = task.TaskId, - Status = task.Status, - StatusMessage = task.StatusMessage, - CreatedAt = task.CreatedAt, - LastUpdatedAt = task.LastUpdatedAt, - TimeToLive = task.TimeToLive, - PollInterval = task.PollInterval - }; - - return SendNotificationAsync( - NotificationMethods.TaskStatusNotification, - notificationParams, - McpJsonUtilities.JsonContext.Default.McpTaskStatusNotificationParams, - cancellationToken); - } - /// /// Provides an implementation for creating loggers /// that send logging message notifications to the client for logged messages. diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index b8b41bdc3..bfea5056a 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -1,4 +1,6 @@ +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; +using System.Text.Json; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Server; @@ -8,6 +10,12 @@ namespace ModelContextProtocol.Server; /// public abstract partial class McpServer : McpSession { + /// + /// Waiters for task-based input responses. Keyed by (taskId, requestId), signaled when + /// input responses arrive via tasks/update. + /// + internal virtual ConcurrentDictionary<(string TaskId, string RequestId), TaskCompletionSource> TaskInputResponseWaiters { get; } = new(); + /// /// Initializes a new instance of the class. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerHandlers.cs b/src/ModelContextProtocol.Core/Server/McpServerHandlers.cs index 6dbcea8af..2e6567f7c 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerHandlers.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerHandlers.cs @@ -42,8 +42,53 @@ public sealed class McpServerHandlers /// /// This handler is invoked when a client makes a call to a tool that isn't found in the collection. /// The handler should implement logic to execute the requested tool and return appropriate results. + /// Use instead if the tool may return a + /// for asynchronous execution. /// - public McpRequestHandler? CallToolHandler { get; set; } + /// is already set. + public McpRequestHandler? CallToolHandler + { + get; + set + { + if (value is not null && CallToolWithTaskHandler is not null) + { + throw new InvalidOperationException( + $"Cannot set {nameof(CallToolHandler)} when {nameof(CallToolWithTaskHandler)} is already set. Only one call tool handler may be configured."); + } + + field = value; + } + } + + /// + /// Gets or sets the handler for requests with task support. + /// + /// + /// + /// This handler is invoked when a client makes a call to a tool, allowing the tool to return either + /// a for immediate results or a for + /// long-running asynchronous operations. + /// + /// + /// Cannot be set if is already set. + /// + /// + /// is already set. + public McpRequestHandler>? CallToolWithTaskHandler + { + get; + set + { + if (value is not null && CallToolHandler is not null) + { + throw new InvalidOperationException( + $"Cannot set {nameof(CallToolWithTaskHandler)} when {nameof(CallToolHandler)} is already set. Only one call tool handler may be configured."); + } + + field = value; + } + } /// /// Gets or sets the handler for requests. @@ -156,6 +201,33 @@ public sealed class McpServerHandlers /// public McpRequestHandler? SetLoggingLevelHandler { get; set; } + /// + /// Gets or sets the handler for requests. + /// + /// + /// This handler is invoked when a client polls for the current state of a task. + /// The handler should return the appropriate subtype + /// based on the task's current status. + /// + public McpRequestHandler? GetTaskHandler { get; set; } + + /// + /// Gets or sets the handler for requests. + /// + /// + /// This handler is invoked when a client provides input responses for a task + /// that is in the state. + /// + public McpRequestHandler? UpdateTaskHandler { get; set; } + + /// + /// Gets or sets the handler for requests. + /// + /// + /// This handler is invoked when a client requests cancellation of an in-progress task. + /// + public McpRequestHandler? CancelTaskHandler { get; set; } + /// Gets or sets notification handlers to register with the server. /// /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs index 04d11e016..d27ecc989 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -2,14 +2,16 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Server; /// -#pragma warning disable MCPEXP002 +#pragma warning disable MCPEXP001, MCPEXP002 internal sealed partial class McpServerImpl : McpServer { internal static Implementation DefaultImplementation { get; } = new() @@ -26,7 +28,6 @@ internal sealed partial class McpServerImpl : McpServer private readonly RequestHandlers _requestHandlers; private readonly McpSessionHandler _sessionHandler; private readonly SemaphoreSlim _disposeLock = new(1, 1); - private readonly McpTaskCancellationTokenProvider? _taskCancellationTokenProvider; private ClientCapabilities? _clientCapabilities; private Implementation? _clientInfo; @@ -68,12 +69,6 @@ public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFact _servicesScopePerRequest = options.ScopeRequests; _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; - // Only allocate the cancellation token provider if a task store is configured - if (options.TaskStore is not null) - { - _taskCancellationTokenProvider = new McpTaskCancellationTokenProvider(); - } - _clientInfo = options.KnownClientInfo; _clientCapabilities = options.KnownClientCapabilities; UpdateEndpointNameWithClientInfo(); @@ -87,9 +82,9 @@ public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFact ConfigureTools(options); ConfigurePrompts(options); ConfigureResources(options); - ConfigureTasks(options); ConfigureLogging(options); ConfigureCompletion(options); + ConfigureTasks(options); ConfigureExperimentalAndExtensions(options); // Register any notification handlers that were provided. @@ -210,7 +205,6 @@ public override async ValueTask DisposeAsync() _disposed = true; - _taskCancellationTokenProvider?.Dispose(); _disposables.ForEach(d => d()); await _sessionHandler.DisposeAsync().ConfigureAwait(false); } @@ -394,6 +388,84 @@ private void ConfigureCompletion(McpServerOptions options) return result; } + private void ConfigureTasks(McpServerOptions options) + { + var getTaskHandler = options.Handlers.GetTaskHandler; + var updateTaskHandler = options.Handlers.UpdateTaskHandler; + var cancelTaskHandler = options.Handlers.CancelTaskHandler; + var taskStore = options.TaskStore; + + // If a task store is provided, wire up handlers from it for any that aren't explicitly set. + if (taskStore is not null) + { + getTaskHandler ??= async (request, cancellationToken) => + { + var info = await taskStore.GetTaskAsync(request.Params!.TaskId, cancellationToken).ConfigureAwait(false); + return info is null + ? throw new McpProtocolException($"Unknown task: '{request.Params.TaskId}'", McpErrorCode.InvalidParams) + : ToGetTaskResult(info); + }; + + updateTaskHandler ??= async (request, cancellationToken) => + { + await taskStore.ResolveInputRequestsAsync(request.Params!.TaskId, request.Params.InputResponses.Keys, cancellationToken).ConfigureAwait(false); + + // Signal any waiters for the provided response keys. + foreach (var kvp in request.Params.InputResponses) + { + if (TaskInputResponseWaiters.TryRemove((request.Params.TaskId, kvp.Key), out var tcs)) + { + tcs.TrySetResult(kvp.Value); + } + } + + return new UpdateTaskResult(); + }; + + cancelTaskHandler ??= async (request, cancellationToken) => + { + var cancelled = await taskStore.SetCancelledAsync(request.Params!.TaskId, cancellationToken).ConfigureAwait(false); + if (!cancelled) + { + throw new McpProtocolException($"Task '{request.Params.TaskId}' could not be cancelled.", McpErrorCode.InvalidParams); + } + + return new CancelTaskResult(); + }; + } + + if (getTaskHandler is null && updateTaskHandler is null && cancelTaskHandler is null) + { + return; + } + + getTaskHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown task: '{request.Params?.TaskId}'", McpErrorCode.InvalidParams)); + updateTaskHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown task: '{request.Params?.TaskId}'", McpErrorCode.InvalidParams)); + cancelTaskHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown task: '{request.Params?.TaskId}'", McpErrorCode.InvalidParams)); + + // Advertise tasks extension in server capabilities. + ServerCapabilities.Extensions ??= new Dictionary(); + ServerCapabilities.Extensions[McpExtensions.Tasks] = new JsonObject(); + + SetHandler( + RequestMethods.TasksGet, + getTaskHandler, + McpJsonUtilities.JsonContext.Default.GetTaskRequestParams, + McpJsonUtilities.JsonContext.Default.GetTaskResult); + + SetHandler( + RequestMethods.TasksUpdate, + updateTaskHandler, + McpJsonUtilities.JsonContext.Default.UpdateTaskRequestParams, + McpJsonUtilities.JsonContext.Default.UpdateTaskResult); + + SetHandler( + RequestMethods.TasksCancel, + cancelTaskHandler, + McpJsonUtilities.JsonContext.Default.CancelTaskRequestParams, + McpJsonUtilities.JsonContext.Default.CancelTaskResult); + } + private void ConfigureExperimentalAndExtensions(McpServerOptions options) { ServerCapabilities.Experimental = options.Capabilities?.Experimental; @@ -663,10 +735,11 @@ private void ConfigureTools(McpServerOptions options) { var listToolsHandler = options.Handlers.ListToolsHandler; var callToolHandler = options.Handlers.CallToolHandler; + var callToolWithTaskHandler = options.Handlers.CallToolWithTaskHandler; var tools = options.ToolCollection; var toolsCapability = options.Capabilities?.Tools; - if (listToolsHandler is null && callToolHandler is null && tools is null && + if (listToolsHandler is null && callToolHandler is null && callToolWithTaskHandler is null && tools is null && toolsCapability is null) { return; @@ -675,10 +748,23 @@ private void ConfigureTools(McpServerOptions options) ServerCapabilities.Tools = new(); listToolsHandler ??= (static async (_, __) => new ListToolsResult()); - callToolHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); var listChanged = toolsCapability?.ListChanged; - // Handle tools provided via DI by augmenting the handlers to incorporate them. + var callToolFilters = options.Filters.Request.CallToolFilters; + var callToolWithTaskFilters = options.Filters.Request.CallToolWithTaskFilters; + + // Validate: cannot mix non-task filters/handler with task filters/handler. + bool hasNonTaskPath = callToolHandler is not null || callToolFilters.Count > 0; + bool hasTaskPath = callToolWithTaskHandler is not null || callToolWithTaskFilters.Count > 0; + + if (hasNonTaskPath && hasTaskPath) + { + throw new InvalidOperationException( + $"Cannot mix non-task ({nameof(McpServerHandlers.CallToolHandler)}/{nameof(McpRequestFilters.CallToolFilters)}) " + + $"with task-based ({nameof(McpServerHandlers.CallToolWithTaskHandler)}/{nameof(McpRequestFilters.CallToolWithTaskFilters)}). Use one style or the other."); + } + + // Handle tools provided via DI by augmenting the list handler. if (tools is not null) { var originalListToolsHandler = listToolsHandler; @@ -699,94 +785,104 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) return result; }; - var originalCallToolHandler = callToolHandler; - var taskStore = options.TaskStore; - var sendNotifications = options.SendTaskStatusNotifications; - callToolHandler = async (request, cancellationToken) => - { - if (request.MatchedPrimitive is McpServerTool tool) - { - var taskSupport = tool.ProtocolTool.Execution?.TaskSupport ?? ToolTaskSupport.Forbidden; + listChanged = true; + } - // Check if this is a task-augmented request - if (request.Params?.Task is { } taskMetadata) - { - // Validate tool-level task support - if (taskSupport is ToolTaskSupport.Forbidden) - { - throw new McpProtocolException( - $"Tool '{tool.ProtocolTool.Name}' does not support task-augmented execution.", - McpErrorCode.InvalidParams); - } + listToolsHandler = BuildFilterPipeline(listToolsHandler, options.Filters.Request.ListToolsFilters); - // Task augmentation requested - return CreateTaskResult - return await ExecuteToolAsTaskAsync(tool, request, taskMetadata, taskStore, sendNotifications, cancellationToken).ConfigureAwait(false); - } + // Build the unified task-augmented handler from one of the two paths. + if (hasTaskPath) + { + // Case 2: task filter + task handler + callToolWithTaskHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); - // Validate that required task support is satisfied - if (taskSupport is ToolTaskSupport.Required) + // Augment with DI tools. + if (tools is not null) + { + var originalHandler = callToolWithTaskHandler; + callToolWithTaskHandler = (request, cancellationToken) => + { + if (request.MatchedPrimitive is McpServerTool tool) { - throw new McpProtocolException( - $"Tool '{tool.ProtocolTool.Name}' requires task-augmented execution. " + - "Include a 'task' parameter with the request.", - McpErrorCode.InvalidParams); + return InvokeToolAsTask(tool, request, cancellationToken); } - // Normal synchronous execution - return await tool.InvokeAsync(request, cancellationToken).ConfigureAwait(false); - } - - return await originalCallToolHandler(request, cancellationToken).ConfigureAwait(false); - }; + return originalHandler(request, cancellationToken); + }; + } - listChanged = true; + callToolWithTaskHandler = BuildFilterPipeline(callToolWithTaskHandler, callToolWithTaskFilters, BuildInitialTaskToolFilter(tools)); } + else + { + // Case 1: non-task filter + non-task handler → apply filters, then convert to task-based + callToolHandler ??= (static async (request, _) => throw new McpProtocolException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); - listToolsHandler = BuildFilterPipeline(listToolsHandler, options.Filters.Request.ListToolsFilters); - callToolHandler = BuildFilterPipeline(callToolHandler, options.Filters.Request.CallToolFilters, handler => - async (request, cancellationToken) => + // Augment with DI tools. + if (tools is not null) { - // Initial handler that sets MatchedPrimitive - if (request.Params?.Name is { } toolName && tools is not null && - tools.TryGetPrimitive(toolName, out var tool)) + var originalHandler = callToolHandler; + callToolHandler = (request, cancellationToken) => { - request.MatchedPrimitive = tool; - } - - try - { - var result = await handler(request, cancellationToken).ConfigureAwait(false); - - // Don't log here for task-augmented calls; logging happens asynchronously - // in ExecuteToolAsTaskAsync when the tool actually completes. - if (result.Task is null) + if (request.MatchedPrimitive is McpServerTool tool) { - ToolCallCompleted(request.Params?.Name ?? string.Empty, result.IsError is true); + return tool.InvokeAsync(request, cancellationToken); } - return result; - } - catch (Exception e) - { - ToolCallError(request.Params?.Name ?? string.Empty, e); + return originalHandler(request, cancellationToken); + }; + } - if ((e is OperationCanceledException && cancellationToken.IsCancellationRequested) || e is McpProtocolException) - { - throw; - } + callToolHandler = BuildFilterPipeline(callToolHandler, callToolFilters, BuildInitialCallToolFilter(tools)); - return new() + // Convert to task-based. + var finalCallToolHandler = callToolHandler; + callToolWithTaskHandler = async (request, cancellationToken) => + await finalCallToolHandler(request, cancellationToken).ConfigureAwait(false); + } + + // If a task store is configured, wrap so that when the client signals task support + // the tool execution is offloaded to the background via the store. + if (options.TaskStore is { } taskStore) + { + var innerTaskHandler = callToolWithTaskHandler; + callToolWithTaskHandler = async (request, cancellationToken) => + { + if (request.Params?.Meta?.ContainsKey(McpExtensions.Tasks) is true) + { + var taskInfo = await taskStore.CreateTaskAsync(cancellationToken).ConfigureAwait(false); + var taskId = taskInfo.TaskId; + + _ = Task.Run(async () => { - IsError = true, - Content = [new TextContentBlock + using (CreateMcpTaskScope(taskId, taskStore)) { - Text = e is McpException ? - $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : - $"An error occurred invoking '{request.Params?.Name}'.", - }], - }; + try + { + var augmented = await innerTaskHandler(request, CancellationToken.None).ConfigureAwait(false); + if (augmented.IsTask) + { + return; + } + + var resultJson = JsonSerializer.SerializeToElement(augmented.Result!, McpJsonUtilities.JsonContext.Default.CallToolResult); + await taskStore.SetCompletedAsync(taskId, resultJson).ConfigureAwait(false); + } + catch (Exception ex) + { + var escapedMessage = JsonSerializer.Serialize(ex.Message, McpJsonUtilities.JsonContext.Default.String); + var errorJson = JsonDocument.Parse($$$"""{{"message": {{{escapedMessage}}}}}""").RootElement; + await taskStore.SetFailedAsync(taskId, errorJson).ConfigureAwait(false); + } + } + }, CancellationToken.None); + + return ToCreateTaskResult(taskInfo); } - }); + + return await innerTaskHandler(request, cancellationToken).ConfigureAwait(false); + }; + } ServerCapabilities.Tools.ListChanged = listChanged; @@ -796,144 +892,174 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, McpJsonUtilities.JsonContext.Default.ListToolsResult); - SetHandler( + SetTaskAugmentedHandler( RequestMethods.ToolsCall, - callToolHandler, + callToolWithTaskHandler, McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult); + McpJsonUtilities.JsonContext.Default.CallToolResult, + McpJsonUtilities.JsonContext.Default.CreateTaskResult); } - private void ConfigureTasks(McpServerOptions options) + private static CreateTaskResult ToCreateTaskResult(McpTaskInfo info) => new() { - var taskStore = options.TaskStore; + TaskId = info.TaskId, + Status = info.Status, + CreatedAt = info.CreatedAt, + LastUpdatedAt = info.LastUpdatedAt, + TtlMs = info.TtlMs, + PollIntervalMs = info.PollIntervalMs, + StatusMessage = info.StatusMessage, + ResultType = "task", + }; - // If no task store is configured, tasks are not supported - if (taskStore is null) + private static GetTaskResult ToGetTaskResult(McpTaskInfo info) => info.Status switch + { + McpTaskStatus.Working => new WorkingTaskResult { - return; - } - - // Advertise task support in server capabilities - ServerCapabilities.Tasks = new McpTasksCapability + TaskId = info.TaskId, + CreatedAt = info.CreatedAt, + LastUpdatedAt = info.LastUpdatedAt, + TtlMs = info.TtlMs, + PollIntervalMs = info.PollIntervalMs, + StatusMessage = info.StatusMessage, + ResultType = "complete", + }, + McpTaskStatus.Completed => new CompletedTaskResult { - List = new ListMcpTasksCapability(), - Cancel = new CancelMcpTasksCapability(), - Requests = new RequestMcpTasksCapability - { - Tools = new ToolsMcpTasksCapability - { - Call = new CallToolMcpTasksCapability() - } - } - }; + TaskId = info.TaskId, + CreatedAt = info.CreatedAt, + LastUpdatedAt = info.LastUpdatedAt, + TtlMs = info.TtlMs, + PollIntervalMs = info.PollIntervalMs, + StatusMessage = info.StatusMessage, + TaskResult = info.Result!.Value, + ResultType = "complete", + }, + McpTaskStatus.Failed => new FailedTaskResult + { + TaskId = info.TaskId, + CreatedAt = info.CreatedAt, + LastUpdatedAt = info.LastUpdatedAt, + TtlMs = info.TtlMs, + PollIntervalMs = info.PollIntervalMs, + StatusMessage = info.StatusMessage, + Error = info.Error!.Value, + ResultType = "complete", + }, + McpTaskStatus.Cancelled => new CancelledTaskResult + { + TaskId = info.TaskId, + CreatedAt = info.CreatedAt, + LastUpdatedAt = info.LastUpdatedAt, + TtlMs = info.TtlMs, + PollIntervalMs = info.PollIntervalMs, + StatusMessage = info.StatusMessage, + ResultType = "complete", + }, + McpTaskStatus.InputRequired => new InputRequiredTaskResult + { + TaskId = info.TaskId, + CreatedAt = info.CreatedAt, + LastUpdatedAt = info.LastUpdatedAt, + TtlMs = info.TtlMs, + PollIntervalMs = info.PollIntervalMs, + StatusMessage = info.StatusMessage, + InputRequests = info.InputRequests is IDictionary dict + ? dict + : info.InputRequests?.ToDictionary(kvp => kvp.Key, kvp => kvp.Value) + ?? new Dictionary(), + ResultType = "complete", + }, + _ => throw new InvalidOperationException($"Unknown task status: {info.Status}"), + }; + + private static async ValueTask> InvokeToolAsTask( + McpServerTool tool, + RequestContext request, + CancellationToken cancellationToken) + { + return await tool.InvokeAsync(request, cancellationToken).ConfigureAwait(false); + } - // tasks/get handler - Retrieve task status - McpRequestHandler getTaskHandler = async (request, cancellationToken) => + private McpRequestFilter BuildInitialCallToolFilter( + McpServerPrimitiveCollection? tools) => handler => + async (request, cancellationToken) => { - if (request.Params?.TaskId is not { } taskId) + if (request.Params?.Name is { } toolName && tools is not null && + tools.TryGetPrimitive(toolName, out var tool)) { - throw new McpProtocolException("Missing required parameter 'taskId'", McpErrorCode.InvalidParams); + request.MatchedPrimitive = tool; } - var task = await taskStore.GetTaskAsync(taskId, SessionId, cancellationToken).ConfigureAwait(false); - if (task is null) + try { - throw new McpProtocolException($"Task not found: '{taskId}'", McpErrorCode.InvalidParams); + var result = await handler(request, cancellationToken).ConfigureAwait(false); + ToolCallCompleted(request.Params?.Name ?? string.Empty, result.IsError is true); + return result; } - - return task; - }; - - // tasks/result handler - Retrieve task result (blocking until terminal status) - McpRequestHandler getTaskResultHandler = (request, cancellationToken) => - { - return new ValueTask(GetTaskResultAsync(request, cancellationToken)); - - async Task GetTaskResultAsync(RequestContext request, CancellationToken cancellationToken) + catch (Exception e) { - if (request.Params?.TaskId is not { } taskId) + ToolCallError(request.Params?.Name ?? string.Empty, e); + + if ((e is OperationCanceledException && cancellationToken.IsCancellationRequested) || e is McpProtocolException) { - throw new McpProtocolException("Missing required parameter 'taskId'", McpErrorCode.InvalidParams); + throw; } - // Poll until task reaches terminal status - while (true) + return new() { - McpTask? task = await taskStore.GetTaskAsync(taskId, SessionId, cancellationToken).ConfigureAwait(false); - if (task is null) - { - throw new McpProtocolException($"Task not found: '{taskId}'", McpErrorCode.InvalidParams); - } - - // If terminal, break and retrieve result - if (task.Status is McpTaskStatus.Completed or McpTaskStatus.Failed or McpTaskStatus.Cancelled) + IsError = true, + Content = [new TextContentBlock { - break; - } - - // Poll according to task's pollInterval (default 1 second) - var pollInterval = task.PollInterval ?? TimeSpan.FromSeconds(1); - await Task.Delay(pollInterval, cancellationToken).ConfigureAwait(false); - } - - // Retrieve the stored result - already stored as JsonElement - return await taskStore.GetTaskResultAsync(taskId, SessionId, cancellationToken).ConfigureAwait(false); + Text = e is McpException ? + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'.", + }], + }; } }; - // tasks/list handler - List tasks with pagination - McpRequestHandler listTasksHandler = async (request, cancellationToken) => - { - var cursor = request.Params?.Cursor; - return await taskStore.ListTasksAsync(cursor, SessionId, cancellationToken).ConfigureAwait(false); - }; - - // tasks/cancel handler - Cancel a task - McpRequestHandler cancelTaskHandler = async (request, cancellationToken) => + private McpRequestFilter> BuildInitialTaskToolFilter( + McpServerPrimitiveCollection? tools) => handler => + async (request, cancellationToken) => { - if (request.Params?.TaskId is not { } taskId) + if (request.Params?.Name is { } toolName && tools is not null && + tools.TryGetPrimitive(toolName, out var tool)) { - throw new McpProtocolException("Missing required parameter 'taskId'", McpErrorCode.InvalidParams); + request.MatchedPrimitive = tool; } - // Signal cancellation if task is still running - _taskCancellationTokenProvider!.Cancel(taskId); - - // Delegate to task store - it handles idempotent cancellation - var task = await taskStore.CancelTaskAsync(taskId, SessionId, cancellationToken).ConfigureAwait(false); - if (task is null) + try { - throw new McpProtocolException($"Task not found: '{taskId}'", McpErrorCode.InvalidParams); - } - - return task; - }; - - // Register handlers - SetHandler( - RequestMethods.TasksGet, - getTaskHandler, - McpJsonUtilities.JsonContext.Default.GetTaskRequestParams, - McpJsonUtilities.JsonContext.Default.McpTask); + var result = await handler(request, cancellationToken).ConfigureAwait(false); + if (!result.IsTask) + { + ToolCallCompleted(request.Params?.Name ?? string.Empty, result.Result!.IsError is true); + } - SetHandler( - RequestMethods.TasksResult, - getTaskResultHandler, - McpJsonUtilities.JsonContext.Default.GetTaskPayloadRequestParams, - McpJsonUtilities.JsonContext.Default.JsonElement); + return result; + } + catch (Exception e) + { + ToolCallError(request.Params?.Name ?? string.Empty, e); - SetHandler( - RequestMethods.TasksList, - listTasksHandler, - McpJsonUtilities.JsonContext.Default.ListTasksRequestParams, - McpJsonUtilities.JsonContext.Default.ListTasksResult); + if ((e is OperationCanceledException && cancellationToken.IsCancellationRequested) || e is McpProtocolException) + { + throw; + } - SetHandler( - RequestMethods.TasksCancel, - cancelTaskHandler, - McpJsonUtilities.JsonContext.Default.CancelMcpTaskRequestParams, - McpJsonUtilities.JsonContext.Default.McpTask); - } + return new CallToolResult + { + IsError = true, + Content = [new TextContentBlock + { + Text = e is McpException ? + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'.", + }], + }; + } + }; private void ConfigureLogging(McpServerOptions options) { @@ -1024,6 +1150,20 @@ private void SetHandler( requestTypeInfo, responseTypeInfo); } + private void SetTaskAugmentedHandler( + string method, + McpRequestHandler> handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo, + JsonTypeInfo taskResultTypeInfo) + where TResult : Result + { + _requestHandlers.SetTaskAugmented(method, + (request, jsonRpcRequest, cancellationToken) => + InvokeHandlerAsync(handler, request, jsonRpcRequest, cancellationToken), + requestTypeInfo, responseTypeInfo, taskResultTypeInfo); + } + private static McpRequestHandler BuildFilterPipeline( McpRequestHandler baseHandler, IList> filters, @@ -1117,160 +1257,4 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) => [LoggerMessage(Level = LogLevel.Information, Message = "ReadResource \"{ResourceUri}\" completed.")] private partial void ReadResourceCompleted(string resourceUri); - - /// - /// Executes a tool call as a task and returns a CallToolTaskResult immediately. - /// - private async ValueTask ExecuteToolAsTaskAsync( - McpServerTool tool, - RequestContext request, - McpTaskMetadata taskMetadata, - IMcpTaskStore? taskStore, - bool sendNotifications, - CancellationToken cancellationToken) - { - if (taskStore is null) - { - throw new McpProtocolException( - "Task-augmented requests are not supported. No task store configured.", - McpErrorCode.InvalidRequest); - } - - // Create the task in the task store - var mcpTask = await taskStore.CreateTaskAsync( - taskMetadata, - request.JsonRpcRequest.Id, - request.JsonRpcRequest, - SessionId, - cancellationToken).ConfigureAwait(false); - - // Register the task for TTL-based cancellation - var taskCancellationToken = _taskCancellationTokenProvider!.RequestToken(mcpTask.TaskId, mcpTask.TimeToLive); - - // Execute the tool asynchronously in the background - _ = Task.Run(async () => - { - // When per-request service scoping is enabled, InvokeHandlerAsync creates a new - // IServiceScope and disposes it once the handler returns. Since ExecuteToolAsTaskAsync - // returns immediately (before the tool runs), the scope is disposed before the tool - // gets a chance to resolve any DI services. Create a fresh scope here, tied to this - // background task's lifetime, so the tool's DI resolution uses a live provider. - var taskScope = _servicesScopePerRequest - ? Services?.GetService()?.CreateAsyncScope() - : null; - if (taskScope is not null) - { - request.Services = taskScope.Value.ServiceProvider; - } - - // Set up the task execution context for automatic input_required status tracking - TaskExecutionContext.Current = new TaskExecutionContext - { - TaskId = mcpTask.TaskId, - SessionId = SessionId, - TaskStore = taskStore, - SendNotifications = sendNotifications, - NotifyTaskStatusFunc = NotifyTaskStatusAsync - }; - - try - { - // Update task status to working - var workingTask = await taskStore.UpdateTaskStatusAsync( - mcpTask.TaskId, - McpTaskStatus.Working, - null, // statusMessage - SessionId, - CancellationToken.None).ConfigureAwait(false); - - // Send notification if enabled - if (sendNotifications) - { - _ = NotifyTaskStatusAsync(workingTask, CancellationToken.None); - } - - // Invoke the tool with task-specific cancellation token - var result = await tool.InvokeAsync(request, taskCancellationToken).ConfigureAwait(false); - ToolCallCompleted(request.Params?.Name ?? string.Empty, result.IsError is true); - - // Determine final status based on whether there was an error - var finalStatus = result.IsError is true ? McpTaskStatus.Failed : McpTaskStatus.Completed; - - // Store the result (serialize to JsonElement) - var resultElement = JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResult); - var finalTask = await taskStore.StoreTaskResultAsync( - mcpTask.TaskId, - finalStatus, - resultElement, - SessionId, - CancellationToken.None).ConfigureAwait(false); - - // Send final notification if enabled - if (sendNotifications) - { - _ = NotifyTaskStatusAsync(finalTask, CancellationToken.None); - } - } - catch (OperationCanceledException) when (taskCancellationToken.IsCancellationRequested) - { - // Task was cancelled via TTL expiration or explicit cancellation. - // For TTL expiration, the task is deleted so no status update needed. - // For explicit cancellation, the cancel handler already updates the status. - } - catch (Exception ex) - { - // Log the error - ToolCallError(request.Params?.Name ?? string.Empty, ex); - - // Store error result - var errorResult = new CallToolResult - { - IsError = true, - Content = [new TextContentBlock { Text = $"Task execution failed: {ex.Message}" }], - }; - - try - { - var errorResultElement = JsonSerializer.SerializeToElement(errorResult, McpJsonUtilities.JsonContext.Default.CallToolResult); - var failedTask = await taskStore.StoreTaskResultAsync( - mcpTask.TaskId, - McpTaskStatus.Failed, - errorResultElement, - SessionId, - CancellationToken.None).ConfigureAwait(false); - - // Send failure notification if enabled - if (sendNotifications) - { - _ = NotifyTaskStatusAsync(failedTask, CancellationToken.None); - } - } - catch - { - // If we can't store the error result, there's not much we can do - // The task will remain in "working" status, which will eventually be cleaned up - } - } - finally - { - // Clean up task execution context - TaskExecutionContext.Current = null; - - // Clean up task cancellation tracking - _taskCancellationTokenProvider!.Complete(mcpTask.TaskId); - - // Dispose the per-task service scope (if one was created) - if (taskScope is not null) - { - await taskScope.Value.DisposeAsync().ConfigureAwait(false); - } - } - }, CancellationToken.None); - - // Return the task result immediately - return new CallToolResult - { - Task = mcpTask - }; - } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs index 6da8bbfbe..32c13da27 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs @@ -1,6 +1,8 @@ using ModelContextProtocol.Protocol; using System.Diagnostics.CodeAnalysis; +#pragma warning disable MCPEXP001 + namespace ModelContextProtocol.Server; /// @@ -188,54 +190,17 @@ public McpServerFilters Filters public int MaxSamplingOutputTokens { get; set; } = 1000; /// - /// Gets or sets the task store for managing asynchronous task execution. + /// Gets or sets the task store for managing asynchronous task executions. /// /// /// - /// When non-null, enables explicit task support with persistence, allowing clients to: - /// - /// Execute operations asynchronously by augmenting requests with task metadata - /// Poll for task status via tasks/get requests - /// Retrieve task results via tasks/result requests - /// List all tasks via tasks/list requests - /// Cancel tasks via tasks/cancel requests - /// - /// - /// - /// When null, implicit task support may still be available for async methods (returning or - /// ), but tasks will be ephemeral and not persisted. Use - /// for development/testing or implement for production scenarios. + /// When set, the server automatically enables the io.modelcontextprotocol/tasks extension + /// and wires up tasks/get, tasks/update, and tasks/cancel handlers backed by this store. + /// Tool executions from clients that signal task support will be wrapped in tasks via the store. /// /// - /// The server will automatically advertise task capabilities based on the presence of a task store - /// and the detection of async server primitives (tools, prompts, resources). + /// If explicit task handlers are also set on , the explicit handlers take precedence. /// /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] public IMcpTaskStore? TaskStore { get; set; } - - /// - /// Gets or sets whether to send task status notifications to clients. - /// - /// - /// to send optional notifications/tasks/status notifications when task status changes; - /// to not send notifications. The default is . - /// - /// - /// - /// When enabled, the server will send notifications/tasks/status notifications to inform clients - /// of task state changes. According to the MCP specification, these notifications are optional and - /// receivers MAY send them but are not required to. - /// - /// - /// Clients must not rely on receiving these notifications and should continue polling via tasks/get - /// requests to ensure they receive status updates. - /// - /// - /// Even when this is set to , notifications are only sent when - /// is configured, as task-augmented requests require a task store. - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public bool SendTaskStatusNotifications { get; set; } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs index d67bac18c..34e77e2b4 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs @@ -157,7 +157,6 @@ public sealed class McpServerToolAttribute : Attribute internal bool? _idempotent; internal bool? _openWorld; internal bool? _readOnly; - internal ToolTaskSupport? _taskSupport; /// /// Initializes a new instance of the class. @@ -300,29 +299,4 @@ public bool ReadOnly /// /// public string? IconSource { get; set; } - - /// - /// Gets or sets the task support configuration for the tool. - /// - /// - /// A value indicating how the tool supports task-based invocation. - /// The default value is . - /// - /// - /// - /// When set to , clients must not attempt to invoke the tool as a task. - /// When set to , clients may invoke the tool as a task or as a normal request. - /// When set to , clients must invoke the tool as a task. - /// - /// - /// If this property is not explicitly set on the attribute, the task support behavior will be determined - /// automatically based on the tool's characteristics (e.g., async methods default to ). - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public ToolTaskSupport TaskSupport - { - get => _taskSupport ?? ToolTaskSupport.Forbidden; - set => _taskSupport = value; - } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs index 88d718d13..b0b6b3de7 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs @@ -197,23 +197,6 @@ public sealed class McpServerToolCreateOptions /// public JsonObject? Meta { get; set; } - /// - /// Gets or sets the execution hints for this tool. - /// - /// - /// - /// Execution hints provide information about how the tool should be invoked, including - /// task support level (). - /// - /// - /// If , the tool's execution settings are determined automatically based on - /// the method signature (async methods get ; sync methods - /// get ). - /// - /// - [Experimental(Experimentals.Tasks_DiagnosticId, UrlFormat = Experimentals.Tasks_Url)] - public ToolExecution? Execution { get; set; } - /// /// Creates a shallow clone of the current instance. /// @@ -235,6 +218,5 @@ internal McpServerToolCreateOptions Clone() => Metadata = Metadata, Icons = Icons, Meta = Meta, - Execution = Execution, }; } diff --git a/src/ModelContextProtocol.Core/Server/McpTaskExecutionContext.cs b/src/ModelContextProtocol.Core/Server/McpTaskExecutionContext.cs new file mode 100644 index 000000000..6b536172a --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpTaskExecutionContext.cs @@ -0,0 +1,22 @@ +using System.Diagnostics.CodeAnalysis; + +namespace ModelContextProtocol.Server; + +/// +/// Provides ambient context when a tool is executing as a background task. +/// When established, calls to +/// are redirected through the task store as input requests rather than sent directly to the client. +/// +[Experimental(Experimentals.Extensions_DiagnosticId, UrlFormat = Experimentals.Extensions_Url)] +internal sealed class McpTaskExecutionContext +{ + internal static readonly AsyncLocal Current = new(); + + public required string TaskId { get; init; } + public required IMcpTaskStore Store { get; init; } + + internal sealed class Scope(McpTaskExecutionContext? previous) : IDisposable + { + public void Dispose() => Current.Value = previous; + } +} diff --git a/src/ModelContextProtocol.Core/Server/McpTaskInfo.cs b/src/ModelContextProtocol.Core/Server/McpTaskInfo.cs new file mode 100644 index 000000000..00b20537a --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpTaskInfo.cs @@ -0,0 +1,28 @@ +using ModelContextProtocol.Protocol; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; + +namespace ModelContextProtocol.Server; + +/// +/// Represents the state of a task in an . +/// +/// +/// +/// This is the store's representation of a task, decoupled from the MCP protocol wire types. +/// The server infrastructure maps to the appropriate protocol response +/// types (, ) when communicating with clients. +/// +/// +[Experimental(Experimentals.Extensions_DiagnosticId, UrlFormat = Experimentals.Extensions_Url)] +public sealed record McpTaskInfo( + string TaskId, + McpTaskStatus Status, + DateTimeOffset CreatedAt, + DateTimeOffset LastUpdatedAt, + long? TtlMs = null, + long? PollIntervalMs = null, + string? StatusMessage = null, + JsonElement? Result = null, + JsonElement? Error = null, + IReadOnlyDictionary? InputRequests = null); diff --git a/src/ModelContextProtocol.Core/Server/TaskExecutionContext.cs b/src/ModelContextProtocol.Core/Server/TaskExecutionContext.cs deleted file mode 100644 index fc45835c4..000000000 --- a/src/ModelContextProtocol.Core/Server/TaskExecutionContext.cs +++ /dev/null @@ -1,47 +0,0 @@ -namespace ModelContextProtocol.Server; - -/// -/// Represents the execution context for a task being executed by the server. -/// This context flows with async execution and enables automatic task status updates. -/// -internal sealed class TaskExecutionContext -{ - /// - /// Gets the AsyncLocal instance used to track the current task execution context. - /// - private static readonly AsyncLocal s_current = new(); - - /// - /// Gets or sets the current task execution context for the executing async flow. - /// - public static TaskExecutionContext? Current - { - get => s_current.Value; - set => s_current.Value = value; - } - - /// - /// Gets the task ID of the currently executing task. - /// - public required string TaskId { get; init; } - - /// - /// Gets the session ID associated with the task. - /// - public string? SessionId { get; init; } - - /// - /// Gets the task store used to persist task state. - /// - public required IMcpTaskStore TaskStore { get; init; } - - /// - /// Gets whether task status notifications should be sent. - /// - public bool SendNotifications { get; init; } - - /// - /// Gets or sets the function to call when sending a task status notification. - /// - public Func? NotifyTaskStatusFunc { get; init; } -} diff --git a/src/ModelContextProtocol/McpServerOptionsSetup.cs b/src/ModelContextProtocol/McpServerOptionsSetup.cs index 5977fae7e..c46854460 100644 --- a/src/ModelContextProtocol/McpServerOptionsSetup.cs +++ b/src/ModelContextProtocol/McpServerOptionsSetup.cs @@ -9,12 +9,10 @@ namespace ModelContextProtocol; /// The individually registered tools. /// The individually registered prompts. /// The individually registered resources. -/// The optional task store registered in DI. internal sealed class McpServerOptionsSetup( IEnumerable serverTools, IEnumerable serverPrompts, - IEnumerable serverResources, - IMcpTaskStore? taskStore = null) : IConfigureOptions + IEnumerable serverResources) : IConfigureOptions { /// /// Configures the given McpServerOptions instance by setting server information @@ -25,8 +23,6 @@ public void Configure(McpServerOptions options) { Throw.IfNull(options); - options.TaskStore ??= taskStore; - // Collect all of the provided tools into a tools collection. If the options already has // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index 231eb073a..07167c438 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -8,7 +8,7 @@ .NET SDK for the Model Context Protocol (MCP) with hosting and dependency injection extensions. README.md True - + $(NoWarn);MCPEXP001 diff --git a/tests/Common/Utils/TestServerTransport.cs b/tests/Common/Utils/TestServerTransport.cs index 43cd5c262..ed9b6ee72 100644 --- a/tests/Common/Utils/TestServerTransport.cs +++ b/tests/Common/Utils/TestServerTransport.cs @@ -46,14 +46,6 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can await SamplingAsync(request, cancellationToken); else if (request.Method == RequestMethods.ElicitationCreate) await ElicitAsync(request, cancellationToken); - else if (request.Method == RequestMethods.TasksGet) - await TasksGetAsync(request, cancellationToken); - else if (request.Method == RequestMethods.TasksResult) - await TasksResultAsync(request, cancellationToken); - else if (request.Method == RequestMethods.TasksList) - await TasksListAsync(request, cancellationToken); - else if (request.Method == RequestMethods.TasksCancel) - await TasksCancelAsync(request, cancellationToken); else await WriteMessageAsync(request, cancellationToken); } @@ -79,161 +71,21 @@ await WriteMessageAsync(new JsonRpcResponse private async Task SamplingAsync(JsonRpcRequest request, CancellationToken cancellationToken) { - // Check if the request is task-augmented (has Task metadata) - var requestParams = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.DefaultOptions); - if (requestParams?.Task is not null && MockTask is not null) - { - // Return a task-augmented response - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CreateTaskResult { Task = MockTask }, McpJsonUtilities.DefaultOptions), - }, cancellationToken); - } - else - { - // Return a normal sampling response - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = [new TextContentBlock { Text = "" }], Model = "model" }, McpJsonUtilities.DefaultOptions), - }, cancellationToken); - } - } - - private async Task ElicitAsync(JsonRpcRequest request, CancellationToken cancellationToken) - { - // Check if the request is task-augmented (has Task metadata) - var requestParams = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.DefaultOptions); - if (requestParams?.Task is not null && MockTask is not null) - { - // Return a task-augmented response - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CreateTaskResult { Task = MockTask }, McpJsonUtilities.DefaultOptions), - }, cancellationToken); - } - else - { - // Return a normal elicitation response - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = JsonSerializer.SerializeToNode(new ElicitResult { Action = "decline" }, McpJsonUtilities.DefaultOptions), - }, cancellationToken); - } - } - - /// - /// Gets or sets the task to return from tasks/get requests. - /// - public McpTask? MockTask { get; set; } - - /// - /// Gets or sets the result to return from tasks/result requests. - /// - public object? MockTaskResult { get; set; } - - /// - /// Gets or sets the list of tasks to return from tasks/list requests. - /// - public McpTask[]? MockTaskList { get; set; } - - private async Task TasksGetAsync(JsonRpcRequest request, CancellationToken cancellationToken) - { - var task = MockTask ?? new McpTask - { - TaskId = "test-task-id", - Status = McpTaskStatus.Completed, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - + // Return a normal sampling response await WriteMessageAsync(new JsonRpcResponse { Id = request.Id, - Result = JsonSerializer.SerializeToNode(new GetTaskResult - { - TaskId = task.TaskId, - Status = task.Status, - StatusMessage = task.StatusMessage, - CreatedAt = task.CreatedAt, - LastUpdatedAt = task.LastUpdatedAt, - TimeToLive = task.TimeToLive, - PollInterval = task.PollInterval - }, McpJsonUtilities.DefaultOptions), + Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = [new TextContentBlock { Text = "" }], Model = "model" }, McpJsonUtilities.DefaultOptions), }, cancellationToken); } - private async Task TasksResultAsync(JsonRpcRequest request, CancellationToken cancellationToken) - { - var result = MockTaskResult ?? new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Task result" }], - Model = "test-model" - }; - - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = JsonSerializer.SerializeToNode(result, McpJsonUtilities.DefaultOptions), - }, cancellationToken); - } - - private async Task TasksListAsync(JsonRpcRequest request, CancellationToken cancellationToken) - { - var tasks = MockTaskList ?? [ - new McpTask - { - TaskId = "task-1", - Status = McpTaskStatus.Completed, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }, - new McpTask - { - TaskId = "task-2", - Status = McpTaskStatus.Working, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-3), - LastUpdatedAt = DateTimeOffset.UtcNow, - } - ]; - - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = JsonSerializer.SerializeToNode(new ListTasksResult - { - Tasks = tasks, - }, McpJsonUtilities.DefaultOptions), - }, cancellationToken); - } - - private async Task TasksCancelAsync(JsonRpcRequest request, CancellationToken cancellationToken) + private async Task ElicitAsync(JsonRpcRequest request, CancellationToken cancellationToken) { - var task = MockTask ?? new McpTask - { - TaskId = "test-task-id", - Status = McpTaskStatus.Cancelled, - StatusMessage = "Task cancelled by request", - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - + // Return a normal elicitation response await WriteMessageAsync(new JsonRpcResponse { Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CancelMcpTaskResult - { - TaskId = task.TaskId, - Status = McpTaskStatus.Cancelled, - StatusMessage = task.StatusMessage ?? "Task cancelled", - CreatedAt = task.CreatedAt, - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = task.TimeToLive, - PollInterval = task.PollInterval - }, McpJsonUtilities.DefaultOptions), + Result = JsonSerializer.SerializeToNode(new ElicitResult { Action = "decline" }, McpJsonUtilities.DefaultOptions), }, cancellationToken); } diff --git a/tests/Directory.Build.props b/tests/Directory.Build.props index 1071ec394..bc169333f 100644 --- a/tests/Directory.Build.props +++ b/tests/Directory.Build.props @@ -3,7 +3,7 @@ True - + $(NoWarn);MCPEXP001 $(NoWarn);MCP9004 diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpTaskIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpTaskIntegrationTests.cs deleted file mode 100644 index 2b74fcd14..000000000 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpTaskIntegrationTests.cs +++ /dev/null @@ -1,342 +0,0 @@ -using Microsoft.AspNetCore.Builder; -using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.AspNetCore.Tests.Utils; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.ComponentModel; -using System.Text.Json; - -namespace ModelContextProtocol.AspNetCore.Tests; - -/// -/// Integration tests for MCP Tasks feature over HTTP transports. -/// Tests task creation, polling, cancellation, and result retrieval. -/// -public class HttpTaskIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper) -{ - private readonly HttpClientTransportOptions DefaultTransportOptions = new() - { - Endpoint = new("http://localhost:5000/"), - Name = "In-memory Streamable HTTP Client", - }; - - private Task ConnectMcpClientAsync( - HttpClient? httpClient = null, - HttpClientTransportOptions? transportOptions = null, - McpClientOptions? clientOptions = null) - => McpClient.CreateAsync( - new HttpClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), - clientOptions, - LoggerFactory, - TestContext.Current.CancellationToken); - - private static IDictionary CreateArguments(string key, object? value) - { - return new Dictionary - { - [key] = JsonSerializer.SerializeToElement(value, McpJsonUtilities.DefaultOptions) - }; - } - - [Fact] - public async Task CallToolAsTask_ReturnsTask_WhenServerSupportsTasksAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - Builder.Services.AddMcpServer(options => - { - options.TaskStore = taskStore; - }) - .WithHttpTransport() - .WithTools(); - - await using var app = Builder.Build(); - app.MapMcp(); - await app.StartAsync(TestContext.Current.CancellationToken); - - await using var client = await ConnectMcpClientAsync(); - - // Act - Call tool with task augmentation - var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "long_running_operation", - Arguments = CreateArguments("durationMs", 100), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - - // Assert - Response should indicate task was created - Assert.NotNull(result); - Assert.Null(result.IsError); - } - - [Fact] - public async Task GetTaskAsync_ReturnsTaskStatus_WhenTaskExistsAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - Builder.Services.AddMcpServer(options => - { - options.TaskStore = taskStore; - }) - .WithHttpTransport() - .WithTools(); - - await using var app = Builder.Build(); - app.MapMcp(); - await app.StartAsync(TestContext.Current.CancellationToken); - - await using var client = await ConnectMcpClientAsync(); - - // First create a task by calling a tool with task augmentation - _ = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "long_running_operation", - Arguments = CreateArguments("durationMs", 500), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - - // Get all tasks - var tasks = await client.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.NotEmpty(tasks); - - // Act - Get the task status - var task = await client.GetTaskAsync(tasks[0].TaskId, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task); - Assert.Equal(tasks[0].TaskId, task.TaskId); - } - - [Fact] - public async Task ListTasksAsync_ReturnsTasks_WhenTasksExistAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - Builder.Services.AddMcpServer(options => - { - options.TaskStore = taskStore; - }) - .WithHttpTransport() - .WithTools(); - - await using var app = Builder.Build(); - app.MapMcp(); - await app.StartAsync(TestContext.Current.CancellationToken); - - await using var client = await ConnectMcpClientAsync(); - - // Create multiple tasks - for (int i = 0; i < 3; i++) - { - await client.CallToolAsync( - new CallToolRequestParams - { - Name = "long_running_operation", - Arguments = CreateArguments("durationMs", 1000), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - } - - // Act - var tasks = await client.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(tasks); - Assert.Equal(3, tasks.Count); - } - - [Fact] - public async Task CancelTaskAsync_CancelsTask_WhenTaskIsRunningAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - Builder.Services.AddMcpServer(options => - { - options.TaskStore = taskStore; - }) - .WithHttpTransport() - .WithTools(); - - await using var app = Builder.Build(); - app.MapMcp(); - await app.StartAsync(TestContext.Current.CancellationToken); - - await using var client = await ConnectMcpClientAsync(); - - // Create a long-running task - await client.CallToolAsync( - new CallToolRequestParams - { - Name = "long_running_operation", - Arguments = CreateArguments("durationMs", 10000), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - - var tasks = await client.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.NotEmpty(tasks); - - // Act - Cancel the task - var cancelledTask = await client.CancelTaskAsync(tasks[0].TaskId, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(cancelledTask); - Assert.Equal(McpTaskStatus.Cancelled, cancelledTask.Status); - } - - [Fact] - public async Task GetTaskResultAsync_ReturnsResult_WhenTaskCompletesAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - Builder.Services.AddMcpServer(options => - { - options.TaskStore = taskStore; - }) - .WithHttpTransport() - .WithTools(); - - await using var app = Builder.Build(); - app.MapMcp(); - await app.StartAsync(TestContext.Current.CancellationToken); - - await using var client = await ConnectMcpClientAsync(); - - // Create a quick task - await client.CallToolAsync( - new CallToolRequestParams - { - Name = "long_running_operation", - Arguments = CreateArguments("durationMs", 50), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - - var tasks = await client.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.NotEmpty(tasks); - - // Wait a bit for the task to complete - await Task.Delay(200, TestContext.Current.CancellationToken); - - // Act - Get the task result - var result = await client.GetTaskResultAsync(tasks[0].TaskId, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotEqual(default, result); - } - - [Fact] - public async Task TasksIsolated_BetweenSessions_WhenMultipleClientsConnectAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - Builder.Services.AddMcpServer(options => - { - options.TaskStore = taskStore; - }) - .WithHttpTransport() - .WithTools(); - - await using var app = Builder.Build(); - app.MapMcp(); - await app.StartAsync(TestContext.Current.CancellationToken); - - // Connect two separate clients - await using var client1 = await ConnectMcpClientAsync(); - await using var client2 = await ConnectMcpClientAsync(); - - // Client 1 creates a task - await client1.CallToolAsync( - new CallToolRequestParams - { - Name = "long_running_operation", - Arguments = CreateArguments("durationMs", 1000), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - - // Act - Both clients list tasks - var client1Tasks = await client1.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - var client2Tasks = await client2.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Tasks should be isolated by session - Assert.Single(client1Tasks); - Assert.Empty(client2Tasks); - } - - [Fact] - public async Task ServerCapabilities_IncludesTasks_WhenTaskStoreConfiguredAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - Builder.Services.AddMcpServer(options => - { - options.TaskStore = taskStore; - }) - .WithHttpTransport() - .WithTools(); - - await using var app = Builder.Build(); - app.MapMcp(); - await app.StartAsync(TestContext.Current.CancellationToken); - - // Act - await using var client = await ConnectMcpClientAsync(); - - // Assert - Assert.NotNull(client.ServerCapabilities?.Tasks); - } - - [Fact] - public async Task ListTools_ShowsTaskSupport_WhenToolIsAsyncAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - Builder.Services.AddMcpServer(options => - { - options.TaskStore = taskStore; - }) - .WithHttpTransport() - .WithTools(); - - await using var app = Builder.Build(); - app.MapMcp(); - await app.StartAsync(TestContext.Current.CancellationToken); - - await using var client = await ConnectMcpClientAsync(); - - // Act - var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - var asyncTool = tools.FirstOrDefault(t => t.Name == "long_running_operation"); - Assert.NotNull(asyncTool); - Assert.NotNull(asyncTool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, asyncTool.ProtocolTool.Execution.TaskSupport); - } - - [McpServerToolType] - public sealed class LongRunningTools - { - [McpServerTool, Description("Simulates a long-running operation")] - public static async Task LongRunningOperation( - [Description("Duration of the operation in milliseconds")] int durationMs, - CancellationToken cancellationToken) - { - await Task.Delay(durationMs, cancellationToken); - return $"Operation completed after {durationMs}ms"; - } - - [McpServerTool, Description("A synchronous tool that does not support tasks")] - public static string SyncTool([Description("Input message")] string message) - { - return $"Sync result: {message}"; - } - } -} diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 9cb963a96..0765c7450 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -162,27 +162,6 @@ private static void ConfigureTools(McpServerOptions options, string? cliArg) """), }, new Tool - { - Name = "longRunning", - Description = "Simulates a long-running operation that supports task-based execution.", - InputSchema = JsonElement.Parse(""" - { - "type": "object", - "properties": { - "durationMs": { - "type": "number", - "description": "Duration of the operation in milliseconds" - } - }, - "required": ["durationMs"] - } - """), - Execution = new ToolExecution - { - TaskSupport = ToolTaskSupport.Optional - } - }, - new Tool { Name = "crash", Description = "Terminates the server process with a specified exit code.", @@ -245,19 +224,6 @@ private static void ConfigureTools(McpServerOptions options, string? cliArg) Content = [new TextContentBlock { Text = cliArg ?? "null" }] }; } - else if (request.Params.Name == "longRunning") - { - if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("durationMs", out var durationMsValue)) - { - throw new McpProtocolException("Missing required argument 'durationMs'", McpErrorCode.InvalidParams); - } - int durationMs = Convert.ToInt32(durationMsValue.GetRawText()); - await Task.Delay(durationMs, cancellationToken); - return new CallToolResult - { - Content = [new TextContentBlock { Text = $"Long-running operation completed after {durationMs}ms" }] - }; - } else if (request.Params.Name == "crash") { if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("exitCode", out var exitCodeValue)) diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index a36a0a6e0..a6f37f2a6 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -146,27 +146,6 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } """), }, - new Tool - { - Name = "longRunning", - Description = "Simulates a long-running operation that supports task-based execution.", - InputSchema = JsonElement.Parse(""" - { - "type": "object", - "properties": { - "durationMs": { - "type": "number", - "description": "Duration of the operation in milliseconds" - } - }, - "required": ["durationMs"] - } - """), - Execution = new ToolExecution - { - TaskSupport = ToolTaskSupport.Optional - } - } ] }; }, @@ -212,19 +191,6 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st Content = [new TextContentBlock { Text = $"LLM sampling result: {sampleResult.Content.OfType().FirstOrDefault()?.Text}" }] }; } - else if (request.Params.Name == "longRunning") - { - if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("durationMs", out var durationMsValue)) - { - throw new McpProtocolException("Missing required argument 'durationMs'", McpErrorCode.InvalidParams); - } - int durationMs = Convert.ToInt32(durationMsValue.ToString()); - await Task.Delay(durationMs, cancellationToken); - return new CallToolResult - { - Content = [new TextContentBlock { Text = $"Long-running operation completed after {durationMs}ms" }] - }; - } else { throw new McpProtocolException($"Unknown tool: '{request.Params.Name}'", McpErrorCode.InvalidParams); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTaskMethodsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTaskMethodsTests.cs deleted file mode 100644 index ada9970cf..000000000 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTaskMethodsTests.cs +++ /dev/null @@ -1,261 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Client; - -public class McpClientTaskMethodsTests : ClientServerTestBase -{ - public McpClientTaskMethodsTests(ITestOutputHelper outputHelper) - : base(outputHelper) - { - } - - protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) - { - // Add task store for server-side task support - var taskStore = new InMemoryMcpTaskStore(); - services.AddSingleton(taskStore); - - // Configure server to use the task store directly - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - // Add a simple tool for testing - mcpServerBuilder.WithTools([McpServerTool.Create( - async (string input, CancellationToken ct) => - { - await Task.Delay(50, ct); - return $"Processed: {input}"; - }, - new McpServerToolCreateOptions - { - Name = "test-tool", - Description = "A test tool" - })]); - } - - private static IDictionary CreateArguments(string key, object? value) - { - // For simple strings, just create a JsonElement from a string value - return new Dictionary - { - [key] = JsonDocument.Parse($"\"{value}\"").RootElement.Clone() - }; - } - - [Fact] - public async Task GetTaskAsync_ReturnsTaskStatus() - { - await using McpClient client = await CreateMcpClientForServer(); - - // Create a task by calling a tool with task metadata - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = CreateArguments("input", "test"), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - // The response should contain task metadata - Assert.NotNull(callResult.Task); - - string taskId = callResult.Task.TaskId; - - // Now get the task status - var task = await client.GetTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - - Assert.Equal(taskId, task.TaskId); - } - - [Fact] - public async Task GetTaskAsync_ThrowsForInvalidTaskId() - { - await using McpClient client = await CreateMcpClientForServer(); - - await Assert.ThrowsAsync(async () => - await client.GetTaskAsync("", cancellationToken: TestContext.Current.CancellationToken)); - } - - [Fact] - public async Task GetTaskResultAsync_ReturnsDeserializedResult() - { - await using McpClient client = await CreateMcpClientForServer(); - - // Create a task - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = CreateArguments("input", "hello"), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - string taskId = callResult.Task.TaskId; - - // Wait for task to complete and get the result - JsonElement result = await client.GetTaskResultAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - - // Verify the result has the expected CallToolResult shape - CallToolResult? toolResult = result.Deserialize(McpJsonUtilities.DefaultOptions); - Assert.NotNull(toolResult); - Assert.NotEmpty(toolResult.Content); - - TextContentBlock? textContent = toolResult.Content[0] as TextContentBlock; - Assert.NotNull(textContent); - Assert.Equal("Processed: hello", textContent.Text); - } - - [Fact] - public async Task GetTaskResultAsync_ThrowsForInvalidTaskId() - { - await using McpClient client = await CreateMcpClientForServer(); - - await Assert.ThrowsAsync(async () => - await client.GetTaskResultAsync("", cancellationToken: TestContext.Current.CancellationToken)); - } - - [Fact] - public async Task ListTasksAsync_ReturnsTasks() - { - await using McpClient client = await CreateMcpClientForServer(); - - // Create a task - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = CreateArguments("input", "test"), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - string taskId = callResult.Task.TaskId; - - // List all tasks - var tasks = await client.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(tasks); - Assert.Contains(tasks, t => t.TaskId == taskId); - } - - [Fact] - public async Task ListTasksAsync_HandlesEmptyResult() - { - await using McpClient client = await CreateMcpClientForServer(); - - // List tasks (may or may not be empty depending on state) - var tasks = await client.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(tasks); - } - - [Fact] - public async Task ListTasksAsync_LowLevel_ReturnsRawResult() - { - await using McpClient client = await CreateMcpClientForServer(); - - // Create a task first - await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = CreateArguments("input", "task1"), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - // Use low-level API - var result = await client.ListTasksAsync(new ListTasksRequestParams(), TestContext.Current.CancellationToken); - - Assert.NotNull(result); - Assert.NotNull(result.Tasks); - } - - [Fact] - public async Task ListTasksAsync_LowLevel_ThrowsForNullParams() - { - await using McpClient client = await CreateMcpClientForServer(); - - await Assert.ThrowsAsync(async () => - await client.ListTasksAsync((ListTasksRequestParams)null!, TestContext.Current.CancellationToken)); - } - - [Fact] - public async Task CancelTaskAsync_CancelsRunningTask() - { - await using McpClient client = await CreateMcpClientForServer(); - - // Create a task - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = CreateArguments("input", "test"), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - string taskId = callResult.Task.TaskId; - - // Cancel the task - var canceledTask = await client.CancelTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - - Assert.Equal(taskId, canceledTask.TaskId); - } - - [Fact] - public async Task CancelTaskAsync_ThrowsForInvalidTaskId() - { - await using McpClient client = await CreateMcpClientForServer(); - - await Assert.ThrowsAsync(async () => - await client.CancelTaskAsync("", cancellationToken: TestContext.Current.CancellationToken)); - } - - [Fact] - public async Task ListTasksAsync_HandlesPagination() - { - await using McpClient client = await CreateMcpClientForServer(); - - // Create multiple tasks - var taskIds = new List(); - for (int i = 0; i < 3; i++) - { - var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = CreateArguments("input", $"task-{i}"), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(result.Task); - taskIds.Add(result.Task.TaskId); - } - - // List all tasks (should handle pagination automatically if needed) - var tasks = await client.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(tasks); - Assert.True(tasks.Count >= taskIds.Count, "Should retrieve at least the tasks we created"); - - // Verify all our tasks are in the result - foreach (var taskId in taskIds) - { - Assert.Contains(tasks, t => t.TaskId == taskId); - } - } -} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTaskSamplingElicitationTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTaskSamplingElicitationTests.cs deleted file mode 100644 index 906b4f491..000000000 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTaskSamplingElicitationTests.cs +++ /dev/null @@ -1,867 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Client; - -/// -/// Integration tests for task-based sampling and elicitation on the client side. -/// Tests the client's ability to receive task-augmented requests from the server, -/// execute them as tasks, and report results. -/// -public class McpClientTaskSamplingElicitationTests : ClientServerTestBase -{ - public McpClientTaskSamplingElicitationTests(ITestOutputHelper outputHelper) - : base(outputHelper) - { - } - - protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) - { - // Add task store for server-side task support - var taskStore = new InMemoryMcpTaskStore(); - services.AddSingleton(taskStore); - - // Configure server to use the task store - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - // Add a tool that uses sampling to generate responses - mcpServerBuilder.WithTools([McpServerTool.Create( - async (string prompt, McpServer server, CancellationToken ct) => - { - // This tool requests sampling from the client - var result = await server.SampleAsync(new CreateMessageRequestParams - { - Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = prompt }] }], - MaxTokens = 100 - }, ct); - - return result.Content.OfType().FirstOrDefault()?.Text ?? "No response"; - }, - new McpServerToolCreateOptions - { - Name = "sample-tool", - Description = "A tool that uses sampling" - }), - McpServerTool.Create( - async (string message, McpServer server, CancellationToken ct) => - { - // This tool requests elicitation from the client - var result = await server.ElicitAsync(new ElicitRequestParams - { - Message = message, - RequestedSchema = new() - }, ct); - - return result.Action == "confirm" ? "Confirmed" : "Declined"; - }, - new McpServerToolCreateOptions - { - Name = "elicit-tool", - Description = "A tool that uses elicitation" - })]); - } - - private static IDictionary CreateArguments(string key, object? value) - { - return new Dictionary - { - [key] = JsonDocument.Parse($"\"{value}\"").RootElement.Clone() - }; - } - - #region Client Task-Based Sampling Tests - - [Fact] - public async Task Client_WithTaskStoreAndSamplingHandler_AdvertisesTaskAugmentedSamplingCapability() - { - // Arrange - Create client with task store and sampling handler - var taskStore = new InMemoryMcpTaskStore(); - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - SamplingHandler = (request, progress, ct) => - { - return new ValueTask(new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Sampled response" }], - Model = "test-model" - }); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // The server should see the client's task capabilities - // We verify by checking server can use task-augmented requests - Assert.NotNull(Server.ClientCapabilities); - Assert.NotNull(Server.ClientCapabilities.Sampling); - Assert.NotNull(Server.ClientCapabilities.Tasks); - Assert.NotNull(Server.ClientCapabilities.Tasks.Requests?.Sampling?.CreateMessage); - } - - [Fact] - public async Task Client_WithoutTaskStore_DoesNotAdvertiseTaskAugmentedSamplingCapability() - { - // Arrange - Create client with sampling handler but NO task store - var clientOptions = new McpClientOptions - { - // No TaskStore configured - Handlers = new McpClientHandlers - { - SamplingHandler = (request, progress, ct) => - { - return new ValueTask(new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Sampled response" }], - Model = "test-model" - }); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // The server should see sampling capability but NOT task-augmented sampling - Assert.NotNull(Server.ClientCapabilities); - Assert.NotNull(Server.ClientCapabilities.Sampling); - - // Task capabilities should be null (no task store) - Assert.Null(Server.ClientCapabilities.Tasks); - } - - [Fact] - public async Task Server_SampleAsTaskAsync_FailsWhenClientDoesNotSupportTaskAugmentedSampling() - { - // Arrange - Client with sampling handler but NO task store - var clientOptions = new McpClientOptions - { - Handlers = new McpClientHandlers - { - SamplingHandler = (request, progress, ct) => - { - return new ValueTask(new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Response" }], - Model = "model" - }); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Act & Assert - Server should throw when trying to use task-augmented sampling - var exception = await Assert.ThrowsAsync(async () => - { - await Server.SampleAsTaskAsync( - new CreateMessageRequestParams - { - Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Test" }] }], - MaxTokens = 100 - }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - }); - - Assert.Contains("task-augmented sampling", exception.Message, StringComparison.OrdinalIgnoreCase); - } - - [Fact] - public async Task Client_WithTaskStore_CanExecuteSamplingAsTask() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var samplingCompleted = new TaskCompletionSource(); - - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - SamplingHandler = async (request, progress, ct) => - { - // Simulate some work - await Task.Delay(50, ct); - samplingCompleted.TrySetResult(true); - return new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Task-based sampling response" }], - Model = "test-model" - }; - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Act - Server requests task-augmented sampling - var mcpTask = await Server.SampleAsTaskAsync( - new CreateMessageRequestParams - { - Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Hello" }] }], - MaxTokens = 100 - }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - - // Assert - Task was created - Assert.NotNull(mcpTask); - Assert.NotEmpty(mcpTask.TaskId); - Assert.Equal(McpTaskStatus.Working, mcpTask.Status); - - // Wait for sampling to complete - await samplingCompleted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Poll until task is complete - McpTask taskStatus; - do - { - await Task.Delay(100, TestContext.Current.CancellationToken); - taskStatus = await Server.GetTaskAsync(mcpTask.TaskId, TestContext.Current.CancellationToken); - } - while (taskStatus.Status == McpTaskStatus.Working); - - Assert.Equal(McpTaskStatus.Completed, taskStatus.Status); - - // Get the result - var result = await Server.GetTaskResultAsync( - mcpTask.TaskId, cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(result); - var textContent = Assert.IsType(Assert.Single(result.Content)); - Assert.Equal("Task-based sampling response", textContent.Text); - } - - #endregion - - #region Client Task-Based Elicitation Tests - - [Fact] - public async Task Client_WithTaskStoreAndElicitationHandler_AdvertisesTaskAugmentedElicitationCapability() - { - // Arrange - Create client with task store and elicitation handler - var taskStore = new InMemoryMcpTaskStore(); - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - ElicitationHandler = (request, ct) => - { - return new ValueTask(new ElicitResult { Action = "confirm" }); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Verify client advertised task-augmented elicitation - Assert.NotNull(Server.ClientCapabilities); - Assert.NotNull(Server.ClientCapabilities.Elicitation); - Assert.NotNull(Server.ClientCapabilities.Tasks); - Assert.NotNull(Server.ClientCapabilities.Tasks.Requests?.Elicitation?.Create); - } - - [Fact] - public async Task Client_WithoutTaskStore_DoesNotAdvertiseTaskAugmentedElicitationCapability() - { - // Arrange - Create client with elicitation handler but NO task store - var clientOptions = new McpClientOptions - { - // No TaskStore configured - Handlers = new McpClientHandlers - { - ElicitationHandler = (request, ct) => - { - return new ValueTask(new ElicitResult { Action = "confirm" }); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Verify elicitation is supported but NOT task-augmented - Assert.NotNull(Server.ClientCapabilities); - Assert.NotNull(Server.ClientCapabilities.Elicitation); - Assert.Null(Server.ClientCapabilities.Tasks); - } - - [Fact] - public async Task Server_ElicitAsTaskAsync_FailsWhenClientDoesNotSupportTaskAugmentedElicitation() - { - // Arrange - Client with elicitation handler but NO task store - var clientOptions = new McpClientOptions - { - Handlers = new McpClientHandlers - { - ElicitationHandler = (request, ct) => - { - return new ValueTask(new ElicitResult { Action = "confirm" }); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Act & Assert - Server should throw when trying to use task-augmented elicitation - var exception = await Assert.ThrowsAsync(async () => - { - await Server.ElicitAsTaskAsync( - new ElicitRequestParams - { - Message = "Please confirm", - RequestedSchema = new() - }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - }); - - Assert.Contains("task-augmented elicitation", exception.Message, StringComparison.OrdinalIgnoreCase); - } - - [Fact] - public async Task Client_WithTaskStore_CanExecuteElicitationAsTask() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var elicitationCompleted = new TaskCompletionSource(); - - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - ElicitationHandler = async (request, ct) => - { - // Simulate user interaction time - await Task.Delay(50, ct); - elicitationCompleted.TrySetResult(true); - return new ElicitResult - { - Action = "accept", - Content = new Dictionary - { - ["answer"] = JsonDocument.Parse("\"yes\"").RootElement.Clone() - } - }; - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Act - Server requests task-augmented elicitation - var mcpTask = await Server.ElicitAsTaskAsync( - new ElicitRequestParams - { - Message = "Do you want to proceed?", - RequestedSchema = new() - }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - - // Assert - Task was created - Assert.NotNull(mcpTask); - Assert.NotEmpty(mcpTask.TaskId); - Assert.Equal(McpTaskStatus.Working, mcpTask.Status); - - // Wait for elicitation to complete - await elicitationCompleted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Poll until task is complete - McpTask taskStatus; - do - { - await Task.Delay(100, TestContext.Current.CancellationToken); - taskStatus = await Server.GetTaskAsync(mcpTask.TaskId, TestContext.Current.CancellationToken); - } - while (taskStatus.Status == McpTaskStatus.Working); - - Assert.Equal(McpTaskStatus.Completed, taskStatus.Status); - - // Get the result - var result = await Server.GetTaskResultAsync( - mcpTask.TaskId, cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(result); - Assert.Equal("accept", result.Action); - } - - #endregion - - #region Client Task Reporting Tests - - [Fact] - public async Task Client_CanListOwnTasks() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - SamplingHandler = async (request, progress, ct) => - { - await Task.Delay(50, ct); - return new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Response" }], - Model = "model" - }; - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Create multiple tasks - var task1 = await Server.SampleAsTaskAsync( - new CreateMessageRequestParams { Messages = [], MaxTokens = 100 }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - - var task2 = await Server.SampleAsTaskAsync( - new CreateMessageRequestParams { Messages = [], MaxTokens = 100 }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - - // Act - Server lists tasks from client - var tasks = await Server.ListTasksAsync(TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(tasks); - Assert.True(tasks.Count >= 2, "Should have at least 2 tasks"); - Assert.Contains(tasks, t => t.TaskId == task1.TaskId); - Assert.Contains(tasks, t => t.TaskId == task2.TaskId); - } - - [Fact] - public async Task Client_CanCancelTasks() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var samplingStarted = new TaskCompletionSource(); - var allowCompletion = new TaskCompletionSource(); - - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - SamplingHandler = async (request, progress, ct) => - { - samplingStarted.TrySetResult(true); - // Wait for either completion signal or cancellation - try - { - await allowCompletion.Task.WaitAsync(ct); - } - catch (OperationCanceledException) - { - throw; - } - return new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Should not reach here" }], - Model = "model" - }; - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Create a task that will be in progress - var mcpTask = await Server.SampleAsTaskAsync( - new CreateMessageRequestParams { Messages = [], MaxTokens = 100 }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - - // Wait for sampling to start - await samplingStarted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Act - Cancel the task - var cancelledTask = await Server.CancelTaskAsync(mcpTask.TaskId, TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(cancelledTask); - Assert.Equal(McpTaskStatus.Cancelled, cancelledTask.Status); - - // Allow completion to avoid hanging (the handler might still be running) - allowCompletion.TrySetResult(true); - } - - [Fact] - public async Task Client_TaskStatusNotifications_SentWhenEnabled() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var workingNotificationReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var completedNotificationReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var notificationsReceived = new List(); - var notificationsLock = new object(); - string? expectedTaskId = null; - var expectedTaskIdLock = new object(); - - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - SendTaskStatusNotifications = true, - Handlers = new McpClientHandlers - { - SamplingHandler = async (request, progress, ct) => - { - await Task.Delay(100, ct); - return new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Done" }], - Model = "model" - }; - } - } - }; - - // Register notification handler on the server BEFORE creating the client - var notificationHandler = Server.RegisterNotificationHandler( - NotificationMethods.TaskStatusNotification, - (notification, ct) => - { - if (notification.Params is not { } paramsNode) - { - return default; - } - - var taskNotification = JsonSerializer.Deserialize( - paramsNode, McpJsonUtilities.DefaultOptions); - if (taskNotification is null) - { - return default; - } - - // Only track notifications for our task - string? taskId; - lock (expectedTaskIdLock) - { - taskId = expectedTaskId; - } - if (taskId is not null && taskNotification.TaskId != taskId) - { - return default; - } - - lock (notificationsLock) - { - notificationsReceived.Add(new McpTask - { - TaskId = taskNotification.TaskId, - Status = taskNotification.Status, - CreatedAt = taskNotification.CreatedAt, - LastUpdatedAt = taskNotification.LastUpdatedAt - }); - } - - // Signal when we receive the Working and Completed notifications - if (taskNotification.Status == McpTaskStatus.Working) - { - workingNotificationReceived.TrySetResult(true); - } - else if (taskNotification.Status == McpTaskStatus.Completed) - { - completedNotificationReceived.TrySetResult(true); - } - - return default; - }); - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Act - Create a task - var mcpTask = await Server.SampleAsTaskAsync( - new CreateMessageRequestParams { Messages = [], MaxTokens = 100 }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - - // Store the expected task ID for filtering - lock (expectedTaskIdLock) - { - expectedTaskId = mcpTask.TaskId; - } - - // Wait for both Working and Completed notifications to arrive - // The notifications are sent asynchronously so we need to wait for both - await Task.WhenAll( - workingNotificationReceived.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken), - completedNotificationReceived.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken)); - - // Assert - Should have received notifications for status transitions - await notificationHandler.DisposeAsync(); - - List notifications; - lock (notificationsLock) - { - notifications = [.. notificationsReceived]; - } - - Assert.NotEmpty(notifications); - Assert.Contains(notifications, t => t.Status == McpTaskStatus.Working); - Assert.Contains(notifications, t => t.Status == McpTaskStatus.Completed); - - // Verify all notifications are for the correct task - Assert.All(notifications, t => Assert.Equal(mcpTask.TaskId, t.TaskId)); - } - - #endregion - - #region Error Handling Tests - - [Fact] - public async Task Client_SamplingHandlerException_ResultsInFailedTask() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var samplingAttempted = new TaskCompletionSource(); - - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - SamplingHandler = (request, progress, ct) => - { - samplingAttempted.TrySetResult(true); - throw new InvalidOperationException("Sampling failed!"); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Act - var mcpTask = await Server.SampleAsTaskAsync( - new CreateMessageRequestParams { Messages = [], MaxTokens = 100 }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - - // Wait for sampling attempt - await samplingAttempted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Poll until task status changes - McpTask taskStatus; - do - { - await Task.Delay(100, TestContext.Current.CancellationToken); - taskStatus = await Server.GetTaskAsync(mcpTask.TaskId, TestContext.Current.CancellationToken); - } - while (taskStatus.Status == McpTaskStatus.Working); - - // Assert - Task should be in failed state - Assert.Equal(McpTaskStatus.Failed, taskStatus.Status); - Assert.NotNull(taskStatus.StatusMessage); - Assert.Contains("Sampling failed!", taskStatus.StatusMessage); - } - - [Fact] - public async Task Client_ElicitationHandlerException_ResultsInFailedTask() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var elicitationAttempted = new TaskCompletionSource(); - - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - ElicitationHandler = (request, ct) => - { - elicitationAttempted.TrySetResult(true); - throw new InvalidOperationException("Elicitation failed!"); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Act - var mcpTask = await Server.ElicitAsTaskAsync( - new ElicitRequestParams - { - Message = "Test", - RequestedSchema = new() - }, - new McpTaskMetadata(), - TestContext.Current.CancellationToken); - - // Wait for elicitation attempt - await elicitationAttempted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Poll until task status changes - McpTask taskStatus; - do - { - await Task.Delay(100, TestContext.Current.CancellationToken); - taskStatus = await Server.GetTaskAsync(mcpTask.TaskId, TestContext.Current.CancellationToken); - } - while (taskStatus.Status == McpTaskStatus.Working); - - // Assert - Assert.Equal(McpTaskStatus.Failed, taskStatus.Status); - Assert.NotNull(taskStatus.StatusMessage); - Assert.Contains("Elicitation failed!", taskStatus.StatusMessage); - } - - #endregion - - #region Capability Validation Tests - - [Fact] - public async Task Client_WithOnlySamplingHandler_OnlyAdvertisesSamplingTasks() - { - // Arrange - Client with only sampling handler and task store - var taskStore = new InMemoryMcpTaskStore(); - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - SamplingHandler = (request, progress, ct) => - { - return new ValueTask(new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Response" }], - Model = "model" - }); - } - // No ElicitationHandler - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Assert - Assert.NotNull(Server.ClientCapabilities); - Assert.NotNull(Server.ClientCapabilities.Tasks); - - // Should have sampling task capability - Assert.NotNull(Server.ClientCapabilities.Tasks.Requests?.Sampling?.CreateMessage); - - // Should NOT have elicitation task capability - Assert.Null(Server.ClientCapabilities.Tasks.Requests?.Elicitation); - } - - [Fact] - public async Task Client_WithOnlyElicitationHandler_OnlyAdvertisesElicitationTasks() - { - // Arrange - Client with only elicitation handler and task store - var taskStore = new InMemoryMcpTaskStore(); - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - ElicitationHandler = (request, ct) => - { - return new ValueTask(new ElicitResult { Action = "confirm" }); - } - // No SamplingHandler - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Assert - Assert.NotNull(Server.ClientCapabilities); - Assert.NotNull(Server.ClientCapabilities.Tasks); - - // Should have elicitation task capability - Assert.NotNull(Server.ClientCapabilities.Tasks.Requests?.Elicitation?.Create); - - // Should NOT have sampling task capability - Assert.Null(Server.ClientCapabilities.Tasks.Requests?.Sampling); - } - - [Fact] - public async Task Client_WithBothHandlers_AdvertisesBothTaskCapabilities() - { - // Arrange - Client with both handlers and task store - var taskStore = new InMemoryMcpTaskStore(); - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers - { - SamplingHandler = (request, progress, ct) => - { - return new ValueTask(new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Response" }], - Model = "model" - }); - }, - ElicitationHandler = (request, ct) => - { - return new ValueTask(new ElicitResult { Action = "confirm" }); - } - } - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Assert - Assert.NotNull(Server.ClientCapabilities); - Assert.NotNull(Server.ClientCapabilities.Tasks); - Assert.NotNull(Server.ClientCapabilities.Tasks.Requests); - - // Should have both capabilities - Assert.NotNull(Server.ClientCapabilities.Tasks.Requests.Sampling?.CreateMessage); - Assert.NotNull(Server.ClientCapabilities.Tasks.Requests.Elicitation?.Create); - - // Should also have list and cancel capabilities - Assert.NotNull(Server.ClientCapabilities.Tasks.List); - Assert.NotNull(Server.ClientCapabilities.Tasks.Cancel); - } - - [Fact] - public async Task Client_WithNoHandlers_DoesNotAdvertiseTaskCapabilities() - { - // Arrange - Client with task store but no handlers - var taskStore = new InMemoryMcpTaskStore(); - var clientOptions = new McpClientOptions - { - TaskStore = taskStore, - Handlers = new McpClientHandlers() - // No handlers configured - }; - - await using McpClient client = await CreateMcpClientForServer(clientOptions); - - // Assert - No capabilities should be advertised without handlers - Assert.NotNull(Server.ClientCapabilities); - - // Note: Tasks capability is advertised based on task store being present, - // but request types depend on specific handlers - if (Server.ClientCapabilities.Tasks is not null) - { - // If Tasks is present, requests should be null or have no request types - var requests = Server.ClientCapabilities.Tasks.Requests; - if (requests is not null) - { - Assert.Null(requests.Sampling); - Assert.Null(requests.Elicitation); - } - } - } - - #endregion -} diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs index 689aba9d0..dc2eaf805 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs @@ -284,57 +284,4 @@ public void Configure_WithCompleteHandler_CreatesCompletionsCapability() Assert.NotNull(options.Capabilities?.Completions); } #endregion - - #region TaskStore Tests - [Fact] - public void TaskStore_IsPopulatedFromDI_WhenNotExplicitlySet() - { - var services = new ServiceCollection(); - services.AddMcpServer(); - services.AddSingleton(); - - var options = services.BuildServiceProvider().GetRequiredService>().Value; - - Assert.IsType(options.TaskStore); - } - - [Fact] - public void TaskStore_ExplicitOption_TakesPrecedenceOverDI() - { - var explicitStore = new InMemoryMcpTaskStore(); - - var services = new ServiceCollection(); - services.AddMcpServer(options => options.TaskStore = explicitStore); - services.AddSingleton(); - - var options = services.BuildServiceProvider().GetRequiredService>().Value; - - Assert.Same(explicitStore, options.TaskStore); - } - - [Fact] - public void TaskStore_RemainsNull_WhenNothingIsRegistered() - { - var services = new ServiceCollection(); - services.AddMcpServer(); - - var options = services.BuildServiceProvider().GetRequiredService>().Value; - - Assert.Null(options.TaskStore); - } - - [Fact] - public void TaskStore_CanBeOverriddenToNull_AfterDIRegistration() - { - var services = new ServiceCollection(); - services.AddMcpServer(); - services.AddSingleton(); - - services.Configure(options => options.TaskStore = null); - - var options = services.BuildServiceProvider().GetRequiredService>().Value; - - Assert.Null(options.TaskStore); - } - #endregion } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ExperimentalPropertySerializationTests.cs b/tests/ModelContextProtocol.Tests/ExperimentalPropertySerializationTests.cs index d68902ef5..866a59c61 100644 --- a/tests/ModelContextProtocol.Tests/ExperimentalPropertySerializationTests.cs +++ b/tests/ModelContextProtocol.Tests/ExperimentalPropertySerializationTests.cs @@ -1,4 +1,5 @@ using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; using ModelContextProtocol.Protocol; @@ -10,13 +11,13 @@ namespace ModelContextProtocol.Tests; /// /// /// -/// Experimental properties (e.g. , ) +/// Experimental properties (e.g. , ) /// use an internal *Core property for serialization. A consumer's source-generated /// cannot see internal members, so experimental data is /// silently dropped unless the consumer chains the SDK's resolver into their options. /// /// -/// These tests depend on and +/// These tests depend on and /// being experimental. When those APIs stabilize, update these tests to reference whatever /// experimental properties exist at that time, or remove them entirely if no experimental /// APIs remain. @@ -32,36 +33,36 @@ public void ExperimentalProperties_Dropped_WithConsumerContextOnly() TypeInfoResolverChain = { ConsumerJsonContext.Default } }; - var tool = new Tool + var capabilities = new ServerCapabilities { - Name = "test-tool", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } + Tools = new ToolsCapability(), + Extensions = new Dictionary { ["io.test"] = new JsonObject { ["enabled"] = true } } }; - string json = JsonSerializer.Serialize(tool, options); - Assert.DoesNotContain("\"execution\"", json); - Assert.Contains("\"name\"", json); + string json = JsonSerializer.Serialize(capabilities, options); + Assert.DoesNotContain("\"extensions\"", json); + Assert.Contains("\"tools\"", json); } [Fact] public void ExperimentalProperties_IgnoredOnDeserialize_WithConsumerContextOnly() { string json = JsonSerializer.Serialize( - new Tool + new ServerCapabilities { - Name = "test-tool", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } + Tools = new ToolsCapability(), + Extensions = new Dictionary { ["io.test"] = new JsonObject { ["enabled"] = true } } }, McpJsonUtilities.DefaultOptions); - Assert.Contains("\"execution\"", json); + Assert.Contains("\"extensions\"", json); var options = new JsonSerializerOptions { TypeInfoResolverChain = { ConsumerJsonContext.Default } }; - var deserialized = JsonSerializer.Deserialize(json, options)!; - Assert.Equal("test-tool", deserialized.Name); - Assert.Null(deserialized.Execution); + var deserialized = JsonSerializer.Deserialize(json, options)!; + Assert.NotNull(deserialized.Tools); + Assert.Null(deserialized.Extensions); } [Fact] @@ -76,35 +77,36 @@ public void ExperimentalProperties_RoundTrip_WhenSdkResolverIsChained() } }; - var tool = new Tool + var capabilities = new ServerCapabilities { - Name = "test-tool", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } + Tools = new ToolsCapability(), + Extensions = new Dictionary { ["io.test"] = new JsonObject { ["enabled"] = true } } }; - string json = JsonSerializer.Serialize(tool, options); - Assert.Contains("\"execution\"", json); - Assert.Contains("\"name\"", json); + string json = JsonSerializer.Serialize(capabilities, options); + Assert.Contains("\"extensions\"", json); + Assert.Contains("\"tools\"", json); - var deserialized = JsonSerializer.Deserialize(json, options)!; - Assert.Equal("test-tool", deserialized.Name); - Assert.NotNull(deserialized.Execution); - Assert.Equal(ToolTaskSupport.Optional, deserialized.Execution.TaskSupport); + var deserialized = JsonSerializer.Deserialize(json, options)!; + Assert.NotNull(deserialized.Tools); + Assert.NotNull(deserialized.Extensions); + Assert.True(deserialized.Extensions.ContainsKey("io.test")); } [Fact] public void ExperimentalProperties_RoundTrip_WithDefaultOptions() { - var capabilities = new ServerCapabilities + var capabilities = new ClientCapabilities { - Tasks = new McpTasksCapability() + Extensions = new Dictionary { ["io.test"] = new JsonObject { ["enabled"] = true } } }; string json = JsonSerializer.Serialize(capabilities, McpJsonUtilities.DefaultOptions); - Assert.Contains("\"tasks\"", json); + Assert.Contains("\"extensions\"", json); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)!; - Assert.NotNull(deserialized.Tasks); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)!; + Assert.NotNull(deserialized.Extensions); + Assert.True(deserialized.Extensions.ContainsKey("io.test")); } } diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 7f7de2a41..a9b40a412 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -35,10 +35,6 @@ - - - - diff --git a/tests/ModelContextProtocol.Tests/Protocol/CallToolRequestParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CallToolRequestParamsTests.cs index d2f5a09ad..ec758120f 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/CallToolRequestParamsTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/CallToolRequestParamsTests.cs @@ -17,7 +17,6 @@ public static void CallToolRequestParams_SerializationRoundTrip_PreservesAllProp ["city"] = JsonDocument.Parse("\"Seattle\"").RootElement.Clone(), ["units"] = JsonDocument.Parse("\"metric\"").RootElement.Clone() }, - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromHours(1) }, Meta = new JsonObject { ["progressToken"] = "token-123" } }; @@ -30,8 +29,6 @@ public static void CallToolRequestParams_SerializationRoundTrip_PreservesAllProp Assert.Equal(2, deserialized.Arguments.Count); Assert.Equal("Seattle", deserialized.Arguments["city"].GetString()); Assert.Equal("metric", deserialized.Arguments["units"].GetString()); - Assert.NotNull(deserialized.Task); - Assert.Equal(original.Task.TimeToLive, deserialized.Task.TimeToLive); Assert.NotNull(deserialized.Meta); Assert.Equal("token-123", (string)deserialized.Meta["progressToken"]!); } @@ -50,7 +47,6 @@ public static void CallToolRequestParams_SerializationRoundTrip_WithMinimalPrope Assert.NotNull(deserialized); Assert.Equal(original.Name, deserialized.Name); Assert.Null(deserialized.Arguments); - Assert.Null(deserialized.Task); Assert.Null(deserialized.Meta); } } diff --git a/tests/ModelContextProtocol.Tests/Protocol/CallToolResultTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CallToolResultTests.cs index d66e03b3f..b1ac90c9d 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/CallToolResultTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/CallToolResultTests.cs @@ -14,13 +14,6 @@ public static void CallToolResult_SerializationRoundTrip_PreservesAllProperties( Content = [new TextContentBlock { Text = "Result text" }], StructuredContent = JsonElement.Parse("""{"temperature":72}"""), IsError = false, - Task = new McpTask - { - TaskId = "task-1", - Status = McpTaskStatus.Completed, - CreatedAt = new DateTimeOffset(2025, 1, 1, 0, 0, 0, TimeSpan.Zero), - LastUpdatedAt = new DateTimeOffset(2025, 1, 1, 0, 0, 0, TimeSpan.Zero) - }, Meta = new JsonObject { ["key"] = "value" } }; @@ -34,8 +27,6 @@ public static void CallToolResult_SerializationRoundTrip_PreservesAllProperties( Assert.NotNull(deserialized.StructuredContent); Assert.Equal(72, deserialized.StructuredContent.Value.GetProperty("temperature").GetInt32()); Assert.False(deserialized.IsError); - Assert.NotNull(deserialized.Task); - Assert.Equal("task-1", deserialized.Task.TaskId); Assert.NotNull(deserialized.Meta); Assert.Equal("value", (string)deserialized.Meta["key"]!); } @@ -52,7 +43,6 @@ public static void CallToolResult_SerializationRoundTrip_WithMinimalProperties() Assert.Empty(deserialized.Content); Assert.Null(deserialized.StructuredContent); Assert.Null(deserialized.IsError); - Assert.Null(deserialized.Task); Assert.Null(deserialized.Meta); } } diff --git a/tests/ModelContextProtocol.Tests/Protocol/CancelMcpTaskRequestParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CancelMcpTaskRequestParamsTests.cs deleted file mode 100644 index a3b3b2ef6..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/CancelMcpTaskRequestParamsTests.cs +++ /dev/null @@ -1,25 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class CancelMcpTaskRequestParamsTests -{ - [Fact] - public static void CancelMcpTaskRequestParams_SerializationRoundTrip() - { - // Arrange - var original = new CancelMcpTaskRequestParams - { - TaskId = "cancel-task-456" - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TaskId, deserialized.TaskId); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/CancelMcpTaskResultTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CancelMcpTaskResultTests.cs deleted file mode 100644 index 5cf628642..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/CancelMcpTaskResultTests.cs +++ /dev/null @@ -1,33 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class CancelMcpTaskResultTests -{ - [Fact] - public static void CancelMcpTaskResult_SerializationRoundTrip() - { - // Arrange - var original = new CancelMcpTaskResult - { - TaskId = "cancelled-789", - Status = McpTaskStatus.Cancelled, - StatusMessage = "Cancelled by user", - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = null, - PollInterval = null - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TaskId, deserialized.TaskId); - Assert.Equal(original.Status, deserialized.Status); - Assert.Equal(original.StatusMessage, deserialized.StatusMessage); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/ClientCapabilitiesTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ClientCapabilitiesTests.cs index cacb7e84e..82613dd53 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ClientCapabilitiesTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ClientCapabilitiesTests.cs @@ -21,7 +21,6 @@ public static void ClientCapabilities_SerializationRoundTrip_PreservesAllPropert Form = new FormElicitationCapability(), Url = new UrlElicitationCapability() }, - Tasks = new McpTasksCapability(), Extensions = new Dictionary { ["io.modelcontextprotocol/test"] = new object() @@ -40,7 +39,6 @@ public static void ClientCapabilities_SerializationRoundTrip_PreservesAllPropert Assert.NotNull(deserialized.Elicitation); Assert.NotNull(deserialized.Elicitation.Form); Assert.NotNull(deserialized.Elicitation.Url); - Assert.NotNull(deserialized.Tasks); Assert.NotNull(deserialized.Extensions); Assert.True(deserialized.Extensions.ContainsKey("io.modelcontextprotocol/test")); } @@ -58,7 +56,6 @@ public static void ClientCapabilities_SerializationRoundTrip_WithMinimalProperti Assert.Null(deserialized.Roots); Assert.Null(deserialized.Sampling); Assert.Null(deserialized.Elicitation); - Assert.Null(deserialized.Tasks); Assert.Null(deserialized.Extensions); } diff --git a/tests/ModelContextProtocol.Tests/Protocol/CreateTaskResultTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CreateTaskResultTests.cs deleted file mode 100644 index 0252053cb..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/CreateTaskResultTests.cs +++ /dev/null @@ -1,41 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; -using System.Text.Json.Nodes; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class CreateTaskResultTests -{ - [Fact] - public static void CreateTaskResult_SerializationRoundTrip_PreservesAllProperties() - { - var original = new CreateTaskResult - { - Task = new McpTask - { - TaskId = "task-123", - Status = McpTaskStatus.Working, - StatusMessage = "Processing", - CreatedAt = new DateTimeOffset(2025, 6, 1, 12, 0, 0, TimeSpan.Zero), - LastUpdatedAt = new DateTimeOffset(2025, 6, 1, 12, 5, 0, TimeSpan.Zero), - TimeToLive = TimeSpan.FromHours(1), - PollInterval = TimeSpan.FromSeconds(5) - }, - Meta = new JsonObject { ["key"] = "value" } - }; - - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - Assert.NotNull(deserialized); - Assert.Equal("task-123", deserialized.Task.TaskId); - Assert.Equal(McpTaskStatus.Working, deserialized.Task.Status); - Assert.Equal("Processing", deserialized.Task.StatusMessage); - Assert.Equal(original.Task.CreatedAt, deserialized.Task.CreatedAt); - Assert.Equal(original.Task.LastUpdatedAt, deserialized.Task.LastUpdatedAt); - Assert.Equal(original.Task.TimeToLive, deserialized.Task.TimeToLive); - Assert.Equal(original.Task.PollInterval, deserialized.Task.PollInterval); - Assert.NotNull(deserialized.Meta); - Assert.Equal("value", (string)deserialized.Meta["key"]!); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitRequestParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitRequestParamsTests.cs index 1d57f55ad..f8e2fedbf 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ElicitRequestParamsTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitRequestParamsTests.cs @@ -23,7 +23,6 @@ public static void ElicitRequestParams_SerializationRoundTrip_PreservesAllProper ["age"] = new ElicitRequestParams.NumberSchema { Description = "Your age" } } }, - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(10) }, Meta = new JsonObject { ["progressToken"] = "tok-1" } }; @@ -37,8 +36,6 @@ public static void ElicitRequestParams_SerializationRoundTrip_PreservesAllProper Assert.Equal("Please provide your details", deserialized.Message); Assert.NotNull(deserialized.RequestedSchema); Assert.Equal(2, deserialized.RequestedSchema.Properties.Count); - Assert.NotNull(deserialized.Task); - Assert.Equal(TimeSpan.FromMinutes(10), deserialized.Task.TimeToLive); Assert.NotNull(deserialized.Meta); Assert.Equal("tok-1", (string)deserialized.Meta["progressToken"]!); } @@ -63,7 +60,6 @@ public static void ElicitRequestParams_SerializationRoundTrip_UrlMode() Assert.Equal("https://example.com/auth", deserialized.Url); Assert.Equal("Please authenticate", deserialized.Message); Assert.Null(deserialized.RequestedSchema); - Assert.Null(deserialized.Task); } [Fact] @@ -83,7 +79,6 @@ public static void ElicitRequestParams_SerializationRoundTrip_WithMinimalPropert Assert.Null(deserialized.ElicitationId); Assert.Null(deserialized.Url); Assert.Null(deserialized.RequestedSchema); - Assert.Null(deserialized.Task); Assert.Null(deserialized.Meta); } } diff --git a/tests/ModelContextProtocol.Tests/Protocol/GetTaskPayloadRequestParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/GetTaskPayloadRequestParamsTests.cs deleted file mode 100644 index 47f427259..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/GetTaskPayloadRequestParamsTests.cs +++ /dev/null @@ -1,25 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class GetTaskPayloadRequestParamsTests -{ - [Fact] - public static void GetTaskPayloadRequestParams_SerializationRoundTrip() - { - // Arrange - var original = new GetTaskPayloadRequestParams - { - TaskId = "payload-task-999" - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TaskId, deserialized.TaskId); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/GetTaskRequestParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/GetTaskRequestParamsTests.cs deleted file mode 100644 index 9b3e7b1d5..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/GetTaskRequestParamsTests.cs +++ /dev/null @@ -1,25 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class GetTaskRequestParamsTests -{ - [Fact] - public static void GetTaskRequestParams_SerializationRoundTrip() - { - // Arrange - var original = new GetTaskRequestParams - { - TaskId = "get-task-123" - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TaskId, deserialized.TaskId); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/GetTaskResultTests.cs b/tests/ModelContextProtocol.Tests/Protocol/GetTaskResultTests.cs deleted file mode 100644 index ece58683f..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/GetTaskResultTests.cs +++ /dev/null @@ -1,37 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class GetTaskResultTests -{ - [Fact] - public static void GetTaskResult_SerializationRoundTrip() - { - // Arrange - var original = new GetTaskResult - { - TaskId = "result-123", - Status = McpTaskStatus.Completed, - StatusMessage = "Done", - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = TimeSpan.FromHours(1), - PollInterval = TimeSpan.FromSeconds(1) - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TaskId, deserialized.TaskId); - Assert.Equal(original.Status, deserialized.Status); - Assert.Equal(original.StatusMessage, deserialized.StatusMessage); - Assert.Equal(original.CreatedAt, deserialized.CreatedAt); - Assert.Equal(original.LastUpdatedAt, deserialized.LastUpdatedAt); - Assert.Equal(original.TimeToLive, deserialized.TimeToLive); - Assert.Equal(original.PollInterval, deserialized.PollInterval); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/ListTasksRequestParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ListTasksRequestParamsTests.cs deleted file mode 100644 index 3e9022757..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/ListTasksRequestParamsTests.cs +++ /dev/null @@ -1,25 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class ListTasksRequestParamsTests -{ - [Fact] - public static void ListTasksRequestParams_SerializationRoundTrip() - { - // Arrange - var original = new ListTasksRequestParams - { - Cursor = "cursor-abc123" - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.Cursor, deserialized.Cursor); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/ListTasksResultTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ListTasksResultTests.cs deleted file mode 100644 index 8d2fbd33b..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/ListTasksResultTests.cs +++ /dev/null @@ -1,46 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class ListTasksResultTests -{ - [Fact] - public static void ListTasksResult_SerializationRoundTrip() - { - // Arrange - var original = new ListTasksResult - { - Tasks = - [ - new McpTask - { - TaskId = "task-1", - Status = McpTaskStatus.Working, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow - }, - new McpTask - { - TaskId = "task-2", - Status = McpTaskStatus.Completed, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow - } - ], - NextCursor = "next-page-token" - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.NotNull(deserialized.Tasks); - Assert.Equal(2, deserialized.Tasks.Count); - Assert.Equal(original.Tasks[0].TaskId, deserialized.Tasks[0].TaskId); - Assert.Equal(original.Tasks[1].TaskId, deserialized.Tasks[1].TaskId); - Assert.Equal(original.NextCursor, deserialized.NextCursor); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/McpTaskMetadataTests.cs b/tests/ModelContextProtocol.Tests/Protocol/McpTaskMetadataTests.cs deleted file mode 100644 index 82f33fbe7..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/McpTaskMetadataTests.cs +++ /dev/null @@ -1,53 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class McpTaskMetadataTests -{ - [Fact] - public static void McpTaskMetadata_SerializationRoundTrip_WithTimeToLive() - { - // Arrange - var original = new McpTaskMetadata - { - TimeToLive = TimeSpan.FromHours(2) - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TimeToLive, deserialized.TimeToLive); - } - - [Fact] - public static void McpTaskMetadata_SerializationRoundTrip_WithNullTimeToLive() - { - // Arrange - var original = new McpTaskMetadata(); - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Null(deserialized.TimeToLive); - } - - [Fact] - public static void McpTaskMetadata_HasCorrectJsonPropertyNames() - { - var metadata = new McpTaskMetadata - { - TimeToLive = TimeSpan.FromMinutes(15) - }; - - string json = JsonSerializer.Serialize(metadata, McpJsonUtilities.DefaultOptions); - - Assert.Contains("\"ttl\":", json); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/McpTaskStatusNotificationParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/McpTaskStatusNotificationParamsTests.cs deleted file mode 100644 index bf3cbbbf0..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/McpTaskStatusNotificationParamsTests.cs +++ /dev/null @@ -1,37 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class McpTaskStatusNotificationParamsTests -{ - [Fact] - public static void McpTaskStatusNotificationParams_SerializationRoundTrip() - { - // Arrange - var original = new McpTaskStatusNotificationParams - { - TaskId = "notification-task", - Status = McpTaskStatus.Completed, - StatusMessage = "Task completed successfully", - CreatedAt = new DateTimeOffset(2025, 12, 9, 10, 0, 0, TimeSpan.Zero), - LastUpdatedAt = new DateTimeOffset(2025, 12, 9, 10, 30, 0, TimeSpan.Zero), - TimeToLive = TimeSpan.FromHours(1), - PollInterval = TimeSpan.FromSeconds(2) - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TaskId, deserialized.TaskId); - Assert.Equal(original.Status, deserialized.Status); - Assert.Equal(original.StatusMessage, deserialized.StatusMessage); - Assert.Equal(original.CreatedAt, deserialized.CreatedAt); - Assert.Equal(original.LastUpdatedAt, deserialized.LastUpdatedAt); - Assert.Equal(original.TimeToLive, deserialized.TimeToLive); - Assert.Equal(original.PollInterval, deserialized.PollInterval); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/McpTaskTests.cs b/tests/ModelContextProtocol.Tests/Protocol/McpTaskTests.cs deleted file mode 100644 index 7919e408e..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/McpTaskTests.cs +++ /dev/null @@ -1,160 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class McpTaskTests -{ - [Fact] - public static void McpTask_SerializationRoundTrip_PreservesAllProperties() - { - // Arrange - var original = new McpTask - { - TaskId = "task-12345", - Status = McpTaskStatus.Working, - StatusMessage = "Processing request", - CreatedAt = new DateTimeOffset(2025, 12, 9, 10, 30, 0, TimeSpan.Zero), - LastUpdatedAt = new DateTimeOffset(2025, 12, 9, 10, 35, 0, TimeSpan.Zero), - TimeToLive = TimeSpan.FromHours(24), - PollInterval = TimeSpan.FromSeconds(5) - }; - - // Act - Serialize to JSON - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - - // Act - Deserialize back from JSON - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TaskId, deserialized.TaskId); - Assert.Equal(original.Status, deserialized.Status); - Assert.Equal(original.StatusMessage, deserialized.StatusMessage); - Assert.Equal(original.CreatedAt, deserialized.CreatedAt); - Assert.Equal(original.LastUpdatedAt, deserialized.LastUpdatedAt); - Assert.Equal(original.TimeToLive, deserialized.TimeToLive); - Assert.Equal(original.PollInterval, deserialized.PollInterval); - } - - [Fact] - public static void McpTask_SerializationRoundTrip_WithMinimalProperties() - { - // Arrange - var original = new McpTask - { - TaskId = "task-minimal", - Status = McpTaskStatus.Completed, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow - }; - - // Act - Serialize to JSON - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - - // Act - Deserialize back from JSON - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal(original.TaskId, deserialized.TaskId); - Assert.Equal(original.Status, deserialized.Status); - Assert.Null(deserialized.StatusMessage); - Assert.Equal(original.CreatedAt, deserialized.CreatedAt); - Assert.Equal(original.LastUpdatedAt, deserialized.LastUpdatedAt); - Assert.Null(deserialized.TimeToLive); - Assert.Null(deserialized.PollInterval); - } - - [Fact] - public static void McpTask_HasCorrectJsonPropertyNames() - { - var task = new McpTask - { - TaskId = "test-task", - Status = McpTaskStatus.Working, - StatusMessage = "Test message", - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = TimeSpan.FromMinutes(30), - PollInterval = TimeSpan.FromSeconds(1) - }; - - string json = JsonSerializer.Serialize(task, McpJsonUtilities.DefaultOptions); - - Assert.Contains("\"taskId\":", json); - Assert.Contains("\"status\":", json); - Assert.Contains("\"statusMessage\":", json); - Assert.Contains("\"createdAt\":", json); - Assert.Contains("\"lastUpdatedAt\":", json); - Assert.Contains("\"ttl\":", json); - Assert.Contains("\"pollInterval\":", json); - } - - [Fact] - public static void McpTask_TimeToLive_SerializesAsMilliseconds() - { - var task = new McpTask - { - TaskId = "test-ttl", - Status = McpTaskStatus.Working, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = TimeSpan.FromSeconds(60) - }; - - string json = JsonSerializer.Serialize(task, McpJsonUtilities.DefaultOptions); - - Assert.Contains("\"ttl\":60000", json); - } - - [Theory] - [InlineData(McpTaskStatus.Working)] - [InlineData(McpTaskStatus.InputRequired)] - [InlineData(McpTaskStatus.Completed)] - [InlineData(McpTaskStatus.Failed)] - [InlineData(McpTaskStatus.Cancelled)] - public static void McpTaskStatus_SerializesCorrectly(McpTaskStatus status) - { - var task = new McpTask - { - TaskId = "status-test", - Status = status, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow - }; - - string json = JsonSerializer.Serialize(task, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - Assert.NotNull(deserialized); - Assert.Equal(status, deserialized.Status); - } - - [Fact] - public static void McpTaskStatus_HasCorrectJsonValues() - { - var statuses = new[] - { - (McpTaskStatus.Working, "working"), - (McpTaskStatus.InputRequired, "input_required"), - (McpTaskStatus.Completed, "completed"), - (McpTaskStatus.Failed, "failed"), - (McpTaskStatus.Cancelled, "cancelled") - }; - - foreach (var (status, expectedJson) in statuses) - { - var task = new McpTask - { - TaskId = "test", - Status = status, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow - }; - - string json = JsonSerializer.Serialize(task, McpJsonUtilities.DefaultOptions); - Assert.Contains($"\"status\":\"{expectedJson}\"", json); - } - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/McpTasksCapabilityTests.cs b/tests/ModelContextProtocol.Tests/Protocol/McpTasksCapabilityTests.cs deleted file mode 100644 index 4e8caa740..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/McpTasksCapabilityTests.cs +++ /dev/null @@ -1,91 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class McpTasksCapabilityTests -{ - [Fact] - public static void McpTasksCapability_SerializationRoundTrip_WithAllProperties() - { - // Arrange - var original = new McpTasksCapability - { - List = new ListMcpTasksCapability(), - Cancel = new CancelMcpTasksCapability(), - Requests = new RequestMcpTasksCapability - { - Tools = new ToolsMcpTasksCapability - { - Call = new CallToolMcpTasksCapability() - }, - Sampling = new SamplingMcpTasksCapability - { - CreateMessage = new CreateMessageMcpTasksCapability() - }, - Elicitation = new ElicitationMcpTasksCapability - { - Create = new CreateElicitationMcpTasksCapability() - } - } - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.NotNull(deserialized.List); - Assert.NotNull(deserialized.Cancel); - Assert.NotNull(deserialized.Requests); - Assert.NotNull(deserialized.Requests.Tools); - Assert.NotNull(deserialized.Requests.Tools.Call); - Assert.NotNull(deserialized.Requests.Sampling); - Assert.NotNull(deserialized.Requests.Sampling.CreateMessage); - Assert.NotNull(deserialized.Requests.Elicitation); - Assert.NotNull(deserialized.Requests.Elicitation.Create); - } - - [Fact] - public static void McpTasksCapability_SerializationRoundTrip_WithMinimalProperties() - { - // Arrange - var original = new McpTasksCapability(); - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Null(deserialized.List); - Assert.Null(deserialized.Cancel); - Assert.Null(deserialized.Requests); - } - - [Fact] - public static void McpTasksCapability_HasCorrectJsonPropertyNames() - { - var capability = new McpTasksCapability - { - List = new ListMcpTasksCapability(), - Cancel = new CancelMcpTasksCapability(), - Requests = new RequestMcpTasksCapability - { - Tools = new ToolsMcpTasksCapability - { - Call = new CallToolMcpTasksCapability() - } - } - }; - - string json = JsonSerializer.Serialize(capability, McpJsonUtilities.DefaultOptions); - - Assert.Contains("\"list\":", json); - Assert.Contains("\"cancel\":", json); - Assert.Contains("\"requests\":", json); - Assert.Contains("\"tools\":", json); - Assert.Contains("\"call\":", json); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/RequestMcpTasksCapabilityTests.cs b/tests/ModelContextProtocol.Tests/Protocol/RequestMcpTasksCapabilityTests.cs deleted file mode 100644 index 8bfcb3be4..000000000 --- a/tests/ModelContextProtocol.Tests/Protocol/RequestMcpTasksCapabilityTests.cs +++ /dev/null @@ -1,108 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Protocol; - -public static class RequestMcpTasksCapabilityTests -{ - [Fact] - public static void RequestMcpTasksCapability_SerializationRoundTrip_ToolsOnly() - { - // Arrange - var original = new RequestMcpTasksCapability - { - Tools = new ToolsMcpTasksCapability - { - Call = new CallToolMcpTasksCapability() - } - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.NotNull(deserialized.Tools); - Assert.NotNull(deserialized.Tools.Call); - Assert.Null(deserialized.Sampling); - Assert.Null(deserialized.Elicitation); - } - - [Fact] - public static void RequestMcpTasksCapability_SerializationRoundTrip_SamplingOnly() - { - // Arrange - var original = new RequestMcpTasksCapability - { - Sampling = new SamplingMcpTasksCapability - { - CreateMessage = new CreateMessageMcpTasksCapability() - } - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Null(deserialized.Tools); - Assert.NotNull(deserialized.Sampling); - Assert.NotNull(deserialized.Sampling.CreateMessage); - Assert.Null(deserialized.Elicitation); - } - - [Fact] - public static void RequestMcpTasksCapability_SerializationRoundTrip_ElicitationOnly() - { - // Arrange - var original = new RequestMcpTasksCapability - { - Elicitation = new ElicitationMcpTasksCapability - { - Create = new CreateElicitationMcpTasksCapability() - } - }; - - // Act - string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - - // Assert - Assert.NotNull(deserialized); - Assert.Null(deserialized.Tools); - Assert.Null(deserialized.Sampling); - Assert.NotNull(deserialized.Elicitation); - Assert.NotNull(deserialized.Elicitation.Create); - } - - [Fact] - public static void RequestMcpTasksCapability_HasCorrectJsonPropertyNames() - { - var capability = new RequestMcpTasksCapability - { - Tools = new ToolsMcpTasksCapability - { - Call = new CallToolMcpTasksCapability() - }, - Sampling = new SamplingMcpTasksCapability - { - CreateMessage = new CreateMessageMcpTasksCapability() - }, - Elicitation = new ElicitationMcpTasksCapability - { - Create = new CreateElicitationMcpTasksCapability() - } - }; - - string json = JsonSerializer.Serialize(capability, McpJsonUtilities.DefaultOptions); - - Assert.Contains("\"tools\":", json); - Assert.Contains("\"sampling\":", json); - Assert.Contains("\"elicitation\":", json); - Assert.Contains("\"call\":", json); - Assert.Contains("\"createMessage\":", json); - Assert.Contains("\"create\":", json); - } -} diff --git a/tests/ModelContextProtocol.Tests/Protocol/ServerCapabilitiesTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ServerCapabilitiesTests.cs index a6f8265f1..7b95e911b 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ServerCapabilitiesTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ServerCapabilitiesTests.cs @@ -15,7 +15,6 @@ public static void ServerCapabilities_SerializationRoundTrip_PreservesAllPropert Resources = new ResourcesCapability { Subscribe = true, ListChanged = true }, Tools = new ToolsCapability { ListChanged = false }, Completions = new CompletionsCapability(), - Tasks = new McpTasksCapability(), Extensions = new Dictionary { ["io.modelcontextprotocol/apps"] = new object() @@ -35,7 +34,6 @@ public static void ServerCapabilities_SerializationRoundTrip_PreservesAllPropert Assert.NotNull(deserialized.Tools); Assert.False(deserialized.Tools.ListChanged); Assert.NotNull(deserialized.Completions); - Assert.NotNull(deserialized.Tasks); Assert.NotNull(deserialized.Extensions); Assert.True(deserialized.Extensions.ContainsKey("io.modelcontextprotocol/apps")); } @@ -55,7 +53,6 @@ public static void ServerCapabilities_SerializationRoundTrip_WithMinimalProperti Assert.Null(deserialized.Resources); Assert.Null(deserialized.Tools); Assert.Null(deserialized.Completions); - Assert.Null(deserialized.Tasks); Assert.Null(deserialized.Extensions); } diff --git a/tests/ModelContextProtocol.Tests/Protocol/TaskSerializationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/TaskSerializationTests.cs new file mode 100644 index 000000000..f97347705 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/TaskSerializationTests.cs @@ -0,0 +1,445 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Tests.Protocol; + +/// +/// Serialization and deserialization tests for SEP-2663 task protocol types. +/// +public static class TaskSerializationTests +{ + #region CreateTaskResult + + [Fact] + public static void CreateTaskResult_SerializationRoundTrip_PreservesAllProperties() + { + var original = new CreateTaskResult + { + TaskId = "task-123", + Status = McpTaskStatus.Working, + StatusMessage = "Processing...", + CreatedAt = new DateTimeOffset(2025, 6, 1, 12, 0, 0, TimeSpan.Zero), + LastUpdatedAt = new DateTimeOffset(2025, 6, 1, 12, 5, 0, TimeSpan.Zero), + TtlMs = 3600000, + PollIntervalMs = 5000, + ResultType = "task", + Meta = new JsonObject { ["key"] = "value" } + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.NotNull(deserialized); + Assert.Equal("task-123", deserialized.TaskId); + Assert.Equal(McpTaskStatus.Working, deserialized.Status); + Assert.Equal("Processing...", deserialized.StatusMessage); + Assert.Equal(original.CreatedAt, deserialized.CreatedAt); + Assert.Equal(original.LastUpdatedAt, deserialized.LastUpdatedAt); + Assert.Equal(3600000, deserialized.TtlMs); + Assert.Equal(5000, deserialized.PollIntervalMs); + Assert.Equal("task", deserialized.ResultType); + Assert.NotNull(deserialized.Meta); + Assert.Equal("value", (string)deserialized.Meta["key"]!); + } + + [Fact] + public static void CreateTaskResult_UsesCorrectWireFieldNames() + { + var result = new CreateTaskResult + { + TaskId = "t1", + Status = McpTaskStatus.Working, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + TtlMs = 60000, + PollIntervalMs = 1000, + ResultType = "task", + }; + + string json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + + // Must use camelCase wire names + Assert.Contains("\"ttlMs\":", json); + Assert.Contains("\"pollIntervalMs\":", json); + Assert.Contains("\"taskId\":", json); + Assert.Contains("\"resultType\":\"task\"", json); + + // Must NOT contain legacy field names + Assert.DoesNotContain("\"ttl\":", json); + Assert.DoesNotContain("\"pollInterval\":", json); + } + + [Fact] + public static void CreateTaskResult_ResultType_SerializesAsTask() + { + var result = new CreateTaskResult + { + TaskId = "t1", + Status = McpTaskStatus.Working, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + ResultType = "task", + }; + + string json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); + var node = JsonNode.Parse(json)!; + + Assert.Equal("task", (string)node["resultType"]!); + } + + #endregion + + #region GetTaskResult Subtypes + + [Fact] + public static void GetTaskResult_Working_RoundTrip() + { + var original = new WorkingTaskResult + { + TaskId = "w1", + CreatedAt = new DateTimeOffset(2025, 1, 1, 0, 0, 0, TimeSpan.Zero), + LastUpdatedAt = new DateTimeOffset(2025, 1, 1, 0, 1, 0, TimeSpan.Zero), + StatusMessage = "In progress", + PollIntervalMs = 2000, + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var working = Assert.IsType(deserialized); + Assert.Equal("w1", working.TaskId); + Assert.Equal(McpTaskStatus.Working, working.Status); + Assert.Equal("In progress", working.StatusMessage); + Assert.Equal(2000, working.PollIntervalMs); + } + + [Fact] + public static void GetTaskResult_Completed_RoundTrip_IncludesResult() + { + var resultPayload = JsonSerializer.SerializeToElement(new { content = new[] { new { type = "text", text = "done" } } }); + var original = new CompletedTaskResult + { + TaskId = "c1", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + TaskResult = resultPayload, + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var completed = Assert.IsType(deserialized); + Assert.Equal("c1", completed.TaskId); + Assert.Equal(McpTaskStatus.Completed, completed.Status); + Assert.Equal(JsonValueKind.Object, completed.TaskResult.ValueKind); + } + + [Fact] + public static void GetTaskResult_Failed_RoundTrip_IncludesError() + { + var errorPayload = JsonSerializer.SerializeToElement(new { code = -32000, message = "internal error" }); + var original = new FailedTaskResult + { + TaskId = "f1", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + Error = errorPayload, + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var failed = Assert.IsType(deserialized); + Assert.Equal("f1", failed.TaskId); + Assert.Equal(McpTaskStatus.Failed, failed.Status); + Assert.Equal(-32000, failed.Error.GetProperty("code").GetInt32()); + } + + [Fact] + public static void GetTaskResult_Cancelled_RoundTrip() + { + var original = new CancelledTaskResult + { + TaskId = "x1", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + StatusMessage = "User cancelled", + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var cancelled = Assert.IsType(deserialized); + Assert.Equal("x1", cancelled.TaskId); + Assert.Equal(McpTaskStatus.Cancelled, cancelled.Status); + Assert.Equal("User cancelled", cancelled.StatusMessage); + } + + [Fact] + public static void GetTaskResult_InputRequired_RoundTrip_IncludesInputRequests() + { + var inputRequests = new Dictionary + { + ["req-1"] = JsonSerializer.SerializeToElement(new { method = "elicitation/create", @params = new { message = "Confirm?" } }) + }; + var original = new InputRequiredTaskResult + { + TaskId = "i1", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + InputRequests = inputRequests, + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var inputRequired = Assert.IsType(deserialized); + Assert.Equal("i1", inputRequired.TaskId); + Assert.Equal(McpTaskStatus.InputRequired, inputRequired.Status); + Assert.Single(inputRequired.InputRequests); + Assert.True(inputRequired.InputRequests.ContainsKey("req-1")); + } + + [Fact] + public static void GetTaskResult_Converter_DispatchesToCorrectSubtypeByStatus() + { + var statuses = new (string status, Type expectedType)[] + { + ("working", typeof(WorkingTaskResult)), + ("completed", typeof(CompletedTaskResult)), + ("failed", typeof(FailedTaskResult)), + ("cancelled", typeof(CancelledTaskResult)), + ("input_required", typeof(InputRequiredTaskResult)), + }; + + foreach (var (status, expectedType) in statuses) + { + var json = status switch + { + "completed" => """{"taskId":"t","status":"completed","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z","result":{}}""", + "failed" => """{"taskId":"t","status":"failed","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z","error":{}}""", + "input_required" => """{"taskId":"t","status":"input_required","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z","inputRequests":{}}""", + _ => $$$"""{"taskId":"t","status":"{{{status}}}","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z"}""", + }; + + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + Assert.NotNull(result); + Assert.IsType(expectedType, result); + } + } + + [Fact] + public static void GetTaskResult_MissingTaskId_ThrowsJsonException() + { + var json = """{"status":"working","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z"}"""; + Assert.Throws(() => JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + } + + [Fact] + public static void GetTaskResult_MissingStatus_ThrowsJsonException() + { + var json = """{"taskId":"t","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z"}"""; + Assert.Throws(() => JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + } + + [Fact] + public static void GetTaskResult_UnknownStatus_ThrowsJsonException() + { + var json = """{"taskId":"t","status":"exploded","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z"}"""; + Assert.Throws(() => JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + } + + [Fact] + public static void GetTaskResult_CompletedMissingResult_ThrowsJsonException() + { + var json = """{"taskId":"t","status":"completed","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z"}"""; + Assert.Throws(() => JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + } + + [Fact] + public static void GetTaskResult_FailedMissingError_ThrowsJsonException() + { + var json = """{"taskId":"t","status":"failed","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z"}"""; + Assert.Throws(() => JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + } + + [Fact] + public static void GetTaskResult_InputRequiredMissingInputRequests_ThrowsJsonException() + { + var json = """{"taskId":"t","status":"input_required","createdAt":"2025-01-01T00:00:00Z","lastUpdatedAt":"2025-01-01T00:00:00Z"}"""; + Assert.Throws(() => JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + } + + #endregion + + #region McpTaskStatus Enum + + [Theory] + [InlineData(McpTaskStatus.Working, "working")] + [InlineData(McpTaskStatus.InputRequired, "input_required")] + [InlineData(McpTaskStatus.Completed, "completed")] + [InlineData(McpTaskStatus.Cancelled, "cancelled")] + [InlineData(McpTaskStatus.Failed, "failed")] + public static void McpTaskStatus_SerializesAsSnakeCase(McpTaskStatus status, string expectedWireValue) + { + string json = JsonSerializer.Serialize(status, McpJsonUtilities.DefaultOptions); + Assert.Equal($"\"{expectedWireValue}\"", json); + + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + Assert.Equal(status, deserialized); + } + + #endregion + + #region TaskStatusNotificationParams + + [Fact] + public static void TaskStatusNotificationParams_Working_RoundTrip() + { + var original = new WorkingTaskNotificationParams + { + TaskId = "n1", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + StatusMessage = "Working on it", + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var working = Assert.IsType(deserialized); + Assert.Equal("n1", working.TaskId); + Assert.Equal("Working on it", working.StatusMessage); + } + + [Fact] + public static void TaskStatusNotificationParams_Completed_RoundTrip() + { + var resultPayload = JsonSerializer.SerializeToElement(new { text = "done" }); + var original = new CompletedTaskNotificationParams + { + TaskId = "n2", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + TaskResult = resultPayload, + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var completed = Assert.IsType(deserialized); + Assert.Equal("n2", completed.TaskId); + Assert.Equal("done", completed.TaskResult.GetProperty("text").GetString()); + } + + [Fact] + public static void TaskStatusNotificationParams_Failed_RoundTrip() + { + var errorPayload = JsonSerializer.SerializeToElement(new { code = -1, message = "boom" }); + var original = new FailedTaskNotificationParams + { + TaskId = "n3", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + Error = errorPayload, + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var failed = Assert.IsType(deserialized); + Assert.Equal("n3", failed.TaskId); + Assert.Equal("boom", failed.Error.GetProperty("message").GetString()); + } + + [Fact] + public static void TaskStatusNotificationParams_Cancelled_RoundTrip() + { + var original = new CancelledTaskNotificationParams + { + TaskId = "n4", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + Assert.IsType(deserialized); + } + + [Fact] + public static void TaskStatusNotificationParams_InputRequired_RoundTrip() + { + var inputRequests = new Dictionary + { + ["r1"] = JsonSerializer.SerializeToElement(new { method = "sampling/createMessage" }) + }; + var original = new InputRequiredTaskNotificationParams + { + TaskId = "n5", + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + InputRequests = inputRequests, + }; + + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + + var inputRequired = Assert.IsType(deserialized); + Assert.Single(inputRequired.InputRequests); + } + + #endregion + + #region ResultOrCreatedTask + + [Fact] + public static void ResultOrCreatedTask_ImplicitConversion_FromResult() + { + CallToolResult callResult = new() { Content = [new TextContentBlock { Text = "hi" }] }; + + ResultOrCreatedTask augmented = callResult; + + Assert.False(augmented.IsTask); + Assert.Same(callResult, augmented.Result); + Assert.Null(augmented.TaskCreated); + } + + [Fact] + public static void ResultOrCreatedTask_ImplicitConversion_FromCreateTaskResult() + { + CreateTaskResult taskCreated = new() + { + TaskId = "t1", + Status = McpTaskStatus.Working, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + }; + + ResultOrCreatedTask augmented = taskCreated; + + Assert.True(augmented.IsTask); + Assert.Same(taskCreated, augmented.TaskCreated); + Assert.Null(augmented.Result); + } + + [Fact] + public static void ResultOrCreatedTask_IsTask_FalseForResult_TrueForTask() + { + var result = new ResultOrCreatedTask(new CallToolResult()); + var task = new ResultOrCreatedTask(new CreateTaskResult + { + TaskId = "t", + Status = McpTaskStatus.Working, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + }); + + Assert.False(result.IsTask); + Assert.True(task.IsTask); + } + + #endregion +} diff --git a/tests/ModelContextProtocol.Tests/Server/AutomaticInputRequiredStatusTests.cs b/tests/ModelContextProtocol.Tests/Server/AutomaticInputRequiredStatusTests.cs deleted file mode 100644 index 1f5c51c6c..000000000 --- a/tests/ModelContextProtocol.Tests/Server/AutomaticInputRequiredStatusTests.cs +++ /dev/null @@ -1,478 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; -using System.IO.Pipelines; - -namespace ModelContextProtocol.Tests.Server; - -/// -/// Tests for automatic InputRequired status tracking when server-to-client -/// requests (SampleAsync, ElicitAsync) are made during task-augmented tool execution. -/// -public class AutomaticInputRequiredStatusTests : LoggedTest -{ - public AutomaticInputRequiredStatusTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - } - -#pragma warning disable MCPEXP001 // Tasks feature is experimental - - [Fact] - public async Task TaskStatus_TransitionsToInputRequired_DuringSampleAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var statusesDuringSampling = new List(); - var samplingRequestReceived = new TaskCompletionSource(); - var continueSampling = new TaskCompletionSource(); - - await using var fixture = new InputRequiredTestFixture( - LoggerFactory, - configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - options.SendTaskStatusNotifications = true; // Enable notifications - }); - - // Tool that calls SampleAsync during execution - builder.WithTools([McpServerTool.Create( - async (string prompt, McpServer server, CancellationToken ct) => - { - // Call SampleAsync - this should trigger InputRequired status - var result = await server.SampleAsync(new CreateMessageRequestParams - { - Messages = [new SamplingMessage - { - Role = Role.User, - Content = [new TextContentBlock { Text = prompt }] - }], - MaxTokens = 100 - }, ct); - - var textContent = result.Content.OfType().FirstOrDefault(); - return textContent?.Text ?? "No response"; - }, - new McpServerToolCreateOptions - { - Name = "sampling-tool", - Description = "A tool that uses sampling" - })]); - }, - configureClient: clientOptions => - { - clientOptions.Handlers = new McpClientHandlers - { - SamplingHandler = async (request, progress, ct) => - { - // Signal that we received the sampling request - samplingRequestReceived.TrySetResult(true); - - // Wait for permission to continue (so we can check status) - await continueSampling.Task.WaitAsync(ct); - - return new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Sampled response" }], - Model = "test-model" - }; - } - }; - }); - - // Act - Call the tool as a task - var mcpTask = await fixture.Client.CallToolAsTaskAsync( - "sampling-tool", - arguments: new Dictionary { ["prompt"] = "Hello" }, - taskMetadata: new McpTaskMetadata(), - progress: null, - cancellationToken: TestContext.Current.CancellationToken); - - // Wait for the sampling request to be received by the client - await samplingRequestReceived.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Check the task status while sampling is in progress - var statusDuringSampling = await taskStore.GetTaskAsync( - mcpTask.TaskId, - cancellationToken: TestContext.Current.CancellationToken); - - if (statusDuringSampling is not null) - { - statusesDuringSampling.Add(statusDuringSampling.Status); - } - - // Allow sampling to complete - continueSampling.TrySetResult(true); - - // Wait for task to complete - McpTask? finalStatus = null; - int maxAttempts = 50; - do - { - await Task.Delay(100, TestContext.Current.CancellationToken); - finalStatus = await taskStore.GetTaskAsync(mcpTask.TaskId, cancellationToken: TestContext.Current.CancellationToken); - maxAttempts--; - } - while (finalStatus?.Status is not McpTaskStatus.Completed && maxAttempts > 0); - - // Assert - Status should have been InputRequired during sampling - Assert.Contains(McpTaskStatus.InputRequired, statusesDuringSampling); - - // Final status should be Completed - Assert.NotNull(finalStatus); - Assert.Equal(McpTaskStatus.Completed, finalStatus.Status); - } - - [Fact] - public async Task TaskStatus_TransitionsToInputRequired_DuringElicitAsync() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var statusesDuringElicitation = new List(); - var elicitationRequestReceived = new TaskCompletionSource(); - var continueElicitation = new TaskCompletionSource(); - - await using var fixture = new InputRequiredTestFixture( - LoggerFactory, - configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - options.SendTaskStatusNotifications = true; - }); - - // Tool that calls ElicitAsync during execution - builder.WithTools([McpServerTool.Create( - async (string message, McpServer server, CancellationToken ct) => - { - // Call ElicitAsync - this should trigger InputRequired status - var result = await server.ElicitAsync(new ElicitRequestParams - { - Message = message, - RequestedSchema = new() - }, ct); - - return result.Action == "confirm" ? "Confirmed" : "Declined"; - }, - new McpServerToolCreateOptions - { - Name = "elicitation-tool", - Description = "A tool that uses elicitation" - })]); - }, - configureClient: clientOptions => - { - clientOptions.Handlers = new McpClientHandlers - { - ElicitationHandler = async (request, ct) => - { - // Signal that we received the elicitation request - elicitationRequestReceived.TrySetResult(true); - - // Wait for permission to continue - await continueElicitation.Task.WaitAsync(ct); - - return new ElicitResult { Action = "confirm" }; - } - }; - }); - - // Act - Call the tool as a task - var mcpTask = await fixture.Client.CallToolAsTaskAsync( - "elicitation-tool", - arguments: new Dictionary { ["message"] = "Please confirm" }, - taskMetadata: new McpTaskMetadata(), - progress: null, - cancellationToken: TestContext.Current.CancellationToken); - - // Wait for the elicitation request to be received - await elicitationRequestReceived.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Check the task status while elicitation is in progress - var statusDuringElicitation = await taskStore.GetTaskAsync( - mcpTask.TaskId, - cancellationToken: TestContext.Current.CancellationToken); - - if (statusDuringElicitation is not null) - { - statusesDuringElicitation.Add(statusDuringElicitation.Status); - } - - // Allow elicitation to complete - continueElicitation.TrySetResult(true); - - // Wait for task to complete - McpTask? finalStatus = null; - int maxAttempts = 50; - do - { - await Task.Delay(100, TestContext.Current.CancellationToken); - finalStatus = await taskStore.GetTaskAsync(mcpTask.TaskId, cancellationToken: TestContext.Current.CancellationToken); - maxAttempts--; - } - while (finalStatus?.Status is not McpTaskStatus.Completed && maxAttempts > 0); - - // Assert - Status should have been InputRequired during elicitation - Assert.Contains(McpTaskStatus.InputRequired, statusesDuringElicitation); - - // Final status should be Completed - Assert.NotNull(finalStatus); - Assert.Equal(McpTaskStatus.Completed, finalStatus.Status); - } - - [Fact] - public async Task TaskStatus_ReturnsToWorking_AfterSamplingCompletes() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - var samplingCompleted = new TaskCompletionSource(); - var checkStatusAfterSampling = new TaskCompletionSource(); - - await using var fixture = new InputRequiredTestFixture( - LoggerFactory, - configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - // Tool that calls SampleAsync and then waits - builder.WithTools([McpServerTool.Create( - async (string prompt, McpServer server, CancellationToken ct) => - { - // Call SampleAsync - var result = await server.SampleAsync(new CreateMessageRequestParams - { - Messages = [new SamplingMessage - { - Role = Role.User, - Content = [new TextContentBlock { Text = prompt }] - }], - MaxTokens = 100 - }, ct); - - // Signal that sampling completed - samplingCompleted.TrySetResult(true); - - // Wait so test can check status - await checkStatusAfterSampling.Task.WaitAsync(ct); - - var textContent = result.Content.OfType().FirstOrDefault(); - return textContent?.Text ?? "No response"; - }, - new McpServerToolCreateOptions - { - Name = "sampling-tool", - Description = "A tool that uses sampling" - })]); - }, - configureClient: clientOptions => - { - clientOptions.Handlers = new McpClientHandlers - { - SamplingHandler = (request, progress, ct) => - { - // Return immediately to let sampling complete - return new ValueTask(new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Response" }], - Model = "test-model" - }); - } - }; - }); - - // Act - Call the tool as a task - var mcpTask = await fixture.Client.CallToolAsTaskAsync( - "sampling-tool", - arguments: new Dictionary { ["prompt"] = "Hello" }, - taskMetadata: new McpTaskMetadata(), - progress: null, - cancellationToken: TestContext.Current.CancellationToken); - - // Wait for sampling to complete inside the tool - await samplingCompleted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Small delay to ensure status update is processed - await Task.Delay(50, TestContext.Current.CancellationToken); - - // Check status after sampling completed (should be back to Working) - var taskAfterSampling = await taskStore.GetTaskAsync(mcpTask.TaskId, cancellationToken: TestContext.Current.CancellationToken); - - // Allow tool to complete - checkStatusAfterSampling.TrySetResult(true); - - // Assert - Status should be Working after sampling completes (before tool completes) - Assert.NotNull(taskAfterSampling); - Assert.Equal(McpTaskStatus.Working, taskAfterSampling.Status); - } - - [Fact] - public async Task TaskStatus_DoesNotChangeToInputRequired_ForNonTaskExecution() - { - // Arrange - When a tool is NOT executed as a task, SampleAsync should not change any task status - var taskStore = new InMemoryMcpTaskStore(); - var samplingCompleted = new TaskCompletionSource(); - - await using var fixture = new InputRequiredTestFixture( - LoggerFactory, - configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - // Tool that calls SampleAsync - note it doesn't have TaskSupport.Required so can be called directly - builder.WithTools([McpServerTool.Create( - async (string prompt, McpServer server, CancellationToken ct) => - { - var result = await server.SampleAsync(new CreateMessageRequestParams - { - Messages = [new SamplingMessage - { - Role = Role.User, - Content = [new TextContentBlock { Text = prompt }] - }], - MaxTokens = 100 - }, ct); - - samplingCompleted.TrySetResult(true); - var textContent = result.Content.OfType().FirstOrDefault(); - return textContent?.Text ?? "No response"; - }, - new McpServerToolCreateOptions - { - Name = "sampling-tool", - Description = "A tool that uses sampling", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } - })]); - }, - configureClient: clientOptions => - { - clientOptions.Handlers = new McpClientHandlers - { - SamplingHandler = (request, progress, ct) => - { - return new ValueTask(new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Response" }], - Model = "test-model" - }); - } - }; - }); - - // Act - Call the tool DIRECTLY (not as a task) - var result = await fixture.Client.CallToolAsync( - "sampling-tool", - arguments: new Dictionary { ["prompt"] = "Hello" }, - cancellationToken: TestContext.Current.CancellationToken); - - await samplingCompleted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Assert - No task should exist (tool was not called as a task) - var tasks = await taskStore.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Empty(tasks.Tasks); - - // And the result should still work - Assert.NotNull(result); - } - -#pragma warning restore MCPEXP001 - - /// - /// Test fixture that supports both server and client configuration for InputRequired status tests. - /// - private sealed class InputRequiredTestFixture : IAsyncDisposable - { - private readonly Pipe _clientToServerPipe = new(); - private readonly Pipe _serverToClientPipe = new(); - private readonly IServiceProvider _serviceProvider; - private readonly McpServer _server; - private readonly Task _serverTask; - private readonly CancellationTokenSource _cts; - - public McpClient Client { get; } - public McpServer Server => _server; - - public InputRequiredTestFixture( - ILoggerFactory loggerFactory, - Action? configureServer = null, - Action? configureClient = null) - { - _cts = new CancellationTokenSource(); - - // Configure server - var services = new ServiceCollection(); - services.AddLogging(); - services.AddSingleton(loggerFactory); - - var builder = services - .AddMcpServer() - .WithStreamServerTransport( - _clientToServerPipe.Reader.AsStream(), - _serverToClientPipe.Writer.AsStream()); - - configureServer?.Invoke(services, builder); - - _serviceProvider = services.BuildServiceProvider(validateScopes: true); - _server = _serviceProvider.GetRequiredService(); - _serverTask = _server.RunAsync(_cts.Token); - - // Configure client - var clientOptions = new McpClientOptions(); - configureClient?.Invoke(clientOptions); - - // Create client synchronously (test code) - Client = McpClient.CreateAsync( - new StreamClientTransport( - serverInput: _clientToServerPipe.Writer.AsStream(), - _serverToClientPipe.Reader.AsStream(), - loggerFactory), - clientOptions: clientOptions, - loggerFactory: loggerFactory, - cancellationToken: TestContext.Current.CancellationToken).GetAwaiter().GetResult(); - } - - public async ValueTask DisposeAsync() - { - await Client.DisposeAsync(); - await _cts.CancelAsync(); - - _clientToServerPipe.Writer.Complete(); - _serverToClientPipe.Writer.Complete(); - - try - { - await _serverTask; - } - catch (OperationCanceledException) - { - // Expected - } - - if (_serviceProvider is IAsyncDisposable asyncDisposable) - { - await asyncDisposable.DisposeAsync(); - } - else if (_serviceProvider is IDisposable disposable) - { - disposable.Dispose(); - } - - _cts.Dispose(); - } - } -} diff --git a/tests/ModelContextProtocol.Tests/Server/InMemoryMcpTaskStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/InMemoryMcpTaskStoreTests.cs index 7d2fc5596..33c19d457 100644 --- a/tests/ModelContextProtocol.Tests/Server/InMemoryMcpTaskStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/InMemoryMcpTaskStoreTests.cs @@ -1,1231 +1,279 @@ -using Microsoft.Extensions.Time.Testing; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; using System.Text.Json; -using TestInMemoryMcpTaskStore = ModelContextProtocol.Tests.Internal.InMemoryMcpTaskStore; + +#pragma warning disable MCPEXP001 namespace ModelContextProtocol.Tests.Server; -public class InMemoryMcpTaskStoreTests : LoggedTest +/// +/// Unit tests for . +/// +public class InMemoryMcpTaskStoreTests { - public InMemoryMcpTaskStoreTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - } + private CancellationToken CT => TestContext.Current.CancellationToken; [Fact] - public async Task CreateTaskAsync_CreatesTaskWithUniqueId() + public async Task CreateTaskAsync_ReturnsWorkingTaskWithUniqueId() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var requestId = new RequestId("req-1"); - var request = new JsonRpcRequest { Method = "tools/call" }; + var store = new InMemoryMcpTaskStore(); - // Act - var task = await store.CreateTaskAsync(metadata, requestId, request, "session-1", TestContext.Current.CancellationToken); + var result = await store.CreateTaskAsync(CT); - // Assert - Assert.NotNull(task); - Assert.NotEmpty(task.TaskId); - Assert.Equal(McpTaskStatus.Working, task.Status); - Assert.NotEqual(default, task.CreatedAt); - Assert.NotEqual(default, task.LastUpdatedAt); + Assert.NotNull(result); + Assert.NotEmpty(result.TaskId); + Assert.Equal(McpTaskStatus.Working, result.Status); + Assert.NotEqual(default, result.CreatedAt); + Assert.NotEqual(default, result.LastUpdatedAt); } [Fact] - public async Task CreateTaskAsync_GeneratesUniqueTaskIds() + public async Task CreateTaskAsync_GeneratesUniqueIds() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); + var store = new InMemoryMcpTaskStore(); - // Act - var task1 = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var task2 = await store.CreateTaskAsync(metadata, new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); + var task1 = await store.CreateTaskAsync(CT); + var task2 = await store.CreateTaskAsync(CT); - // Assert Assert.NotEqual(task1.TaskId, task2.TaskId); } [Fact] - public async Task CreateTaskAsync_AppliesTtlFromMetadata() + public async Task CreateTaskAsync_UsesDefaultPollInterval() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata - { - TimeToLive = TimeSpan.FromSeconds(5) - }; + var store = new InMemoryMcpTaskStore { DefaultPollIntervalMs = 500 }; - // Act - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); + var result = await store.CreateTaskAsync(CT); - // Assert - Assert.Equal(TimeSpan.FromSeconds(5), task.TimeToLive); + Assert.Equal(500, result.PollIntervalMs); } [Fact] - public async Task CreateTaskAsync_CapsMaxTtl() + public async Task CreateTaskAsync_UsesDefaultTtl() { - // Arrange - var maxTtl = TimeSpan.FromMinutes(5); - using var store = new InMemoryMcpTaskStore(maxTtl: maxTtl); - var metadata = new McpTaskMetadata - { - TimeToLive = TimeSpan.FromHours(1) // Request 1 hour - }; + var store = new InMemoryMcpTaskStore { DefaultTtlMs = 30000 }; - // Act - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); + var result = await store.CreateTaskAsync(CT); - // Assert - Assert.Equal(maxTtl, task.TimeToLive); + Assert.Equal(30000, result.TtlMs); } [Fact] - public async Task GetTaskAsync_ReturnsTaskById() + public async Task GetTaskAsync_ReturnsWorkingTask() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var created = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); - // Act - var retrieved = await store.GetTaskAsync(created.TaskId, null, TestContext.Current.CancellationToken); + var result = await store.GetTaskAsync(created.TaskId, CT); - // Assert - Assert.NotNull(retrieved); - Assert.Equal(created.TaskId, retrieved.TaskId); - Assert.Equal(created.Status, retrieved.Status); + Assert.NotNull(result); + Assert.Equal(McpTaskStatus.Working, result.Status); + Assert.Equal(created.TaskId, result.TaskId); } [Fact] - public async Task GetTaskAsync_ReturnsNullForNonexistentTask() + public async Task GetTaskAsync_ReturnsNullForUnknownId() { - // Arrange - using var store = new InMemoryMcpTaskStore(); + var store = new InMemoryMcpTaskStore(); - // Act - var task = await store.GetTaskAsync("nonexistent-id", null, TestContext.Current.CancellationToken); + var result = await store.GetTaskAsync("nonexistent", CT); - // Assert - Assert.Null(task); + Assert.Null(result); } [Fact] - public async Task GetTaskAsync_EnforcesSessionIsolation() + public async Task SetCompletedAsync_TransitionsToCompleted() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); + var resultPayload = JsonDocument.Parse("""{"answer":42}""").RootElement.Clone(); - // Act - var sameSession = await store.GetTaskAsync(task.TaskId, "session-1", TestContext.Current.CancellationToken); - var differentSession = await store.GetTaskAsync(task.TaskId, "session-2", TestContext.Current.CancellationToken); + await store.SetCompletedAsync(created.TaskId, resultPayload, CT); - // Assert - Assert.NotNull(sameSession); - Assert.Null(differentSession); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.Completed, task.Status); + Assert.Equal(42, task.Result!.Value.GetProperty("answer").GetInt32()); } [Fact] - public async Task StoreTaskResultAsync_StoresResultForCompletedTask() + public async Task SetFailedAsync_TransitionsToFailed() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var result = new CallToolResult { Content = [new TextContentBlock { Text = "Success" }] }; - var resultElement = JsonSerializer.SerializeToElement(result, McpJsonUtilities.DefaultOptions); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); + var errorPayload = JsonDocument.Parse("""{"message":"boom"}""").RootElement.Clone(); - // Act - await store.StoreTaskResultAsync(task.TaskId, McpTaskStatus.Completed, resultElement, null, TestContext.Current.CancellationToken); + await store.SetFailedAsync(created.TaskId, errorPayload, CT); - // Assert - var retrieved = await store.GetTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Completed, retrieved!.Status); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.Failed, task.Status); + Assert.Equal("boom", task.Error!.Value.GetProperty("message").GetString()); } [Fact] - public async Task StoreTaskResultAsync_EnforcesSessionIsolation() + public async Task SetCancelledAsync_TransitionsToCancelled() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - var result = new CallToolResult { Content = [new TextContentBlock { Text = "Success" }] }; - var resultElement = JsonSerializer.SerializeToElement(result, McpJsonUtilities.DefaultOptions); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); - // Act & Assert - await Assert.ThrowsAsync( - () => store.StoreTaskResultAsync(task.TaskId, McpTaskStatus.Completed, resultElement, "session-2", TestContext.Current.CancellationToken)); - } - - [Fact] - public async Task StoreTaskResultAsync_ThrowsForNonTerminalStatus() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var result = new CallToolResult { Content = [new TextContentBlock { Text = "Success" }] }; - var resultElement = JsonSerializer.SerializeToElement(result, McpJsonUtilities.DefaultOptions); + var cancelled = await store.SetCancelledAsync(created.TaskId, CT); - // Act & Assert - await Assert.ThrowsAsync( - () => store.StoreTaskResultAsync(task.TaskId, McpTaskStatus.Working, resultElement, null, TestContext.Current.CancellationToken)); + Assert.True(cancelled); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.Cancelled, task.Status); } [Fact] - public async Task GetTaskResultAsync_ReturnsStoredResult() + public async Task SetCancelledAsync_ReturnsFalseForTerminalTask() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var result = new CallToolResult { Content = [new TextContentBlock { Text = "Success" }] }; - var resultElement = JsonSerializer.SerializeToElement(result, McpJsonUtilities.DefaultOptions); - await store.StoreTaskResultAsync(task.TaskId, McpTaskStatus.Completed, resultElement, null, TestContext.Current.CancellationToken); - - // Act - var retrieved = await store.GetTaskResultAsync(task.TaskId, null, TestContext.Current.CancellationToken); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); + await store.SetCompletedAsync(created.TaskId, JsonSerializer.SerializeToElement("done", McpJsonUtilities.DefaultOptions), CT); - // Assert - var callToolResult = retrieved.Deserialize(McpJsonUtilities.DefaultOptions)!; - Assert.Single(callToolResult.Content); - Assert.Equal("Success", ((TextContentBlock)callToolResult.Content[0]).Text); - } + var cancelled = await store.SetCancelledAsync(created.TaskId, CT); - [Fact] - public async Task GetTaskResultAsync_EnforcesSessionIsolation() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - var result = new CallToolResult { Content = [new TextContentBlock { Text = "Success" }] }; - var resultElement = JsonSerializer.SerializeToElement(result, McpJsonUtilities.DefaultOptions); - await store.StoreTaskResultAsync(task.TaskId, McpTaskStatus.Completed, resultElement, "session-1", TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync( - () => store.GetTaskResultAsync(task.TaskId, "session-2", TestContext.Current.CancellationToken)); + Assert.False(cancelled); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.Completed, task.Status); } [Fact] - public async Task UpdateTaskStatusAsync_UpdatesStatus() + public async Task SetCancelledAsync_ReturnsFalseForUnknownId() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); + var store = new InMemoryMcpTaskStore(); - // Act - await store.UpdateTaskStatusAsync(task.TaskId, McpTaskStatus.Working, "Processing...", null, TestContext.Current.CancellationToken); + var cancelled = await store.SetCancelledAsync("nonexistent", CT); - // Assert - var updated = await store.GetTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Working, updated!.Status); - Assert.Equal("Processing...", updated.StatusMessage); + Assert.False(cancelled); } [Fact] - public async Task UpdateTaskStatusAsync_UpdatesLastUpdatedAt() + public async Task SetInputRequestsAsync_TransitionsToInputRequired() { - // Arrange - Use FakeTimeProvider for deterministic testing - var fakeTime = new FakeTimeProvider(DateTimeOffset.UtcNow); - using var store = new TestInMemoryMcpTaskStore( - defaultTtl: null, - maxTtl: null, - pollInterval: null, - cleanupInterval: Timeout.InfiniteTimeSpan, - pageSize: 100, - maxTasks: null, - maxTasksPerSession: null, - timeProvider: fakeTime); - - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var originalTimestamp = task.LastUpdatedAt; - - // Advance time to ensure timestamp changes - fakeTime.Advance(TimeSpan.FromMilliseconds(10)); - - // Act - await store.UpdateTaskStatusAsync(task.TaskId, McpTaskStatus.Working, null, null, TestContext.Current.CancellationToken); - - // Assert - var updated = await store.GetTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - Assert.True(updated!.LastUpdatedAt > originalTimestamp); - } - - #region Input Required Status Tests + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); - // NOTE: The InputRequired status is automatically set by the server when a tool executing - // as a task calls SampleAsync() or ElicitAsync(). The status is set back to Working when - // the request completes. See TaskExecutionContext for implementation details. - // The tests below verify the store correctly handles status transitions. - - [Fact] - public async Task InputRequiredStatus_SerializesCorrectly() - { - // Verify the input_required status serializes as expected - var task = new McpTask + var requests = new Dictionary { - TaskId = "test-task", - Status = McpTaskStatus.InputRequired, - StatusMessage = "Waiting for user input", - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow + ["req1"] = JsonDocument.Parse("""{"method":"elicitation/create","params":{"message":"hello"}}""").RootElement.Clone() }; + await store.SetInputRequestsAsync(created.TaskId, requests, CT); - string json = JsonSerializer.Serialize(task, McpJsonUtilities.DefaultOptions); - - Assert.Contains("\"status\":\"input_required\"", json); - } - - [Fact] - public async Task InputRequiredStatus_CanTransitionToWorking() - { - // Arrange - Spec: "From input_required: may move to working, completed, failed, or cancelled" - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Transition to input_required (testing store's status transition capability) - var inputRequiredTask = await store.UpdateTaskStatusAsync( - task.TaskId, - McpTaskStatus.InputRequired, - "Waiting for user confirmation", - cancellationToken: TestContext.Current.CancellationToken); - - Assert.Equal(McpTaskStatus.InputRequired, inputRequiredTask.Status); - - // Act - Transition back to working - var workingTask = await store.UpdateTaskStatusAsync( - task.TaskId, - McpTaskStatus.Working, - "Processing resumed", - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(McpTaskStatus.Working, workingTask.Status); - } - - [Fact] - public async Task InputRequiredStatus_CanTransitionToCancelled() - { - // Arrange - Spec: Task transitions show input_required can go to terminal states - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Transition to input_required - await store.UpdateTaskStatusAsync( - task.TaskId, - McpTaskStatus.InputRequired, - "Need input", - cancellationToken: TestContext.Current.CancellationToken); - - // Act - Transition to cancelled - var cancelledTask = await store.CancelTaskAsync( - task.TaskId, - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(McpTaskStatus.Cancelled, cancelledTask.Status); - } - - #endregion - - [Fact] - public async Task ListTasksAsync_ReturnsAllTasks() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var task1 = await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var task2 = await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Act - var result = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(2, result.Tasks.Count); - Assert.Contains(result.Tasks, t => t.TaskId == task1.TaskId); - Assert.Contains(result.Tasks, t => t.TaskId == task2.TaskId); - Assert.Null(result.NextCursor); - } - - [Fact] - public async Task ListTasksAsync_FiltersBySession() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var task1 = await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - var task2 = await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, "session-2", TestContext.Current.CancellationToken); - - // Act - var session1Result = await store.ListTasksAsync(sessionId: "session-1", cancellationToken: TestContext.Current.CancellationToken); - var session2Result = await store.ListTasksAsync(sessionId: "session-2", cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.Single(session1Result.Tasks); - Assert.Equal(task1.TaskId, session1Result.Tasks[0].TaskId); - Assert.Single(session2Result.Tasks); - Assert.Equal(task2.TaskId, session2Result.Tasks[0].TaskId); - } - - [Fact] - public async Task ListTasksAsync_SupportsPagination() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - - // Create 150 tasks (more than page size of 100) - for (int i = 0; i < 150; i++) - { - await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId($"req-{i}"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - } - - // Act - First page - var firstPageResult = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Act - Second page - var secondPageResult = await store.ListTasksAsync(cursor: firstPageResult.NextCursor, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(100, firstPageResult.Tasks.Count); - Assert.NotNull(firstPageResult.NextCursor); - Assert.Equal(50, secondPageResult.Tasks.Count); - Assert.Null(secondPageResult.NextCursor); - } - - [Fact] - public async Task CancelTaskAsync_CancelsTask() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Act - var cancelled = await store.CancelTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(McpTaskStatus.Cancelled, cancelled.Status); - } - - [Fact] - public async Task CancelTaskAsync_IsIdempotent() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // First cancellation - await store.CancelTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - - // Act - Second cancellation - var result = await store.CancelTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - - // Assert - Should return unchanged task, not throw - Assert.Equal(McpTaskStatus.Cancelled, result.Status); - } - - [Fact] - public async Task CancelTaskAsync_DoesNotCancelCompletedTask() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var result = new CallToolResult { Content = [new TextContentBlock { Text = "Success" }] }; - var resultElement = JsonSerializer.SerializeToElement(result, McpJsonUtilities.DefaultOptions); - await store.StoreTaskResultAsync(task.TaskId, McpTaskStatus.Completed, resultElement, null, TestContext.Current.CancellationToken); - - // Act - var cancelResult = await store.CancelTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - - // Assert - Task remains completed - Assert.Equal(McpTaskStatus.Completed, cancelResult.Status); - } - - [Fact] - public async Task CancelTaskAsync_EnforcesSessionIsolation() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync( - () => store.CancelTaskAsync(task.TaskId, "session-2", TestContext.Current.CancellationToken)); - } - - [Fact] - public async Task Dispose_StopsCleanupTimer() - { - // Arrange - Use FakeTimeProvider for deterministic testing - var fakeTime = new FakeTimeProvider(DateTimeOffset.UtcNow); - var cleanupInterval = TimeSpan.FromMilliseconds(100); - - var store = new TestInMemoryMcpTaskStore( - defaultTtl: null, - maxTtl: null, - pollInterval: null, - cleanupInterval: cleanupInterval, - pageSize: 100, - maxTasks: null, - maxTasksPerSession: null, - timeProvider: fakeTime); - - var metadata = new McpTaskMetadata { TimeToLive = TimeSpan.FromMilliseconds(100) }; - await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Act - store.Dispose(); - - // Advance time - timer should not fire after dispose - fakeTime.Advance(TimeSpan.FromTicks(cleanupInterval.Ticks * 3)); - - // Assert - Store should still be accessible after dispose (no exceptions) - // The cleanup timer should have stopped - Assert.True(true); // If we get here without exceptions, dispose worked - } - - [Fact] - public async Task CleanupExpiredTasks_RemovesExpiredTasks() - { - // Arrange - Use FakeTimeProvider for deterministic testing - var fakeTime = new FakeTimeProvider(DateTimeOffset.UtcNow); - var cleanupInterval = TimeSpan.FromMilliseconds(50); - var ttl = TimeSpan.FromMilliseconds(100); - - using var store = new TestInMemoryMcpTaskStore( - defaultTtl: null, - maxTtl: null, - pollInterval: null, - cleanupInterval: cleanupInterval, - pageSize: 100, - maxTasks: null, - maxTasksPerSession: null, - timeProvider: fakeTime); - - var metadata = new McpTaskMetadata { TimeToLive = ttl }; - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Verify task exists initially - var resultBefore = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Single(resultBefore.Tasks); - - // Advance time past the TTL to make task expired - fakeTime.Advance(ttl + TimeSpan.FromMilliseconds(1)); - - // Trigger cleanup by advancing time past cleanup interval - fakeTime.Advance(cleanupInterval); - - // Act - List tasks to verify cleanup happened - var resultAfter = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.Empty(resultAfter.Tasks); // Task should be cleaned up by the timer - } - - [Fact] - public async Task DefaultTtl_AppliedWhenNoTtlSpecified() - { - // Arrange - var defaultTtl = TimeSpan.FromMinutes(10); - using var store = new InMemoryMcpTaskStore(defaultTtl: defaultTtl); - var metadata = new McpTaskMetadata(); // No TTL specified - - // Act - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(defaultTtl, task.TimeToLive); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.InputRequired, task.Status); + Assert.NotNull(task.InputRequests); + Assert.Single(task.InputRequests); + Assert.True(task.InputRequests.ContainsKey("req1")); } [Fact] - public async Task MultipleOperations_ConcurrentAccess() + public async Task SetInputRequestsAsync_MergesMultipleRequests() { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var tasks = new List>(); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); - // Act - Create multiple tasks concurrently - for (int i = 0; i < 10; i++) + await store.SetInputRequestsAsync(created.TaskId, new Dictionary { - int taskNum = i; - tasks.Add(Task.Run(async () => - { - var metadata = new McpTaskMetadata(); - return await store.CreateTaskAsync(metadata, new RequestId($"req-{taskNum}"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - })); - } - - var createdTasks = await Task.WhenAll(tasks); - - // Assert - All tasks should be created with unique IDs - Assert.Equal(10, createdTasks.Length); - Assert.Equal(10, createdTasks.Select(t => t.TaskId).Distinct().Count()); - } - - [Fact] - public void Constructor_ThrowsWhenDefaultTtlExceedsMaxTtl() - { - // Arrange & Act & Assert - var exception = Assert.Throws(() => - new InMemoryMcpTaskStore( - defaultTtl: TimeSpan.FromHours(2), - maxTtl: TimeSpan.FromHours(1))); - - Assert.Equal("defaultTtl", exception.ParamName); - Assert.Contains("Default TTL", exception.Message); - Assert.Contains("cannot exceed maximum TTL", exception.Message); - } - - [Fact] - public async Task CreateTaskAsync_UsesConfiguredPollInterval() - { - // Arrange - using var store = new InMemoryMcpTaskStore(pollInterval: TimeSpan.FromMilliseconds(2500)); - var metadata = new McpTaskMetadata(); - - // Act - var task = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(TimeSpan.FromMilliseconds(2500), task.PollInterval); - } - - [Fact] - public void Constructor_ThrowsWhenPollIntervalIsZero() - { - // Arrange & Act & Assert - var exception = Assert.Throws(() => - new InMemoryMcpTaskStore(pollInterval: TimeSpan.Zero)); - - Assert.Equal("pollInterval", exception.ParamName); - Assert.Contains("Poll interval must be positive", exception.Message); - } - - [Fact] - public void Constructor_ThrowsWhenPollIntervalIsNegative() - { - // Arrange & Act & Assert - var exception = Assert.Throws(() => - new InMemoryMcpTaskStore(pollInterval: TimeSpan.FromMilliseconds(-100))); - - Assert.Equal("pollInterval", exception.ParamName); - Assert.Contains("Poll interval must be positive", exception.Message); - } - - [Fact] - public async Task GetTaskAsync_ReturnsDefensiveCopy() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var createdTask = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Act - Get the task and modify the returned copy - var retrievedTask = await store.GetTaskAsync(createdTask.TaskId, null, TestContext.Current.CancellationToken); - var originalStatus = retrievedTask!.Status; - retrievedTask.Status = McpTaskStatus.Completed; - retrievedTask.StatusMessage = "Modified externally"; - - // Assert - Get the task again and verify the stored state wasn't affected - var taskAgain = await store.GetTaskAsync(createdTask.TaskId, null, TestContext.Current.CancellationToken); - Assert.Equal(originalStatus, taskAgain!.Status); - Assert.Null(taskAgain.StatusMessage); - } - - [Fact] - public async Task ListTasksAsync_ReturnsDefensiveCopies() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Act - List tasks and modify the returned copies - var result = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - var firstTask = result.Tasks[0]; - var originalTaskId = firstTask.TaskId; - firstTask.Status = McpTaskStatus.Failed; - firstTask.StatusMessage = "Modified in list"; - - // Assert - Get the task directly and verify the stored state wasn't affected - var directTask = await store.GetTaskAsync(originalTaskId, null, TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Working, directTask!.Status); - Assert.Null(directTask.StatusMessage); - } - - [Fact] - public async Task CancelTaskAsync_ReturnsDefensiveCopy() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var metadata = new McpTaskMetadata(); - var createdTask = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Act - Cancel the task and modify the returned copy - var cancelledTask = await store.CancelTaskAsync(createdTask.TaskId, null, TestContext.Current.CancellationToken); - cancelledTask.StatusMessage = "Modified after cancel"; - cancelledTask.Status = McpTaskStatus.Completed; - - // Assert - Get the task again and verify it's still cancelled with no message - var taskAgain = await store.GetTaskAsync(createdTask.TaskId, null, TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Cancelled, taskAgain!.Status); - Assert.Null(taskAgain.StatusMessage); - } - - [Fact] - public async Task ConcurrentUpdates_HandlesContentionCorrectly() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var task = await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Act - Launch 100 concurrent updates to the same task - var updateTasks = Enumerable.Range(0, 100).Select(i => - Task.Run(async () => - { - try - { - await store.UpdateTaskStatusAsync(task.TaskId, McpTaskStatus.Working, $"Update {i}", null, TestContext.Current.CancellationToken); - return true; - } - catch - { - return false; - } - })); - - var results = await Task.WhenAll(updateTasks); - - // Assert - All updates should succeed (retry loop handles contention) - Assert.All(results, success => Assert.True(success)); - - // Verify task is still in valid state (one of the updates won) - var finalTask = await store.GetTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - Assert.NotNull(finalTask); - Assert.Equal(McpTaskStatus.Working, finalTask.Status); - Assert.Matches(@"Update \d+", finalTask.StatusMessage!); - } - - [Fact] - public async Task ConcurrentStoreResult_OnlyFirstWins() - { - // Arrange - using var store = new InMemoryMcpTaskStore(); - var task = await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Act - Try to store results concurrently (only first should succeed) - var storeTasks = Enumerable.Range(0, 10).Select(i => - Task.Run(async () => - { - try - { - var result = new CallToolResult { Content = [new TextContentBlock { Text = $"Result {i}" }] }; - var resultElement = JsonSerializer.SerializeToElement(result, McpJsonUtilities.DefaultOptions); - await store.StoreTaskResultAsync( - task.TaskId, - McpTaskStatus.Completed, - resultElement, - null, - TestContext.Current.CancellationToken); - return i; - } - catch (InvalidOperationException) - { - // Expected: task already in terminal state - return -1; - } - })); - - var results = await Task.WhenAll(storeTasks); - var successfulUpdates = results.Where(r => r >= 0).ToList(); - - // Assert - Exactly one update should succeed, others should fail - Assert.Single(successfulUpdates); - - // Verify the winning result is stored - var finalTask = await store.GetTaskAsync(task.TaskId, null, TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Completed, finalTask!.Status); - } - - [Fact] - public async Task ListTasksAsync_PaginationWithCustomPageSize() - { - // Arrange - Use small page size for testing - using var store = new InMemoryMcpTaskStore(pageSize: 10); - - // Create 25 tasks - for (int i = 0; i < 25; i++) + ["req1"] = JsonSerializer.SerializeToElement("first", McpJsonUtilities.DefaultOptions) + }, CT); + await store.SetInputRequestsAsync(created.TaskId, new Dictionary { - await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId($"req-{i}"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - } - - // Act - Paginate through all tasks - var result1 = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - var result2 = await store.ListTasksAsync(cursor: result1.NextCursor, cancellationToken: TestContext.Current.CancellationToken); - var result3 = await store.ListTasksAsync(cursor: result2.NextCursor, cancellationToken: TestContext.Current.CancellationToken); + ["req2"] = JsonSerializer.SerializeToElement("second", McpJsonUtilities.DefaultOptions) + }, CT); - // Assert - Assert.Equal(10, result1.Tasks.Count); - Assert.NotNull(result1.NextCursor); - Assert.Equal(10, result2.Tasks.Count); - Assert.NotNull(result2.NextCursor); - Assert.Equal(5, result3.Tasks.Count); - Assert.Null(result3.NextCursor); - - // Verify no duplicates across pages - var allTaskIds = result1.Tasks.Concat(result2.Tasks).Concat(result3.Tasks).Select(t => t.TaskId).ToList(); - Assert.Equal(25, allTaskIds.Distinct().Count()); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.InputRequired, task.Status); + Assert.NotNull(task.InputRequests); + Assert.Equal(2, task.InputRequests.Count); + Assert.True(task.InputRequests.ContainsKey("req1")); + Assert.True(task.InputRequests.ContainsKey("req2")); } [Fact] - public async Task ListTasksAsync_NoDuplicatesWithIdenticalTimestamps() + public async Task ResolveInputRequestsAsync_RemovesMatchedRequests() { - // Arrange - using var store = new InMemoryMcpTaskStore(pageSize: 5); - - // Create tasks with identical metadata to increase chance of timestamp collision - var createTasks = Enumerable.Range(0, 20).Select(i => - store.CreateTaskAsync(new McpTaskMetadata(), new RequestId($"req-{i}"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken)); - - await Task.WhenAll(createTasks); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); - // Act - Collect all tasks through pagination - var allTasks = new List(); - string? cursor = null; - do + await store.SetInputRequestsAsync(created.TaskId, new Dictionary { - var result = await store.ListTasksAsync(cursor: cursor, cancellationToken: TestContext.Current.CancellationToken); - allTasks.AddRange(result.Tasks); - cursor = result.NextCursor; - } while (cursor != null); + ["req1"] = JsonSerializer.SerializeToElement("request1", McpJsonUtilities.DefaultOptions), + ["req2"] = JsonSerializer.SerializeToElement("request2", McpJsonUtilities.DefaultOptions), + }, CT); - // Assert - No duplicates - var taskIds = allTasks.Select(t => t.TaskId).ToList(); - Assert.Equal(20, taskIds.Count); - Assert.Equal(20, taskIds.Distinct().Count()); + await store.ResolveInputRequestsAsync(created.TaskId, ["req1"], CT); - // Verify tasks are properly ordered - Assert.Equal(allTasks.OrderBy(t => t.CreatedAt).ThenBy(t => t.TaskId).Select(t => t.TaskId), taskIds); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.InputRequired, task.Status); + Assert.NotNull(task.InputRequests); + Assert.Single(task.InputRequests); + Assert.True(task.InputRequests.ContainsKey("req2")); } [Fact] - public async Task ListTasksAsync_ConsistentWithExpiredTasksRemovedBetweenPages() + public async Task ResolveInputRequestsAsync_TransitionsToWorkingWhenAllSatisfied() { - // Arrange - Use FakeTimeProvider for deterministic testing - var fakeTime = new FakeTimeProvider(DateTimeOffset.UtcNow); - var ttl = TimeSpan.FromSeconds(1); - using var store = new TestInMemoryMcpTaskStore( - defaultTtl: ttl, - maxTtl: null, - pollInterval: null, - cleanupInterval: Timeout.InfiniteTimeSpan, - pageSize: 5, - maxTasks: null, - maxTasksPerSession: null, - timeProvider: fakeTime); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); - // Create 15 tasks - for (int i = 0; i < 15; i++) + await store.SetInputRequestsAsync(created.TaskId, new Dictionary { - await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId($"req-{i}"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - } - - // Act - Get first page immediately - var result1 = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); + ["req1"] = JsonSerializer.SerializeToElement("request1", McpJsonUtilities.DefaultOptions), + }, CT); - // Advance time past TTL to make tasks expire - fakeTime.Advance(ttl + TimeSpan.FromMilliseconds(500)); + await store.ResolveInputRequestsAsync(created.TaskId, ["req1"], CT); - // Get second page after expiration - var result2 = await store.ListTasksAsync(cursor: result1.NextCursor, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - First page should have 5 tasks, second page should have 0 (all expired) - Assert.Equal(5, result1.Tasks.Count); - Assert.NotNull(result1.NextCursor); - Assert.Empty(result2.Tasks); - Assert.Null(result2.NextCursor); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.Working, task.Status); } [Fact] - public async Task ListTasksAsync_KeysetPaginationMaintainsConsistencyWithNewTasks() + public async Task SetCompletedAsync_ThrowsForUnknownTask() { - // Arrange - using var store = new InMemoryMcpTaskStore(pageSize: 5); - - // Create 10 initial tasks - for (int i = 0; i < 10; i++) - { - await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId($"req-{i}"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - } - - // Get first page - var result1 = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(5, result1.Tasks.Count); + var store = new InMemoryMcpTaskStore(); - // Add more tasks between pages (these should appear in later queries, not retroactively in page 2) - for (int i = 10; i < 15; i++) - { - await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId($"req-{i}"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - } - - // Get second page using cursor from before new tasks were added - var result2 = await store.ListTasksAsync(cursor: result1.NextCursor, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Second page should have 5 tasks from original set - Assert.Equal(5, result2.Tasks.Count); - Assert.NotNull(result2.NextCursor); - - // Verify no overlap between pages - var page1Ids = result1.Tasks.Select(t => t.TaskId).ToHashSet(); - var page2Ids = result2.Tasks.Select(t => t.TaskId).ToHashSet(); - Assert.Empty(page1Ids.Intersect(page2Ids)); + await Assert.ThrowsAsync( + () => store.SetCompletedAsync("nonexistent", JsonSerializer.SerializeToElement("x", McpJsonUtilities.DefaultOptions), CT)); } [Fact] - public async Task UpdateTaskStatusAsync_ConcurrentWithList_NoCorruption() + public async Task ConcurrentUpdates_DoNotLoseData() { - // Arrange - using var store = new InMemoryMcpTaskStore(pageSize: 10); + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); - // Create 20 tasks - var tasks = new List(); - for (int i = 0; i < 20; i++) - { - var task = await store.CreateTaskAsync(new McpTaskMetadata(), new RequestId($"req-{i}"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - tasks.Add(task); - } - - // Act - Concurrently list and update tasks - var ct = TestContext.Current.CancellationToken; - var listTask = Task.Run(async () => - { - var allTasks = new List(); - string? cursor = null; - do + var tasks = Enumerable.Range(0, 50).Select(i => + store.SetInputRequestsAsync(created.TaskId, new Dictionary { - var result = await store.ListTasksAsync(cursor: cursor, cancellationToken: TestContext.Current.CancellationToken); - allTasks.AddRange(result.Tasks); - cursor = result.NextCursor; - await Task.Delay(10, ct); // Small delay to increase chance of interleaving - } while (cursor != null); - return allTasks; - }, ct); - - var updateTask = Task.Run(async () => - { - foreach (var task in tasks) - { - await store.UpdateTaskStatusAsync(task.TaskId, McpTaskStatus.Working, "Updated", null, TestContext.Current.CancellationToken); - await Task.Delay(5, ct); // Small delay - } - }, ct); - - await Task.WhenAll(listTask, updateTask); - var listedTasks = await listTask; - - // Assert - Should have listed all tasks without duplicates or corruption - Assert.Equal(20, listedTasks.Count); - Assert.Equal(20, listedTasks.Select(t => t.TaskId).Distinct().Count()); - } - - [Fact] - public void Constructor_ThrowsForInvalidMaxTasks() - { - // Assert - Assert.Throws(() => new InMemoryMcpTaskStore(maxTasks: 0)); - Assert.Throws(() => new InMemoryMcpTaskStore(maxTasks: -1)); - } - - [Fact] - public void Constructor_ThrowsForInvalidMaxTasksPerSession() - { - // Assert - Assert.Throws(() => new InMemoryMcpTaskStore(maxTasksPerSession: 0)); - Assert.Throws(() => new InMemoryMcpTaskStore(maxTasksPerSession: -1)); - } - - [Fact] - public async Task CreateTaskAsync_EnforcesMaxTasksLimit() - { - // Arrange - using var store = new InMemoryMcpTaskStore(maxTasks: 3); - var metadata = new McpTaskMetadata(); - - // Act - Create up to the limit - await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - await store.CreateTaskAsync(metadata, new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - await store.CreateTaskAsync(metadata, new RequestId("req-3"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Assert - Fourth task should throw - var ex = await Assert.ThrowsAsync(() => - store.CreateTaskAsync(metadata, new RequestId("req-4"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken)); - Assert.Contains("Maximum number of tasks (3) has been reached", ex.Message); - } - - [Fact] - public async Task CreateTaskAsync_EnforcesMaxTasksPerSessionLimit() - { - // Arrange - using var store = new InMemoryMcpTaskStore(maxTasksPerSession: 2); - var metadata = new McpTaskMetadata(); - - // Act - Create up to the limit for session-1 - await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - await store.CreateTaskAsync(metadata, new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - - // Assert - Third task for session-1 should throw - var ex = await Assert.ThrowsAsync(() => - store.CreateTaskAsync(metadata, new RequestId("req-3"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken)); - Assert.Contains("Maximum number of tasks per session (2) has been reached", ex.Message); - Assert.Contains("session-1", ex.Message); - } - - [Fact] - public async Task CreateTaskAsync_MaxTasksPerSession_AllowsDifferentSessions() - { - // Arrange - using var store = new InMemoryMcpTaskStore(maxTasksPerSession: 2); - var metadata = new McpTaskMetadata(); - - // Act - Create 2 tasks for session-1 - await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - await store.CreateTaskAsync(metadata, new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - - // Should still be able to create tasks for session-2 - var task3 = await store.CreateTaskAsync(metadata, new RequestId("req-3"), new JsonRpcRequest { Method = "test" }, "session-2", TestContext.Current.CancellationToken); - var task4 = await store.CreateTaskAsync(metadata, new RequestId("req-4"), new JsonRpcRequest { Method = "test" }, "session-2", TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task3); - Assert.NotNull(task4); - } - - [Fact] - public async Task CreateTaskAsync_MaxTasksPerSession_DoesNotApplyToNullSession() - { - // Arrange - using var store = new InMemoryMcpTaskStore(maxTasksPerSession: 1); - var metadata = new McpTaskMetadata(); - - // Act - Create multiple tasks with null session (should not be limited) - var task1 = await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var task2 = await store.CreateTaskAsync(metadata, new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - var task3 = await store.CreateTaskAsync(metadata, new RequestId("req-3"), new JsonRpcRequest { Method = "test" }, null, TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task1); - Assert.NotNull(task2); - Assert.NotNull(task3); - } - - [Fact] - public async Task CreateTaskAsync_CombinesMaxTasksAndMaxTasksPerSession() - { - // Arrange - Global limit of 5, per-session limit of 2 - using var store = new InMemoryMcpTaskStore(maxTasks: 5, maxTasksPerSession: 2); - var metadata = new McpTaskMetadata(); - - // Create 2 tasks for session-1 (hits per-session limit) - await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - await store.CreateTaskAsync(metadata, new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - - // session-1 is at its limit - await Assert.ThrowsAsync(() => - store.CreateTaskAsync(metadata, new RequestId("req-3"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken)); + [$"req{i}"] = JsonSerializer.SerializeToElement($"value{i}", McpJsonUtilities.DefaultOptions) + }, CT)); - // But session-2 can still create tasks - await store.CreateTaskAsync(metadata, new RequestId("req-4"), new JsonRpcRequest { Method = "test" }, "session-2", TestContext.Current.CancellationToken); - await store.CreateTaskAsync(metadata, new RequestId("req-5"), new JsonRpcRequest { Method = "test" }, "session-2", TestContext.Current.CancellationToken); + await Task.WhenAll(tasks); - // Now global limit is reached (4 tasks total, but 5th would be 5) - // Wait, we have 4 tasks, should be able to create one more - await store.CreateTaskAsync(metadata, new RequestId("req-6"), new JsonRpcRequest { Method = "test" }, "session-3", TestContext.Current.CancellationToken); - - // Now at 5 tasks (global limit), should throw - var ex = await Assert.ThrowsAsync(() => - store.CreateTaskAsync(metadata, new RequestId("req-7"), new JsonRpcRequest { Method = "test" }, "session-3", TestContext.Current.CancellationToken)); - Assert.Contains("Maximum number of tasks (5) has been reached", ex.Message); - } - - [Fact] - public async Task CreateTaskAsync_MaxTasksPerSession_ExcludesExpiredTasks() - { - // Arrange - Use FakeTimeProvider for deterministic testing - var fakeTime = new FakeTimeProvider(DateTimeOffset.UtcNow); - var shortTtl = TimeSpan.FromMilliseconds(50); - using var store = new TestInMemoryMcpTaskStore( - defaultTtl: shortTtl, - maxTtl: null, - pollInterval: null, - cleanupInterval: Timeout.InfiniteTimeSpan, - pageSize: 100, - maxTasks: null, - maxTasksPerSession: 1, - timeProvider: fakeTime); - - var metadata = new McpTaskMetadata(); - - // Create first task - await store.CreateTaskAsync(metadata, new RequestId("req-1"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - - // Advance time past TTL to make the first task expire - fakeTime.Advance(shortTtl + TimeSpan.FromMilliseconds(1)); - - // Should be able to create another task since the first one expired - var task2 = await store.CreateTaskAsync(metadata, new RequestId("req-2"), new JsonRpcRequest { Method = "test" }, "session-1", TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task2); - } - - [Fact] - public async Task ListTasksAsync_KeysetPaginationWorksWithIdenticalTimestamps() - { - // Arrange - Use a fake time provider to create tasks with identical timestamps - var fakeTime = new FakeTimeProvider(DateTimeOffset.UtcNow); - using var store = new TestInMemoryMcpTaskStore( - defaultTtl: null, - maxTtl: null, - pollInterval: null, - cleanupInterval: Timeout.InfiniteTimeSpan, - pageSize: 5, - maxTasks: null, - maxTasksPerSession: null, - timeProvider: fakeTime); - - // Create 10 tasks - all with the EXACT same timestamp - var createdTasks = new List(); - for (int i = 0; i < 10; i++) - { - var task = await store.CreateTaskAsync( - new McpTaskMetadata(), - new RequestId($"req-{i}"), - new JsonRpcRequest { Method = "test" }, - null, - TestContext.Current.CancellationToken); - createdTasks.Add(task); - } - - // Verify all tasks have the same CreatedAt timestamp - var firstTimestamp = createdTasks[0].CreatedAt; - Assert.All(createdTasks, task => Assert.Equal(firstTimestamp, task.CreatedAt)); - - // Act - Get first page - var result1 = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - First page should have 5 tasks - Assert.Equal(5, result1.Tasks.Count); - Assert.NotNull(result1.NextCursor); - - // Get second page using cursor - var result2 = await store.ListTasksAsync(cursor: result1.NextCursor, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Second page should have 5 tasks - Assert.Equal(5, result2.Tasks.Count); - Assert.Null(result2.NextCursor); // No more pages - - // Verify no overlap between pages - var page1Ids = result1.Tasks.Select(t => t.TaskId).ToHashSet(); - var page2Ids = result2.Tasks.Select(t => t.TaskId).ToHashSet(); - Assert.Empty(page1Ids.Intersect(page2Ids)); - - // Verify we got all 10 tasks exactly once - var allReturnedIds = page1Ids.Union(page2Ids).ToHashSet(); - var allCreatedIds = createdTasks.Select(t => t.TaskId).ToHashSet(); - Assert.Equal(allCreatedIds, allReturnedIds); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.InputRequired, task.Status); + Assert.NotNull(task.InputRequests); + Assert.Equal(50, task.InputRequests.Count); } [Fact] - public async Task ListTasksAsync_TasksCreatedAfterFirstPageWithSameTimestampAppearInSecondPage() + public async Task ResolveInputRequestsAsync_ForExtraKeys_DoesNotThrow() { - // Arrange - Use a fake time provider so we can control timestamps precisely - var fakeTime = new FakeTimeProvider(DateTimeOffset.UtcNow); - using var store = new TestInMemoryMcpTaskStore( - defaultTtl: null, - maxTtl: null, - pollInterval: null, - cleanupInterval: Timeout.InfiniteTimeSpan, - pageSize: 5, - maxTasks: null, - maxTasksPerSession: null, - timeProvider: fakeTime); - - // Create initial 6 tasks - all with the same timestamp - // (6 so that first page has 5 and cursor points to task 5) - var initialTasks = new List(); - for (int i = 0; i < 6; i++) - { - var task = await store.CreateTaskAsync( - new McpTaskMetadata(), - new RequestId($"req-initial-{i}"), - new JsonRpcRequest { Method = "test" }, - null, - TestContext.Current.CancellationToken); - initialTasks.Add(task); - } + var store = new InMemoryMcpTaskStore(); + var created = await store.CreateTaskAsync(CT); - // Get first page - should have 5 tasks with a cursor - var result1 = await store.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(5, result1.Tasks.Count); - Assert.NotNull(result1.NextCursor); + await store.ResolveInputRequestsAsync(created.TaskId, ["unknown-key"], CT); - // Now create 5 more tasks AFTER we got the first page cursor - // These tasks have the SAME timestamp as the cursor (time hasn't moved) - // Due to monotonic UUID v7 with counter, they should sort AFTER the cursor - var laterTasks = new List(); - for (int i = 0; i < 5; i++) - { - var task = await store.CreateTaskAsync( - new McpTaskMetadata(), - new RequestId($"req-later-{i}"), - new JsonRpcRequest { Method = "test" }, - null, - TestContext.Current.CancellationToken); - laterTasks.Add(task); - } - - // Verify all tasks have the same timestamp - var allTasks = initialTasks.Concat(laterTasks).ToList(); - var firstTimestamp = allTasks[0].CreatedAt; - Assert.All(allTasks, task => Assert.Equal(firstTimestamp, task.CreatedAt)); - - // Get ALL remaining pages - var allSubsequentTasks = new List(); - string? cursor = result1.NextCursor; - while (cursor != null) - { - var result = await store.ListTasksAsync(cursor: cursor, cancellationToken: TestContext.Current.CancellationToken); - allSubsequentTasks.AddRange(result.Tasks); - cursor = result.NextCursor; - } - - // Verify no overlap between first page and subsequent - var page1Ids = result1.Tasks.Select(t => t.TaskId).ToHashSet(); - var subsequentIds = allSubsequentTasks.Select(t => t.TaskId).ToHashSet(); - Assert.Empty(page1Ids.Intersect(subsequentIds)); - - // Verify we got all tasks - var allReturnedIds = page1Ids.Union(subsequentIds).ToHashSet(); - var allCreatedIds = allTasks.Select(t => t.TaskId).ToHashSet(); - Assert.Equal(allCreatedIds, allReturnedIds); - - // Most importantly: verify ALL the later tasks (created after first page) are surfaced - // in the subsequent pages - var laterTaskIds = laterTasks.Select(t => t.TaskId).ToHashSet(); - Assert.Superset(laterTaskIds, subsequentIds); + var task = await store.GetTaskAsync(created.TaskId, CT); + Assert.NotNull(task); + Assert.Equal(McpTaskStatus.Working, task.Status); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTaskAugmentedValidationTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTaskAugmentedValidationTests.cs deleted file mode 100644 index 4c045cb21..000000000 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTaskAugmentedValidationTests.cs +++ /dev/null @@ -1,1012 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Server; - -/// -/// Tests for validation of task-augmented tool call requests. -/// -public class McpServerTaskAugmentedValidationTests : LoggedTest -{ - public McpServerTaskAugmentedValidationTests(ITestOutputHelper outputHelper) - : base(outputHelper) - { - } - - private static IDictionary CreateArguments(string key, object? value) - { - return new Dictionary - { - [key] = JsonDocument.Parse($"\"{value}\"").RootElement.Clone() - }; - } - - [Fact] - public async Task CallToolAsTask_ThrowsError_WhenNoTaskStoreConfigured() - { - // Arrange - Server WITHOUT task store, but with an async tool (auto-marked as taskSupport: optional) - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - // Note: NOT configuring a task store - builder.WithTools([McpServerTool.Create( - async (string input, CancellationToken ct) => - { - await Task.Delay(10, ct); - return $"Result: {input}"; - }, - new McpServerToolCreateOptions - { - Name = "async-tool", - Description = "An async tool" - })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act & Assert - Calling with task metadata should fail - var exception = await Assert.ThrowsAsync(async () => - await client.CallToolAsync( - new CallToolRequestParams - { - Name = "async-tool", - Arguments = CreateArguments("input", "test"), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken)); - - Assert.Contains("not supported", exception.Message, StringComparison.OrdinalIgnoreCase); - } - - [Fact] - public async Task CallToolAsTask_ThrowsError_WhenToolHasForbiddenTaskSupport() - { - // Arrange - Server with task store, but tool has taskSupport: forbidden (sync tool) - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - // Create a synchronous tool - which will have taskSupport: forbidden (default) - builder.WithTools([McpServerTool.Create( - (string input) => $"Result: {input}", - new McpServerToolCreateOptions - { - Name = "sync-tool", - Description = "A synchronous tool that does not support tasks" - })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act & Assert - Calling with task metadata should fail because tool doesn't support it - var exception = await Assert.ThrowsAsync(async () => - await client.CallToolAsync( - new CallToolRequestParams - { - Name = "sync-tool", - Arguments = CreateArguments("input", "test"), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken)); - - Assert.Contains("does not support task-augmented execution", exception.Message, StringComparison.OrdinalIgnoreCase); - Assert.Equal(McpErrorCode.InvalidParams, exception.ErrorCode); - } - - [Fact] - public async Task CallToolAsTask_Succeeds_WhenToolHasOptionalTaskSupport() - { - // Arrange - Server with task store and async tool (auto-marked as taskSupport: optional) - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - builder.WithTools([McpServerTool.Create( - async (string input, CancellationToken ct) => - { - await Task.Delay(10, ct); - return $"Result: {input}"; - }, - new McpServerToolCreateOptions - { - Name = "async-tool", - Description = "An async tool with optional task support" - })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - Calling with task metadata should succeed - var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "async-tool", - Arguments = CreateArguments("input", "test"), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - - // Assert - Should return a task - Assert.NotNull(result.Task); - Assert.NotNull(result.Task.TaskId); - } - - [Fact] - public async Task CallToolNormally_Succeeds_WhenToolHasForbiddenTaskSupport() - { - // Arrange - Server with task store, but calling without task metadata - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - builder.WithTools([McpServerTool.Create( - (string input) => $"Result: {input}", - new McpServerToolCreateOptions - { - Name = "sync-tool", - Description = "A synchronous tool" - })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - Calling WITHOUT task metadata should succeed - var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "sync-tool", - Arguments = CreateArguments("input", "test"), - }, - TestContext.Current.CancellationToken); - - // Assert - Should return normal result - Assert.NotNull(result.Content); - Assert.Null(result.Task); - } - - [Fact] - public async Task CallToolNormally_ThrowsError_WhenToolHasRequiredTaskSupport() - { - // Arrange - Server with task store and tool with taskSupport: required - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - builder.WithTools([McpServerTool.Create( - async (string input, CancellationToken ct) => - { - await Task.Delay(100, ct); - return $"Result: {input}"; - }, - new McpServerToolCreateOptions - { - Name = "required-task-tool", - Description = "A tool that requires task-augmented execution", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Required } - })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act & Assert - Calling WITHOUT task metadata should fail - var exception = await Assert.ThrowsAsync(async () => - await client.CallToolAsync( - new CallToolRequestParams - { - Name = "required-task-tool", - Arguments = CreateArguments("input", "test"), - }, - TestContext.Current.CancellationToken)); - - Assert.Contains("requires task-augmented execution", exception.Message, StringComparison.OrdinalIgnoreCase); - Assert.Equal(McpErrorCode.InvalidParams, exception.ErrorCode); - } - - [Fact] - public async Task CallToolAsTask_Succeeds_WhenToolHasRequiredTaskSupport() - { - // Arrange - Server with task store and tool with taskSupport: required - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - builder.WithTools([McpServerTool.Create( - async (string input, CancellationToken ct) => - { - await Task.Delay(10, ct); - return $"Result: {input}"; - }, - new McpServerToolCreateOptions - { - Name = "required-task-tool", - Description = "A tool that requires task-augmented execution", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Required } - })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - Calling WITH task metadata should succeed - var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "required-task-tool", - Arguments = CreateArguments("input", "test"), - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - - // Assert - Should return a task - Assert.NotNull(result.Task); - Assert.NotNull(result.Task.TaskId); - } - - [Fact] - public async Task CallToolAsTask_WithRequiredTaskSupport_CanResolveScopedServicesFromDI() - { - // Regression test for https://github.com/modelcontextprotocol/csharp-sdk/issues/1430: - // ExecuteToolAsTaskAsync fires Task.Run and returns immediately, so the request-scoped - // IServiceProvider owned by InvokeHandlerAsync is disposed before the background task - // calls tool.InvokeAsync. The fix creates a fresh scope inside the Task.Run body so the - // tool can resolve DI services without hitting ObjectDisposedException. - var taskStore = new InMemoryMcpTaskStore(); - string? capturedValue = null; - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - - // Register a scoped service; resolving it through a disposed scope was the bug. - services.AddScoped(); - - // Register the tool via the factory pattern so that Services = sp is threaded - // through, enabling DI parameter binding at tool-creation time. - builder.Services.AddSingleton(sp => McpServerTool.Create( - async (ITaskToolDiService svc, CancellationToken ct) => - { - await Task.Delay(10, ct); - capturedValue = svc.GetValue(); - return capturedValue; - }, - new McpServerToolCreateOptions - { - Name = "di-required-task-tool", - Services = sp, - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Required } - })); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "di-required-task-tool", - Task = new McpTaskMetadata() - }, - TestContext.Current.CancellationToken); - - Assert.NotNull(result.Task); - string taskId = result.Task.TaskId; - - // Poll until the background task reaches a terminal state. - McpTask taskStatus; - int attempts = 0; - do - { - await Task.Delay(50, TestContext.Current.CancellationToken); - taskStatus = await client.GetTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - attempts++; - } - while (taskStatus.Status == McpTaskStatus.Working && attempts < 50); - - // Without the fix, the background task would fail with ObjectDisposedException when - // resolving ITaskToolDiService, causing the task to reach McpTaskStatus.Failed. - Assert.Equal(McpTaskStatus.Completed, taskStatus.Status); - Assert.Equal("hello-from-di", capturedValue); - } - - [Fact] - public async Task CallToolAsTaskAsync_WithProgress_CreatesTaskSuccessfully() - { - // Arrange - Server with task store and a tool that reports progress - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - builder.WithTools([McpServerTool.Create( - async (IProgress progress, CancellationToken ct) => - { - // Report progress - progress.Report(new ProgressNotificationValue - { - Progress = 50, - Total = 100, - Message = "Halfway done" - }); - await Task.Delay(10, ct); - return "Completed with progress"; - }, - new McpServerToolCreateOptions - { - Name = "progress-task-tool", - Description = "A tool that reports progress during task execution" - })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Track progress notifications received by client - var receivedProgressValues = new List(); - IProgress progress = new SynchronousProgress(value => - { - lock (receivedProgressValues) - { - receivedProgressValues.Add(value); - } - }); - - // Act - Call tool as task with progress tracking - var mcpTask = await client.CallToolAsTaskAsync( - "progress-task-tool", - arguments: null, - taskMetadata: new McpTaskMetadata(), - progress: progress, - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Task was created successfully - Assert.NotNull(mcpTask); - Assert.NotEmpty(mcpTask.TaskId); - - // Note: Progress notifications may not be received for task-augmented calls - // because the notification handler is disposed when the task creation response returns. - // This test verifies the code path executes without errors. - } - - [Fact] - public async Task CallToolAsTaskAsync_WithoutProgress_DoesNotRequireProgressHandler() - { - // Arrange - Server with task store and a tool that reports progress - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - builder.WithTools([McpServerTool.Create( - async (IProgress progress, CancellationToken ct) => - { - // Tool reports progress but client doesn't listen - progress.Report(new ProgressNotificationValue { Progress = 50, Message = "Halfway" }); - await Task.Delay(10, ct); - return "Done"; - }, - new McpServerToolCreateOptions - { - Name = "progress-tool", - Description = "A tool that reports progress" - })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - Call tool as task WITHOUT progress tracking (progress: null) - var mcpTask = await client.CallToolAsTaskAsync( - "progress-tool", - arguments: null, - taskMetadata: new McpTaskMetadata(), - progress: null, // No progress handler - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Task was still created successfully - Assert.NotNull(mcpTask); - Assert.NotEmpty(mcpTask.TaskId); - } - - private sealed class SynchronousProgress(Action callback) : IProgress - { - public void Report(ProgressNotificationValue value) => callback(value); - } - - #region Error Code Tests for Invalid/Nonexistent TaskId - - [Fact] - public async Task GetTaskAsync_WithNonexistentTaskId_ReturnsInvalidParamsError() - { - // Arrange - Spec: "Invalid or nonexistent taskId in tasks/get: -32602 (Invalid params)" - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "ok"; }, - new McpServerToolCreateOptions { Name = "test-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act & Assert - var exception = await Assert.ThrowsAsync(async () => - await client.GetTaskAsync("nonexistent-task-id-12345", cancellationToken: TestContext.Current.CancellationToken)); - - Assert.Equal(McpErrorCode.InvalidParams, exception.ErrorCode); - Assert.Contains("not found", exception.Message, StringComparison.OrdinalIgnoreCase); - } - - [Fact] - public async Task GetTaskResultAsync_WithNonexistentTaskId_ReturnsInvalidParamsError() - { - // Arrange - Spec: "Invalid or nonexistent taskId in tasks/result: -32602 (Invalid params)" - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "ok"; }, - new McpServerToolCreateOptions { Name = "test-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act & Assert - var exception = await Assert.ThrowsAsync(async () => - await client.GetTaskResultAsync("nonexistent-task-id-12345", cancellationToken: TestContext.Current.CancellationToken)); - - Assert.Equal(McpErrorCode.InvalidParams, exception.ErrorCode); - Assert.Contains("not found", exception.Message, StringComparison.OrdinalIgnoreCase); - } - - [Fact] - public async Task CancelTaskAsync_WithNonexistentTaskId_ReturnsError() - { - // Arrange - Spec: "Invalid or nonexistent taskId in tasks/cancel: -32602 (Invalid params)" - // NOTE: Current implementation throws InternalError; this documents actual behavior - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "ok"; }, - new McpServerToolCreateOptions { Name = "test-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act & Assert - var exception = await Assert.ThrowsAsync(async () => - await client.CancelTaskAsync("nonexistent-task-id-12345", cancellationToken: TestContext.Current.CancellationToken)); - - Assert.NotNull(exception); - } - - [Fact] - public async Task ListTasksAsync_WithInvalidCursor_HandlesGracefully() - { - // Arrange - Spec says: "Invalid or nonexistent cursor in tasks/list: -32602 (Invalid params)" - // Current implementation ignores invalid cursors gracefully - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "ok"; }, - new McpServerToolCreateOptions { Name = "test-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - Pass invalid cursor - var result = await client.ListTasksAsync( - new ListTasksRequestParams { Cursor = "invalid-cursor-that-does-not-exist" }, - TestContext.Current.CancellationToken); - - // Assert - Should return valid (possibly empty) result - Assert.NotNull(result.Tasks); - } - - #endregion - - #region Blocking Behavior Tests - - [Fact] - public async Task GetTaskResultAsync_ReturnsImmediately_WhenTaskAlreadyComplete() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "quick result"; }, - new McpServerToolCreateOptions { Name = "quick-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Create and wait for task to complete - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "quick-tool", - Arguments = new Dictionary(), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - string taskId = callResult.Task.TaskId; - - // Wait for task to complete - McpTask taskStatus; - do - { - await Task.Delay(50, TestContext.Current.CancellationToken); - taskStatus = await client.GetTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - } - while (taskStatus.Status == McpTaskStatus.Working); - - // Act - Get result (should return since task is complete) - var result = await client.GetTaskResultAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Should get valid result - Assert.NotEqual(default, result); - } - - [Fact] - public async Task GetTaskResultAsync_ForFailedTask_ReturnsErrorResult() - { - // Arrange - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => - { - await Task.Delay(10, ct); - throw new InvalidOperationException("Tool execution failed intentionally"); -#pragma warning disable CS0162 // Unreachable code detected - return "never"; -#pragma warning restore CS0162 - }, - new McpServerToolCreateOptions { Name = "failable-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Create a failing task - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "failable-tool", - Arguments = new Dictionary(), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - string taskId = callResult.Task.TaskId; - - // Wait for task to fail - McpTask taskStatus; - int attempts = 0; - do - { - await Task.Delay(50, TestContext.Current.CancellationToken); - taskStatus = await client.GetTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - attempts++; - } - while (taskStatus.Status == McpTaskStatus.Working && attempts < 50); - - Assert.Equal(McpTaskStatus.Failed, taskStatus.Status); - - // Act - Get result for failed task - var result = await client.GetTaskResultAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - var toolResult = result.Deserialize(McpJsonUtilities.DefaultOptions); - - // Assert - Failed task should have isError=true - Assert.NotNull(toolResult); - Assert.True(toolResult.IsError, "Failed task should have isError=true in the result"); - } - - #endregion - - #region Task Consistency and Lifecycle Tests - - [Fact] - public async Task ListTasksAsync_ContainsAllTasksRetrievableByGet() - { - // Arrange - Spec: "If a task is retrievable via tasks/get, it MUST be retrievable via tasks/list" - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (string input, CancellationToken ct) => { await Task.Delay(10, ct); return $"Result: {input}"; }, - new McpServerToolCreateOptions { Name = "test-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Create multiple tasks - var createdTaskIds = new List(); - for (int i = 0; i < 3; i++) - { - var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = new Dictionary - { - ["input"] = JsonDocument.Parse($"\"task-{i}\"").RootElement.Clone() - }, - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(result.Task); - createdTaskIds.Add(result.Task.TaskId); - } - - // Verify each task is retrievable via get - foreach (var taskId in createdTaskIds) - { - var task = await client.GetTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - Assert.NotNull(task); - } - - // Act - List all tasks - var allTasks = await client.ListTasksAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - All tasks must be in the list - foreach (var taskId in createdTaskIds) - { - Assert.Contains(allTasks, t => t.TaskId == taskId); - } - } - - [Fact] - public async Task NewTask_StartsInWorkingStatus() - { - // Arrange - Spec: "Tasks MUST begin in the working status when created." - var taskStarted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var taskCanComplete = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => - { - taskStarted.TrySetResult(true); - await taskCanComplete.Task.WaitAsync(ct); - return "done"; - }, - new McpServerToolCreateOptions { Name = "controllable-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - Create a task - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "controllable-tool", - Arguments = new Dictionary(), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(callResult.Task); - Assert.Equal(McpTaskStatus.Working, callResult.Task.Status); - - // Cleanup - taskCanComplete.TrySetResult(true); - } - - [Fact] - public async Task Task_ContainsRequiredTimestamps() - { - // Arrange - Spec: "Receivers MUST include createdAt and lastUpdatedAt timestamps" - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "ok"; }, - new McpServerToolCreateOptions { Name = "test-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - var beforeCreation = DateTimeOffset.UtcNow; - - // Act - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = new Dictionary(), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - var afterCreation = DateTimeOffset.UtcNow; - - // Assert - Assert.NotNull(callResult.Task); - Assert.NotEqual(default, callResult.Task.CreatedAt); - Assert.NotEqual(default, callResult.Task.LastUpdatedAt); - Assert.True(callResult.Task.CreatedAt >= beforeCreation.AddSeconds(-1)); - Assert.True(callResult.Task.CreatedAt <= afterCreation.AddSeconds(1)); - } - - [Fact] - public async Task Task_IncludesTtlInResponse() - { - // Arrange - Spec: "Receivers MUST include the actual ttl duration in tasks/get responses." - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "ok"; }, - new McpServerToolCreateOptions { Name = "test-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = new Dictionary(), - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(30) } - }, - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(callResult.Task); - Assert.NotNull(callResult.Task.TimeToLive); - - var taskStatus = await client.GetTaskAsync(callResult.Task.TaskId, cancellationToken: TestContext.Current.CancellationToken); - Assert.NotNull(taskStatus.TimeToLive); - } - - [Fact] - public async Task Task_IncludesPollIntervalInResponse() - { - // Arrange - Spec: "Receivers MAY include a pollInterval value" - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "ok"; }, - new McpServerToolCreateOptions { Name = "test-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "test-tool", - Arguments = new Dictionary(), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(callResult.Task); - Assert.NotNull(callResult.Task.PollInterval); - } - - #endregion - - #region Server Without Tasks Capability Tests - - [Fact] - public async Task ServerCapabilities_DoNotIncludeTasks_WhenNoTaskStore() - { - // Arrange - Spec: "If capabilities.tasks is not defined, the peer SHOULD NOT attempt to create tasks" - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - // NOT configuring a task store - builder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => { await Task.Delay(10, ct); return "ok"; }, - new McpServerToolCreateOptions { Name = "async-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Assert - Assert.Null(client.ServerCapabilities?.Tasks); - } - - [Fact] - public async Task NormalRequest_Succeeds_WhenTasksNotSupported() - { - // Arrange - Normal requests should work without task support - await using var fixture = new ServerClientFixture(LoggerFactory, configureServer: (services, builder) => - { - builder.WithTools([McpServerTool.Create( - (string input) => $"Sync result: {input}", - new McpServerToolCreateOptions { Name = "sync-tool" })]); - }); - - await using var client = await fixture.CreateClientAsync(TestContext.Current.CancellationToken); - - // Act - var result = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "sync-tool", - Arguments = CreateArguments("input", "test") - }, - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(result.Content); - Assert.Null(result.Task); - } - - #endregion - - private interface ITaskToolDiService - { - string GetValue(); - } - - private sealed class TaskToolDiService : ITaskToolDiService - { - public string GetValue() => "hello-from-di"; - } - - /// - /// Helper fixture for creating server-client pairs with custom configuration. - /// - private sealed class ServerClientFixture : IAsyncDisposable - { - private readonly System.IO.Pipelines.Pipe _clientToServerPipe = new(); - private readonly System.IO.Pipelines.Pipe _serverToClientPipe = new(); - private readonly IServiceProvider _serviceProvider; - private readonly McpServer _server; - private readonly Task _serverTask; - private readonly CancellationTokenSource _cts; - private readonly ILoggerFactory _loggerFactory; - - public ServerClientFixture( - ILoggerFactory loggerFactory, - Action? configureServer = null) - { - _loggerFactory = loggerFactory; - _cts = new CancellationTokenSource(); - - var services = new ServiceCollection(); - services.AddLogging(); - services.AddSingleton(loggerFactory); - - var builder = services - .AddMcpServer() - .WithStreamServerTransport( - _clientToServerPipe.Reader.AsStream(), - _serverToClientPipe.Writer.AsStream()); - - configureServer?.Invoke(services, builder); - - _serviceProvider = services.BuildServiceProvider(validateScopes: true); - _server = _serviceProvider.GetRequiredService(); - _serverTask = _server.RunAsync(_cts.Token); - } - - public async Task CreateClientAsync(CancellationToken cancellationToken) - { - return await McpClient.CreateAsync( - new StreamClientTransport( - serverInput: _clientToServerPipe.Writer.AsStream(), - _serverToClientPipe.Reader.AsStream(), - _loggerFactory), - loggerFactory: _loggerFactory, - cancellationToken: cancellationToken); - } - - public async ValueTask DisposeAsync() - { - await _cts.CancelAsync(); - - _clientToServerPipe.Writer.Complete(); - _serverToClientPipe.Writer.Complete(); - - try - { - await _serverTask; - } - catch (OperationCanceledException) - { - // Expected - } - - if (_serviceProvider is IAsyncDisposable asyncDisposable) - { - await asyncDisposable.DisposeAsync(); - } - else if (_serviceProvider is IDisposable disposable) - { - disposable.Dispose(); - } - - _cts.Dispose(); - } - } -} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTaskMethodsTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTaskMethodsTests.cs deleted file mode 100644 index d908bbb7f..000000000 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTaskMethodsTests.cs +++ /dev/null @@ -1,762 +0,0 @@ -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; -using System.Runtime.InteropServices; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Server; - -/// -/// Tests for McpServer methods that query tasks on the client (Phase 4 implementation). -/// -public class McpServerTaskMethodsTests : LoggedTest -{ - private readonly McpServerOptions _options; - - public McpServerTaskMethodsTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { -#if !NET - Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); -#endif - _options = CreateOptions(); - } - - private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = null) - { - return new McpServerOptions - { - ProtocolVersion = "2024", - InitializationTimeout = TimeSpan.FromSeconds(30), - Capabilities = capabilities, - }; - } - - #region SampleAsTaskAsync Tests - - [Fact] - public async Task SampleAsTaskAsync_ThrowsException_WhenClientDoesNotSupportSampling() - { - // Arrange - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities(), TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.SampleAsTaskAsync( - new CreateMessageRequestParams { Messages = [], MaxTokens = 1000 }, - new McpTaskMetadata(), - CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task SampleAsTaskAsync_ThrowsException_WhenClientDoesNotSupportTaskAugmentedSampling() - { - // Arrange - Client supports sampling but NOT task-augmented sampling - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Sampling = new SamplingCapability(), - // Note: No Tasks capability - }, TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.SampleAsTaskAsync( - new CreateMessageRequestParams { Messages = [], MaxTokens = 1000 }, - new McpTaskMetadata(), - CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task SampleAsTaskAsync_SendsRequest_WhenClientSupportsTaskAugmentedSampling() - { - // Arrange - await using var transport = new TestServerTransport(); - - // Configure transport to return a task result for sampling - transport.MockTask = new McpTask - { - TaskId = "sample-task-123", - Status = McpTaskStatus.Working, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Sampling = new SamplingCapability(), - Tasks = new McpTasksCapability - { - Requests = new RequestMcpTasksCapability - { - Sampling = new SamplingMcpTasksCapability - { - CreateMessage = new CreateMessageMcpTasksCapability() - } - } - } - }, TestContext.Current.CancellationToken); - - // Act - var task = await server.SampleAsTaskAsync( - new CreateMessageRequestParams { Messages = [], MaxTokens = 1000 }, - new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(5) }, - TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task); - Assert.Equal("sample-task-123", task.TaskId); - Assert.Equal(McpTaskStatus.Working, task.Status); - - // Verify the request was sent with task metadata - var samplingRequest = transport.SentMessages.OfType() - .FirstOrDefault(r => r.Method == RequestMethods.SamplingCreateMessage); - Assert.NotNull(samplingRequest); - var requestParams = JsonSerializer.Deserialize( - samplingRequest.Params, McpJsonUtilities.DefaultOptions); - Assert.NotNull(requestParams?.Task); - - await transport.DisposeAsync(); - await runTask; - } - - #endregion - - #region ElicitAsTaskAsync Tests - - [Fact] - public async Task ElicitAsTaskAsync_ThrowsException_WhenClientDoesNotSupportElicitation() - { - // Arrange - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities(), TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.ElicitAsTaskAsync( - new ElicitRequestParams { Message = "test", RequestedSchema = new() }, - new McpTaskMetadata(), - CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task ElicitAsTaskAsync_ThrowsException_WhenClientDoesNotSupportTaskAugmentedElicitation() - { - // Arrange - Client supports elicitation but NOT task-augmented elicitation - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Elicitation = new ElicitationCapability { Form = new() }, - // Note: No Tasks capability - }, TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.ElicitAsTaskAsync( - new ElicitRequestParams { Message = "test", RequestedSchema = new() }, - new McpTaskMetadata(), - CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task ElicitAsTaskAsync_SendsRequest_WhenClientSupportsTaskAugmentedElicitation() - { - // Arrange - await using var transport = new TestServerTransport(); - - // Configure transport to return a task result for elicitation - transport.MockTask = new McpTask - { - TaskId = "elicit-task-456", - Status = McpTaskStatus.Working, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Elicitation = new ElicitationCapability { Form = new() }, - Tasks = new McpTasksCapability - { - Requests = new RequestMcpTasksCapability - { - Elicitation = new ElicitationMcpTasksCapability - { - Create = new CreateElicitationMcpTasksCapability() - } - } - } - }, TestContext.Current.CancellationToken); - - // Act - var task = await server.ElicitAsTaskAsync( - new ElicitRequestParams { Message = "Please provide input", RequestedSchema = new() }, - new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(10) }, - TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task); - Assert.Equal("elicit-task-456", task.TaskId); - Assert.Equal(McpTaskStatus.Working, task.Status); - - await transport.DisposeAsync(); - await runTask; - } - - #endregion - - #region GetTaskAsync Tests - - [Fact] - public async Task GetTaskAsync_ThrowsException_WhenClientDoesNotSupportTasks() - { - // Arrange - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities(), TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.GetTaskAsync("task-id", CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task GetTaskAsync_SendsRequest_AndReturnsTask() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTask = new McpTask - { - TaskId = "client-task-789", - Status = McpTaskStatus.Completed, - StatusMessage = "Task completed successfully", - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability() - }, TestContext.Current.CancellationToken); - - // Act - var task = await server.GetTaskAsync("client-task-789", TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task); - Assert.Equal("client-task-789", task.TaskId); - Assert.Equal(McpTaskStatus.Completed, task.Status); - - // Verify the request was sent - var taskRequest = transport.SentMessages.OfType() - .FirstOrDefault(r => r.Method == RequestMethods.TasksGet); - Assert.NotNull(taskRequest); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task GetTaskAsync_ThrowsArgumentException_WhenTaskIdIsEmpty() - { - // Arrange - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability() - }, TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.GetTaskAsync("", CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - #endregion - - #region GetTaskResultAsync Tests - - [Fact] - public async Task GetTaskResultAsync_ThrowsException_WhenClientDoesNotSupportTasks() - { - // Arrange - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities(), TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.GetTaskResultAsync("task-id", cancellationToken: CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task GetTaskResultAsync_ReturnsDeserializedResult() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTaskResult = new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Hello from task result!" }], - Model = "gpt-4" - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability() - }, TestContext.Current.CancellationToken); - - // Act - var result = await server.GetTaskResultAsync( - "task-id", cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(result); - Assert.Equal("gpt-4", result.Model); - Assert.Single(result.Content); - var textContent = Assert.IsType(result.Content[0]); - Assert.Equal("Hello from task result!", textContent.Text); - - await transport.DisposeAsync(); - await runTask; - } - - #endregion - - #region ListTasksAsync Tests - - [Fact] - public async Task ListTasksAsync_ThrowsException_WhenClientDoesNotSupportTasks() - { - // Arrange - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities(), TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.ListTasksAsync(CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task ListTasksAsync_ThrowsException_WhenClientDoesNotSupportTaskListing() - { - // Arrange - Client supports tasks but NOT task listing - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability - { - // Note: No List capability - } - }, TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.ListTasksAsync(CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task ListTasksAsync_ReturnsTaskList() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTaskList = - [ - new McpTask - { - TaskId = "task-a", - Status = McpTaskStatus.Completed, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-10), - LastUpdatedAt = DateTimeOffset.UtcNow, - }, - new McpTask - { - TaskId = "task-b", - Status = McpTaskStatus.Working, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }, - new McpTask - { - TaskId = "task-c", - Status = McpTaskStatus.Failed, - StatusMessage = "Task failed", - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-2), - LastUpdatedAt = DateTimeOffset.UtcNow, - } - ]; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability - { - List = new ListMcpTasksCapability() - } - }, TestContext.Current.CancellationToken); - - // Act - var tasks = await server.ListTasksAsync(TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(tasks); - Assert.Equal(3, tasks.Count); - Assert.Equal("task-a", tasks[0].TaskId); - Assert.Equal("task-b", tasks[1].TaskId); - Assert.Equal("task-c", tasks[2].TaskId); - - await transport.DisposeAsync(); - await runTask; - } - - #endregion - - #region CancelTaskAsync Tests - - [Fact] - public async Task CancelTaskAsync_ThrowsException_WhenClientDoesNotSupportTasks() - { - // Arrange - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities(), TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.CancelTaskAsync("task-id", CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task CancelTaskAsync_ThrowsException_WhenClientDoesNotSupportTaskCancellation() - { - // Arrange - Client supports tasks but NOT task cancellation - await using var transport = new TestServerTransport(); - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability - { - // Note: No Cancel capability - } - }, TestContext.Current.CancellationToken); - - // Act & Assert - await Assert.ThrowsAsync(async () => - await server.CancelTaskAsync("task-id", CancellationToken.None)); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task CancelTaskAsync_SendsRequest_AndReturnsCancelledTask() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTask = new McpTask - { - TaskId = "task-to-cancel", - Status = McpTaskStatus.Working, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-3), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability - { - Cancel = new CancelMcpTasksCapability() - } - }, TestContext.Current.CancellationToken); - - // Act - var task = await server.CancelTaskAsync("task-to-cancel", TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task); - Assert.Equal("task-to-cancel", task.TaskId); - Assert.Equal(McpTaskStatus.Cancelled, task.Status); - - // Verify the request was sent - var cancelRequest = transport.SentMessages.OfType() - .FirstOrDefault(r => r.Method == RequestMethods.TasksCancel); - Assert.NotNull(cancelRequest); - - await transport.DisposeAsync(); - await runTask; - } - - #endregion - - #region PollTaskUntilCompleteAsync Tests - - [Fact] - public async Task PollTaskUntilCompleteAsync_ReturnsImmediately_WhenTaskIsAlreadyComplete() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTask = new McpTask - { - TaskId = "completed-task", - Status = McpTaskStatus.Completed, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability() - }, TestContext.Current.CancellationToken); - - // Act - var task = await server.PollTaskUntilCompleteAsync("completed-task", TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task); - Assert.Equal("completed-task", task.TaskId); - Assert.Equal(McpTaskStatus.Completed, task.Status); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task PollTaskUntilCompleteAsync_ReturnsTask_WhenTaskFails() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTask = new McpTask - { - TaskId = "failed-task", - Status = McpTaskStatus.Failed, - StatusMessage = "Task execution failed", - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability() - }, TestContext.Current.CancellationToken); - - // Act - var task = await server.PollTaskUntilCompleteAsync("failed-task", TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task); - Assert.Equal("failed-task", task.TaskId); - Assert.Equal(McpTaskStatus.Failed, task.Status); - Assert.Equal("Task execution failed", task.StatusMessage); - - await transport.DisposeAsync(); - await runTask; - } - - #endregion - - #region WaitForTaskResultAsync Tests - - [Fact] - public async Task WaitForTaskResultAsync_ReturnsTaskAndResult_WhenTaskCompletes() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTask = new McpTask - { - TaskId = "task-with-result", - Status = McpTaskStatus.Completed, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - transport.MockTaskResult = new CreateMessageResult - { - Content = [new TextContentBlock { Text = "Final result from task" }], - Model = "test-model" - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability() - }, TestContext.Current.CancellationToken); - - // Act - var (task, result) = await server.WaitForTaskResultAsync( - "task-with-result", cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(task); - Assert.Equal("task-with-result", task.TaskId); - Assert.Equal(McpTaskStatus.Completed, task.Status); - - Assert.NotNull(result); - Assert.Equal("test-model", result.Model); - var textContent = Assert.IsType(Assert.Single(result.Content)); - Assert.Equal("Final result from task", textContent.Text); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task WaitForTaskResultAsync_ThrowsException_WhenTaskFails() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTask = new McpTask - { - TaskId = "failed-task", - Status = McpTaskStatus.Failed, - StatusMessage = "Something went wrong", - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability() - }, TestContext.Current.CancellationToken); - - // Act & Assert - var ex = await Assert.ThrowsAsync(async () => - await server.WaitForTaskResultAsync( - "failed-task", cancellationToken: TestContext.Current.CancellationToken)); - - Assert.Contains("failed", ex.Message, StringComparison.OrdinalIgnoreCase); - Assert.Contains("Something went wrong", ex.Message); - - await transport.DisposeAsync(); - await runTask; - } - - [Fact] - public async Task WaitForTaskResultAsync_ThrowsException_WhenTaskIsCancelled() - { - // Arrange - await using var transport = new TestServerTransport(); - transport.MockTask = new McpTask - { - TaskId = "cancelled-task", - Status = McpTaskStatus.Cancelled, - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-5), - LastUpdatedAt = DateTimeOffset.UtcNow, - }; - - await using var server = McpServer.Create(transport, _options, LoggerFactory); - var runTask = server.RunAsync(TestContext.Current.CancellationToken); - await InitializeServerAsync(transport, new ClientCapabilities - { - Tasks = new McpTasksCapability() - }, TestContext.Current.CancellationToken); - - // Act & Assert - var ex = await Assert.ThrowsAsync(async () => - await server.WaitForTaskResultAsync( - "cancelled-task", cancellationToken: TestContext.Current.CancellationToken)); - - Assert.Contains("cancelled", ex.Message, StringComparison.OrdinalIgnoreCase); - - await transport.DisposeAsync(); - await runTask; - } - - #endregion - - #region Helper Methods - - private static async Task InitializeServerAsync(TestServerTransport transport, ClientCapabilities capabilities, CancellationToken cancellationToken = default) - { - var initializeRequest = new JsonRpcRequest - { - Id = new RequestId("init-1"), - Method = RequestMethods.Initialize, - Params = JsonSerializer.SerializeToNode(new InitializeRequestParams - { - ProtocolVersion = "2024-11-05", - Capabilities = capabilities, - ClientInfo = new Implementation { Name = "test-client", Version = "1.0.0" } - }, McpJsonUtilities.DefaultOptions) - }; - - var tcs = new TaskCompletionSource(); - transport.OnMessageSent = (message) => - { - if (message is JsonRpcResponse response && response.Id == initializeRequest.Id) - { - tcs.TrySetResult(true); - } - }; - - await transport.SendClientMessageAsync(initializeRequest, cancellationToken); - - // Wait for the initialize response to be sent - await tcs.Task.WaitAsync(TestConstants.DefaultTimeout, cancellationToken); - } - - #endregion -} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTaskNotificationTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTaskNotificationTests.cs deleted file mode 100644 index aa8941864..000000000 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTaskNotificationTests.cs +++ /dev/null @@ -1,152 +0,0 @@ -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.Collections.Concurrent; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Server; - -/// -/// Tests for task status notification functionality in McpServer. -/// -public class McpServerTaskNotificationTests : ClientServerTestBase -{ - public McpServerTaskNotificationTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - } - - [Fact] - public async Task NotifyTaskStatusAsync_SendsNotificationWithTaskDetails() - { - // Arrange - var client = await CreateMcpClientForServer(); - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - await using var registration = client.RegisterNotificationHandler( - NotificationMethods.TaskStatusNotification, - (notification, cancellationToken) => - { - if (notification.Params is { } paramsNode) - { - var notificationParams = JsonSerializer.Deserialize(paramsNode, McpJsonUtilities.DefaultOptions); - if (notificationParams is not null) - { - tcs.TrySetResult(notificationParams); - } - } - return default; - }); - - var mcpTask = new McpTask - { - TaskId = "task-123", - Status = McpTaskStatus.Working, - StatusMessage = "Processing request", - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-1), - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = TimeSpan.FromMinutes(10), - PollInterval = TimeSpan.FromSeconds(1) - }; - - // Act - await Server.NotifyTaskStatusAsync(mcpTask, TestContext.Current.CancellationToken); - var notification = await tcs.Task.WaitAsync(TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(mcpTask.TaskId, notification.TaskId); - Assert.Equal(mcpTask.Status, notification.Status); - Assert.Equal(mcpTask.StatusMessage, notification.StatusMessage); - Assert.Equal(mcpTask.CreatedAt, notification.CreatedAt); - Assert.Equal(mcpTask.LastUpdatedAt, notification.LastUpdatedAt); - Assert.Equal(mcpTask.TimeToLive, notification.TimeToLive); - Assert.Equal(mcpTask.PollInterval, notification.PollInterval); - } - - [Fact] - public async Task NotifyTaskStatusAsync_ThrowsOnNullTask() - { - // Arrange - await CreateMcpClientForServer(); - - // Act & Assert - await Assert.ThrowsAsync( - () => Server.NotifyTaskStatusAsync(null!, TestContext.Current.CancellationToken)); - } - - [Fact] - public async Task NotifyTaskStatusAsync_SendsMultipleNotificationsForDifferentStatuses() - { - // Arrange - var client = await CreateMcpClientForServer(); - var receivedNotifications = new ConcurrentBag(); - int expectedCount = 3; - var allReceivedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - await using var registration = client.RegisterNotificationHandler( - NotificationMethods.TaskStatusNotification, - (notification, cancellationToken) => - { - if (notification.Params is { } paramsNode) - { - var notificationParams = JsonSerializer.Deserialize(paramsNode, McpJsonUtilities.DefaultOptions); - if (notificationParams is not null) - { - receivedNotifications.Add(notificationParams); - if (receivedNotifications.Count >= expectedCount) - { - allReceivedTcs.TrySetResult(true); - } - } - } - return default; - }); - - // Act - Send notifications for different statuses - var task1 = new McpTask - { - TaskId = "task-456", - Status = McpTaskStatus.Working, - StatusMessage = "Starting", - CreatedAt = DateTimeOffset.UtcNow.AddMinutes(-1), - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = TimeSpan.FromMinutes(10), - PollInterval = TimeSpan.FromSeconds(1) - }; - - var task2 = new McpTask - { - TaskId = "task-456", - Status = McpTaskStatus.Working, - StatusMessage = "Processing", - CreatedAt = task1.CreatedAt, - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = TimeSpan.FromMinutes(10), - PollInterval = TimeSpan.FromSeconds(1) - }; - - var task3 = new McpTask - { - TaskId = "task-456", - Status = McpTaskStatus.Completed, - StatusMessage = "Done", - CreatedAt = task1.CreatedAt, - LastUpdatedAt = DateTimeOffset.UtcNow, - TimeToLive = TimeSpan.FromMinutes(10), - PollInterval = TimeSpan.FromSeconds(1) - }; - - await Server.NotifyTaskStatusAsync(task1, TestContext.Current.CancellationToken); - await Server.NotifyTaskStatusAsync(task2, TestContext.Current.CancellationToken); - await Server.NotifyTaskStatusAsync(task3, TestContext.Current.CancellationToken); - - // Wait for all notifications to be received - await allReceivedTcs.Task.WaitAsync(TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(3, receivedNotifications.Count); - Assert.Contains(receivedNotifications, n => n.Status == McpTaskStatus.Working && n.StatusMessage == "Starting"); - Assert.Contains(receivedNotifications, n => n.Status == McpTaskStatus.Working && n.StatusMessage == "Processing"); - Assert.Contains(receivedNotifications, n => n.Status == McpTaskStatus.Completed && n.StatusMessage == "Done"); - } -} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs new file mode 100644 index 000000000..f4c9ff9d0 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs @@ -0,0 +1,591 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Concurrent; +using System.Runtime.InteropServices; +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Tests for the MCP tasks extension (SEP-2663) end-to-end using a simple in-memory task store. +/// +public class McpServerTaskTests : ClientServerTestBase +{ + private readonly InMemoryTaskStore _taskStore = new(); + + public McpServerTaskTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.AddSingleton(_taskStore); + + mcpServerBuilder.Services.Configure(options => + { + options.Capabilities ??= new ServerCapabilities(); + + options.Handlers.CallToolWithTaskHandler = async (context, cancellationToken) => + { + var store = context.Server.Services!.GetRequiredService(); + var toolName = context.Params!.Name; + + if (toolName == "immediate-tool") + { + return new CallToolResult() + { + Content = [new TextContentBlock { Text = "immediate result" }], + }; + } + + if (toolName == "async-tool") + { + var taskId = store.CreateTask(); + return new CreateTaskResult + { + TaskId = taskId, + Status = McpTaskStatus.Working, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + PollIntervalMs = 50, + ResultType = "task", + }; + } + + if (toolName == "input-required-tool") + { + var taskId = store.CreateTask(McpTaskStatus.InputRequired); + return new CreateTaskResult + { + TaskId = taskId, + Status = McpTaskStatus.InputRequired, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + PollIntervalMs = 50, + ResultType = "task", + }; + } + + throw new McpException($"Unknown tool: {toolName}"); + }; + + options.Handlers.GetTaskHandler = async (context, cancellationToken) => + { + var store = context.Server.Services!.GetRequiredService(); + var taskId = context.Params!.TaskId; + return store.GetTask(taskId); + }; + + options.Handlers.UpdateTaskHandler = async (context, cancellationToken) => + { + var store = context.Server.Services!.GetRequiredService(); + var taskId = context.Params!.TaskId; + store.ProvideInput(taskId, context.Params.InputResponses); + return new UpdateTaskResult(); + }; + + options.Handlers.CancelTaskHandler = async (context, cancellationToken) => + { + var store = context.Server.Services!.GetRequiredService(); + var taskId = context.Params!.TaskId; + store.CancelTask(taskId); + return new CancelTaskResult(); + }; + }); + } + + [Fact] + public async Task CallToolAsync_ImmediateResult_ReturnsDirectly() + { + await using var client = await CreateMcpClientForServer(); + + var result = await client.CallToolAsync( + new CallToolRequestParams { Name = "immediate-tool" }, + TestContext.Current.CancellationToken); + + Assert.NotNull(result); + Assert.Single(result.Content); + Assert.Equal("immediate result", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task CallToolRawAsync_ImmediateResult_ReturnsResultNotTask() + { + await using var client = await CreateMcpClientForServer(); + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "immediate-tool" }, + TestContext.Current.CancellationToken); + + Assert.False(augmented.IsTask); + Assert.NotNull(augmented.Result); + Assert.Null(augmented.TaskCreated); + Assert.Equal("immediate result", Assert.IsType(augmented.Result.Content[0]).Text); + } + + [Fact] + public async Task CallToolRawAsync_AsyncTool_ReturnsTaskCreated() + { + await using var client = await CreateMcpClientForServer(); + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "async-tool" }, + TestContext.Current.CancellationToken); + + Assert.True(augmented.IsTask); + Assert.NotNull(augmented.TaskCreated); + Assert.Null(augmented.Result); + Assert.Equal(McpTaskStatus.Working, augmented.TaskCreated.Status); + Assert.Equal("task", augmented.TaskCreated.ResultType); + } + + [Fact] + public async Task CallToolAsync_AsyncTool_PollsUntilCompleted() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + // Complete the task after a brief delay so polling finds it. + _ = Task.Run(async () => + { + await Task.Delay(100, ct); + // The store should have exactly one task by now + var taskId = _taskStore.GetAllTaskIds().Single(); + _taskStore.CompleteTask(taskId, new CallToolResult + { + Content = [new TextContentBlock { Text = "async result" }], + }); + }, ct); + + var result = await client.CallToolAsync( + new CallToolRequestParams { Name = "async-tool" }, + ct); + + Assert.NotNull(result); + Assert.Single(result.Content); + Assert.Equal("async result", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task CallToolAsync_AsyncTool_FailedTask_ThrowsMcpException() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + _ = Task.Run(async () => + { + await Task.Delay(100, ct); + var taskId = _taskStore.GetAllTaskIds().Single(); + _taskStore.FailTask(taskId, new { code = -32000, message = "something went wrong" }); + }, ct); + + await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + new CallToolRequestParams { Name = "async-tool" }, + ct)); + } + + [Fact] + public async Task CallToolAsync_AsyncTool_CancelledTask_ThrowsOperationCancelled() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + _ = Task.Run(async () => + { + await Task.Delay(100, ct); + var taskId = _taskStore.GetAllTaskIds().Single(); + _taskStore.CancelTask(taskId); + }, ct); + + await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + new CallToolRequestParams { Name = "async-tool" }, + ct)); + } + + [Fact] + public async Task GetTaskAsync_ReturnsCurrentState() + { + await using var client = await CreateMcpClientForServer(); + + // Create a task via CallToolRawAsync + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "async-tool" }, + TestContext.Current.CancellationToken); + + var taskId = augmented.TaskCreated!.TaskId; + + // Should be working + var taskResult = await client.GetTaskAsync(taskId, TestContext.Current.CancellationToken); + Assert.IsType(taskResult); + Assert.Equal(taskId, taskResult.TaskId); + Assert.Equal(McpTaskStatus.Working, taskResult.Status); + } + + [Fact] + public async Task CancelTaskAsync_CancelsTask() + { + await using var client = await CreateMcpClientForServer(); + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "async-tool" }, + TestContext.Current.CancellationToken); + + var taskId = augmented.TaskCreated!.TaskId; + + // Cancel via client + var cancelResult = await client.CancelTaskAsync(taskId, TestContext.Current.CancellationToken); + Assert.NotNull(cancelResult); + + // Verify state changed + var taskResult = await client.GetTaskAsync(taskId, TestContext.Current.CancellationToken); + Assert.IsType(taskResult); + } + + [Fact] + public async Task ConfigureTasks_AdvertisesExtensionInCapabilities() + { + await using var client = await CreateMcpClientForServer(); + + // The server advertises the tasks extension during initialize. + // The client should see it in server capabilities after the handshake. + #pragma warning disable MCP_EXTENSIONS + var extensions = client.ServerCapabilities.Extensions; + #pragma warning restore MCP_EXTENSIONS + Assert.NotNull(extensions); + Assert.True(extensions.ContainsKey(McpExtensions.Tasks)); + } + + [Fact] + public async Task CreateTaskResult_HasResultTypeTask() + { + await using var client = await CreateMcpClientForServer(); + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "async-tool" }, + TestContext.Current.CancellationToken); + + Assert.True(augmented.IsTask); + Assert.Equal("task", augmented.TaskCreated!.ResultType); + } + + [Fact] + public async Task GetTaskAsync_ImmediatelyAfterCreate_Resolves() + { + // Strong consistency: tasks/get immediately after CreateTaskResult must resolve. + await using var client = await CreateMcpClientForServer(); + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "async-tool" }, + TestContext.Current.CancellationToken); + + var taskId = augmented.TaskCreated!.TaskId; + + // No delay — immediate get + var taskResult = await client.GetTaskAsync(taskId, TestContext.Current.CancellationToken); + Assert.NotNull(taskResult); + Assert.Equal(taskId, taskResult.TaskId); + } + + [Fact] + public async Task GetTaskAsync_UnknownTaskId_ThrowsWithInvalidParams() + { + await using var client = await CreateMcpClientForServer(); + + var ex = await Assert.ThrowsAsync(async () => + await client.GetTaskAsync("nonexistent-task-id-12345", TestContext.Current.CancellationToken)); + + // The server should reject with an error referencing the unknown task + Assert.Contains("Unknown task", ex.Message); + } + + [Fact] + public async Task CancelTask_AlreadyTerminal_AcknowledgesIdempotently() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "async-tool" }, ct); + var taskId = augmented.TaskCreated!.TaskId; + + // Cancel once + await client.CancelTaskAsync(taskId, ct); + + // Cancel again on terminal task — should not throw, returns ack + var ack = await client.CancelTaskAsync(taskId, ct); + Assert.NotNull(ack); + } + + [Fact] + public async Task UpdateTaskAsync_TransitionsFromInputRequired() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + // Create an input-required task + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "input-required-tool" }, ct); + + var taskId = augmented.TaskCreated!.TaskId; + + // Verify it's input_required + var taskResult = await client.GetTaskAsync(taskId, ct); + Assert.IsType(taskResult); + + // Provide input + var inputResponses = new Dictionary + { + ["resp-1"] = JsonSerializer.SerializeToElement(new { answer = "yes" }) + }; + await client.UpdateTaskAsync(new UpdateTaskRequestParams + { + TaskId = taskId, + InputResponses = inputResponses, + }, ct); + + // Verify it transitioned back to working + taskResult = await client.GetTaskAsync(taskId, ct); + Assert.IsType(taskResult); + } + + [Fact] + public async Task CallToolRawAsync_InjectsTaskCapabilityInMeta() + { + // Verify the server receives the task extension in _meta by intercepting + // the handler. The CallToolWithTaskHandler already receives the request, + // so we can observe the meta there. We test the client-side injection indirectly + // by confirming the server returns a task result (which requires the capability signal). + await using var client = await CreateMcpClientForServer(); + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "async-tool" }, + TestContext.Current.CancellationToken); + + // If the capability wasn't injected, the server couldn't have returned a task + Assert.True(augmented.IsTask); + } + + [Fact] + public async Task CallToolRawAsync_PreservesExistingUserMeta() + { + // Verify that user-supplied meta fields are not clobbered + await using var client = await CreateMcpClientForServer(); + + var userMeta = new JsonObject { ["customKey"] = "customValue" }; + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams + { + Name = "immediate-tool", + Meta = userMeta, + }, + TestContext.Current.CancellationToken); + + // Should still work — the meta was cloned, not destructively modified + Assert.False(augmented.IsTask); + Assert.Equal("immediate result", Assert.IsType(augmented.Result!.Content[0]).Text); + + // Original user meta should not be mutated + Assert.Single(userMeta); + Assert.Equal("customValue", (string)userMeta["customKey"]!); + } + + [Fact] + public async Task CallToolAsync_RespectsServerPollInterval() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + var startTime = DateTime.UtcNow; + + // Complete the task after a brief delay + _ = Task.Run(async () => + { + await Task.Delay(200, ct); + var taskId = _taskStore.GetAllTaskIds().Single(); + _taskStore.CompleteTask(taskId, new CallToolResult + { + Content = [new TextContentBlock { Text = "polled" }], + }); + }, ct); + + var result = await client.CallToolAsync( + new CallToolRequestParams { Name = "async-tool" }, ct); + + var elapsed = DateTime.UtcNow - startTime; + + // The server sets pollIntervalMs=50. The task completes after 200ms. + // So we expect at least 1 poll interval to have passed. + Assert.True(elapsed.TotalMilliseconds >= 50, $"Expected at least 50ms, got {elapsed.TotalMilliseconds}ms"); + Assert.Equal("polled", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task CallToolWithTaskHandler_ImplicitConversion_ReturnCallToolResult() + { + // Verify that the implicit conversion from CallToolResult to ResultOrCreatedTask works + // in the handler context — this is already tested by "immediate-tool" working correctly. + await using var client = await CreateMcpClientForServer(); + + var result = await client.CallToolAsync( + new CallToolRequestParams { Name = "immediate-tool" }, + TestContext.Current.CancellationToken); + + Assert.Equal("immediate result", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task CallToolHandler_And_CallToolWithTaskHandler_AreMutuallyExclusive() + { + var handlers = new McpServerHandlers(); + + handlers.CallToolWithTaskHandler = async (ctx, ct) => new CallToolResult(); + Assert.Throws(() => + handlers.CallToolHandler = async (ctx, ct) => new CallToolResult()); + + handlers = new McpServerHandlers(); + + handlers.CallToolHandler = async (ctx, ct) => new CallToolResult(); + Assert.Throws(() => + handlers.CallToolWithTaskHandler = async (ctx, ct) => new CallToolResult()); + } + + [Fact] + public async Task CallToolHandler_CanBeSetToNull_ThenOtherCanBeSet() + { + var handlers = new McpServerHandlers(); + + handlers.CallToolHandler = async (ctx, ct) => new CallToolResult(); + handlers.CallToolHandler = null; + + // Now setting the other should work + handlers.CallToolWithTaskHandler = async (ctx, ct) => new CallToolResult(); + Assert.NotNull(handlers.CallToolWithTaskHandler); + } + + /// + /// Simple in-memory task store for testing. + /// + private sealed class InMemoryTaskStore + { + private readonly ConcurrentDictionary _tasks = new(); + + public string CreateTask(McpTaskStatus initialStatus = McpTaskStatus.Working) + { + var taskId = Guid.NewGuid().ToString("N"); + _tasks[taskId] = new TaskEntry + { + Status = initialStatus, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + }; + return taskId; + } + + public IEnumerable GetAllTaskIds() => _tasks.Keys; + + public GetTaskResult GetTask(string taskId) + { + if (!_tasks.TryGetValue(taskId, out var entry)) + { + throw new McpException($"Unknown task: '{taskId}'"); + } + + return entry.Status switch + { + McpTaskStatus.Working => new WorkingTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + PollIntervalMs = 50, + }, + McpTaskStatus.Completed => new CompletedTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + TaskResult = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions), + }, + McpTaskStatus.Failed => new FailedTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + Error = JsonSerializer.SerializeToElement(entry.Error), + }, + McpTaskStatus.Cancelled => new CancelledTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + }, + McpTaskStatus.InputRequired => new InputRequiredTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + InputRequests = entry.InputRequests ?? new Dictionary(), + }, + _ => throw new InvalidOperationException($"Unexpected status: {entry.Status}") + }; + } + + public void CompleteTask(string taskId, CallToolResult result) + { + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Status = McpTaskStatus.Completed; + entry.Result = result; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + } + } + + public void FailTask(string taskId, object error) + { + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Status = McpTaskStatus.Failed; + entry.Error = error; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + } + } + + public void CancelTask(string taskId) + { + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Status = McpTaskStatus.Cancelled; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + } + } + + public void ProvideInput(string taskId, IDictionary inputResponses) + { + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.InputResponses = inputResponses; + // Transition back to working after receiving input + entry.Status = McpTaskStatus.Working; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + } + } + + private sealed class TaskEntry + { + public McpTaskStatus Status { get; set; } + public DateTimeOffset CreatedAt { get; set; } + public DateTimeOffset LastUpdatedAt { get; set; } + public CallToolResult? Result { get; set; } + public object? Error { get; set; } + public IDictionary? InputRequests { get; set; } + public IDictionary? InputResponses { get; set; } + } + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index d9febd721..ff513c64e 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -343,18 +343,13 @@ public async Task Initialize_CopiesAllCapabilityProperties() Resources = new ResourcesCapability(), Tools = new ToolsCapability(), Completions = new CompletionsCapability(), - Tasks = new McpTasksCapability(), Extensions = new Dictionary { ["io.test"] = new JsonObject() }, }; await Can_Handle_Requests( serverCapabilities: inputCapabilities, method: RequestMethods.Initialize, - configureOptions: options => - { - // Tasks capability requires a TaskStore - options.TaskStore = new InMemoryMcpTaskStore(); - }, + configureOptions: _ => { }, assertResult: (_, response) => { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index a283bf18c..01f18a631 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1079,82 +1079,6 @@ public async Task EnablePollingAsync_ThrowsInvalidOperationException_WhenTranspo Assert.Contains("Streamable HTTP", exception.Message); } - [Fact] - public void AsyncTool_AutomaticallyMarkedWithTaskSupport() - { - // Async tools should automatically get TaskSupport = Optional - McpServerTool tool = McpServerTool.Create(AsyncToolReturningTask); - - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, tool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void AsyncTool_ValueTask_AutomaticallyMarkedWithTaskSupport() - { - // Async tools returning ValueTask should also get TaskSupport = Optional - McpServerTool tool = McpServerTool.Create(AsyncToolReturningValueTask); - - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, tool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void AsyncTool_TaskOfT_AutomaticallyMarkedWithTaskSupport() - { - // Async tools returning Task should get TaskSupport = Optional - McpServerTool tool = McpServerTool.Create(AsyncToolReturningTaskOfT); - - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, tool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void AsyncTool_ValueTaskOfT_AutomaticallyMarkedWithTaskSupport() - { - // Async tools returning ValueTask should get TaskSupport = Optional - McpServerTool tool = McpServerTool.Create(AsyncToolReturningValueTaskOfT); - - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, tool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void SyncTool_NotMarkedWithTaskSupport() - { - // Synchronous tools should not have TaskSupport set - McpServerTool tool = McpServerTool.Create(SyncTool); - - Assert.Null(tool.ProtocolTool.Execution); - } - - private static async Task AsyncToolReturningTask() - { - await Task.Yield(); - } - - private static async ValueTask AsyncToolReturningValueTask() - { - await Task.Yield(); - } - - private static async Task AsyncToolReturningTaskOfT() - { - await Task.Yield(); - return "result"; - } - - private static async ValueTask AsyncToolReturningValueTaskOfT() - { - await Task.Yield(); - return "result"; - } - - private static string SyncTool() - { - return "sync result"; - } - [Description("Tool that returns data.")] [return: Description("The computed result")] private static string ToolWithReturnDescription() => "result"; diff --git a/tests/ModelContextProtocol.Tests/Server/McpTaskStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/McpTaskStoreTests.cs new file mode 100644 index 000000000..efa4a8ba3 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpTaskStoreTests.cs @@ -0,0 +1,321 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using Microsoft.Extensions.DependencyInjection; +using System.Runtime.InteropServices; +using System.Text.Json; + +#pragma warning disable MCPEXP001 + +namespace ModelContextProtocol.Tests.Server; + +/// +/// Tests for the -based auto-wiring of tools/call into tasks. +/// Verifies that setting enables task support +/// for -based tools. +/// +public class McpTaskStoreTests : ClientServerTestBase +{ + public McpTaskStoreTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder.WithTools(); + + mcpServerBuilder.Services.Configure(options => + { + options.TaskStore = new InMemoryMcpTaskStore + { + DefaultPollIntervalMs = 50, + }; + }); + } + + [Fact] + public async Task CallToolRawAsync_WithTaskCapability_ReturnsCreateTaskResult() + { + await using var client = await CreateMcpClientForServer(); + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "slow-tool" }, + TestContext.Current.CancellationToken); + + // Because the client signals task support and a TaskStore is configured, + // the server should wrap the tool execution in a task. + Assert.True(augmented.IsTask); + Assert.NotNull(augmented.TaskCreated); + Assert.Equal(McpTaskStatus.Working, augmented.TaskCreated.Status); + } + + [Fact] + public async Task CallToolAsync_WithTaskStore_PollsToCompletion() + { + await using var client = await CreateMcpClientForServer(); + + // CallToolAsync should poll until the background execution completes. + var result = await client.CallToolAsync( + new CallToolRequestParams { Name = "slow-tool" }, + TestContext.Current.CancellationToken); + + Assert.NotNull(result); + Assert.Single(result.Content); + Assert.Equal("slow result", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task CallToolAsync_WithTaskStore_FastTool_StillCreatesTask() + { + await using var client = await CreateMcpClientForServer(); + + // Even a fast tool should go through the task store when the client signals capability. + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "fast-tool" }, + TestContext.Current.CancellationToken); + + Assert.True(augmented.IsTask); + } + + [Fact] + public async Task GetTaskAsync_ViaStore_ReturnsCompletedResult() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "fast-tool" }, ct); + + var taskId = augmented.TaskCreated!.TaskId; + + // The fast-tool returns immediately in the background, so poll briefly + GetTaskResult? taskResult = null; + for (int i = 0; i < 20; i++) + { + await Task.Delay(50, ct); + taskResult = await client.GetTaskAsync(taskId, ct); + if (taskResult is CompletedTaskResult) + { + break; + } + } + + Assert.IsType(taskResult); + } + + [Fact] + public async Task CancelTaskAsync_ViaStore_TransitionsToCancelled() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + // Create a slow task that won't complete on its own + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "slow-tool" }, ct); + + var taskId = augmented.TaskCreated!.TaskId; + + // Cancel it + await client.CancelTaskAsync(taskId, ct); + + // Verify state + var taskResult = await client.GetTaskAsync(taskId, ct); + Assert.IsType(taskResult); + } + + [Fact] + public async Task GetTaskAsync_UnknownId_ThrowsWithInvalidParams() + { + await using var client = await CreateMcpClientForServer(); + + var ex = await Assert.ThrowsAsync(async () => + await client.GetTaskAsync("nonexistent-id", TestContext.Current.CancellationToken)); + + Assert.Contains("Unknown task", ex.Message); + } + + [Fact] + public async Task ToolExecution_Failure_StoresAsCompletedWithError() + { + await using var client = await CreateMcpClientForServer(); + var ct = TestContext.Current.CancellationToken; + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "failing-tool" }, ct); + + var taskId = augmented.TaskCreated!.TaskId; + + // Poll until completed (tool exceptions are wrapped as isError:true results) + GetTaskResult? taskResult = null; + for (int i = 0; i < 20; i++) + { + await Task.Delay(50, ct); + taskResult = await client.GetTaskAsync(taskId, ct); + if (taskResult is CompletedTaskResult) + { + break; + } + } + + var completed = Assert.IsType(taskResult); + // The tool result has isError: true + Assert.True(completed.TaskResult.GetProperty("isError").GetBoolean()); + } + + [Fact] + public async Task ElicitTool_ViaTask_RedirectsThroughStore() + { + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new McpClientHandlers + { + ElicitationHandler = (request, ct) => + { + // Client responds to the elicitation + return new ValueTask(new ElicitResult { Action = "accept" }); + } + } + }); + var ct = TestContext.Current.CancellationToken; + + // CallToolAsync will poll and resolve input requests automatically. + var result = await client.CallToolAsync( + new CallToolRequestParams { Name = "elicit-tool" }, ct); + + Assert.NotNull(result); + Assert.Equal("accepted", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task SampleTool_ViaTask_RedirectsThroughStore() + { + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new McpClientHandlers + { + SamplingHandler = (request, progress, ct) => + { + return new ValueTask(new CreateMessageResult + { + Content = [new TextContentBlock { Text = "sampled response" }], + Model = "test-model", + }); + } + } + }); + var ct = TestContext.Current.CancellationToken; + + var result = await client.CallToolAsync( + new CallToolRequestParams { Name = "sample-tool" }, ct); + + Assert.NotNull(result); + Assert.Equal("sampled response", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task ElicitTool_ViaTask_ClientDedups_InputRequests() + { + // This test verifies that the client doesn't re-resolve an input request + // that it has already responded to in a previous poll cycle. + int elicitCallCount = 0; + + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new McpClientHandlers + { + ElicitationHandler = (request, ct) => + { + Interlocked.Increment(ref elicitCallCount); + return new ValueTask(new ElicitResult { Action = "accept" }); + } + } + }); + var ct = TestContext.Current.CancellationToken; + + var result = await client.CallToolAsync( + new CallToolRequestParams { Name = "elicit-tool" }, ct); + + // The handler should be called exactly once despite potential multiple polls + Assert.Equal(1, elicitCallCount); + Assert.Equal("accepted", Assert.IsType(result.Content[0]).Text); + } + + [Fact] + public async Task CallToolRawAsync_ElicitTool_ReturnsTask_ThenPollShowsInputRequired() + { + await using var client = await CreateMcpClientForServer(new McpClientOptions + { + Handlers = new McpClientHandlers + { + ElicitationHandler = (request, ct) => + new ValueTask(new ElicitResult { Action = "accept" }) + } + }); + var ct = TestContext.Current.CancellationToken; + + var augmented = await client.CallToolRawAsync( + new CallToolRequestParams { Name = "elicit-tool" }, ct); + + Assert.True(augmented.IsTask); + var taskId = augmented.TaskCreated!.TaskId; + + // Poll — eventually the task should be input_required (elicit-tool calls ElicitAsync) + GetTaskResult? taskResult = null; + for (int i = 0; i < 40; i++) + { + await Task.Delay(50, ct); + taskResult = await client.GetTaskAsync(taskId, ct); + if (taskResult is InputRequiredTaskResult) + { + break; + } + } + + Assert.IsType(taskResult); + } + + [McpServerToolType] + private sealed class TaskStoreTestTools + { + [McpServerTool(Name = "slow-tool"), System.ComponentModel.Description("A tool that takes time")] + public static async Task SlowTool(CancellationToken cancellationToken) + { + await Task.Delay(200, cancellationToken); + return "slow result"; + } + + [McpServerTool(Name = "fast-tool"), System.ComponentModel.Description("A fast tool")] + public static string FastTool() => "fast result"; + + [McpServerTool(Name = "failing-tool"), System.ComponentModel.Description("A tool that fails")] + public static string FailingTool() => throw new InvalidOperationException("intentional failure"); + + [McpServerTool(Name = "elicit-tool"), System.ComponentModel.Description("A tool that elicits")] + public static async Task ElicitTool(McpServer server, CancellationToken cancellationToken) + { + var result = await server.ElicitAsync(new ElicitRequestParams + { + Message = "What is your name?", + RequestedSchema = new(), + }, cancellationToken); + + return result.Action == "accept" ? "accepted" : "declined"; + } + + [McpServerTool(Name = "sample-tool"), System.ComponentModel.Description("A tool that samples")] + public static async Task SampleTool(McpServer server, CancellationToken cancellationToken) + { + var result = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "hello" }] }], + MaxTokens = 100, + }, cancellationToken); + + return result.Content.OfType().FirstOrDefault()?.Text ?? "no response"; + } + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/TaskCancellationIntegrationTests.cs b/tests/ModelContextProtocol.Tests/Server/TaskCancellationIntegrationTests.cs deleted file mode 100644 index cc075a676..000000000 --- a/tests/ModelContextProtocol.Tests/Server/TaskCancellationIntegrationTests.cs +++ /dev/null @@ -1,509 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Server; - -/// -/// Integration tests for task cancellation behavior, including TTL-based automatic -/// cancellation and explicit cancellation via tasks/cancel. -/// -public class TaskCancellationIntegrationTests : ClientServerTestBase -{ - private readonly TaskCompletionSource _toolCancellationFired = new(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly TaskCompletionSource _toolStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); - - public TaskCancellationIntegrationTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - } - - protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) - { - // Add task store for server-side task support - var taskStore = new InMemoryMcpTaskStore(); - services.AddSingleton(taskStore); - - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - // Add a long-running tool that captures cancellation - mcpServerBuilder.WithTools([McpServerTool.Create( - async (CancellationToken ct) => - { - _toolStarted.TrySetResult(true); - try - { - // Wait indefinitely until cancelled - await Task.Delay(Timeout.Infinite, ct); - return "completed"; - } - catch (OperationCanceledException) - { - _toolCancellationFired.TrySetResult(true); - throw; - } - }, - new McpServerToolCreateOptions - { - Name = "long-running-tool", - Description = "A tool that runs until cancelled" - })]); - } - - private static IDictionary EmptyArguments() => new Dictionary(); - - [Fact] - public async Task TaskTool_CancellationToken_FiresWhenTtlExpires() - { - // Arrange - await using McpClient client = await CreateMcpClientForServer(); - - // Act - Call tool with short TTL (200ms) - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "long-running-tool", - Arguments = EmptyArguments(), - // Use a TTL long enough that thread pool scheduling delays on loaded CI machines - // don't cause the CTS to fire before the tool lambda begins executing. - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromSeconds(5) } - }, - cancellationToken: TestContext.Current.CancellationToken); - - // Verify task was created - Assert.NotNull(callResult.Task); - - // Wait for the tool to start executing - await _toolStarted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Assert - Wait for the cancellation to fire (should happen when TTL expires) - var cancelled = await _toolCancellationFired.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - Assert.True(cancelled, "Tool's CancellationToken should have been triggered when TTL expired"); - - // Note: TTL-based expiration does not explicitly set task status to Cancelled. - // Instead, expired tasks are considered "dead" and will be cleaned up by the task store. - // The task may still be in Working status or may throw "not found" if already cleaned up. - } - - [Fact] - public async Task TaskTool_CancellationToken_FiresWhenExplicitlyCancelled() - { - // Arrange - await using McpClient client = await CreateMcpClientForServer(); - - // Start a long-running task with a long TTL - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "long-running-tool", - Arguments = EmptyArguments(), - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(10) } - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - string taskId = callResult.Task.TaskId; - - // Wait for the tool to start executing - await _toolStarted.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - - // Act - Explicitly cancel the task - var cancelledTask = await client.CancelTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Wait for the cancellation to propagate to the tool - var cancelled = await _toolCancellationFired.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken); - Assert.True(cancelled, "Tool's CancellationToken should have been triggered by explicit cancellation"); - - // Verify task status - Assert.Equal(McpTaskStatus.Cancelled, cancelledTask.Status); - } - - [Fact] - public async Task TaskTool_CompletesSuccessfully_WhenNotCancelled() - { - // Arrange - Create a new test with a quick-completing tool - var quickToolCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - var services = new ServiceCollection(); - services.AddLogging(); - var taskStore = new InMemoryMcpTaskStore(); - services.AddSingleton(taskStore); - - var builder = services - .AddMcpServer() - .WithStreamServerTransport( - new System.IO.Pipelines.Pipe().Reader.AsStream(), - new System.IO.Pipelines.Pipe().Writer.AsStream()); - - builder.WithTools([McpServerTool.Create( - async (string input, CancellationToken ct) => - { - await Task.Delay(50, ct); // Quick operation - var result = $"Result: {input}"; - quickToolCompleted.TrySetResult(result); - return result; - }, - new McpServerToolCreateOptions - { - Name = "quick-tool", - Description = "A tool that completes quickly" - })]); - - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - await using var client = await CreateMcpClientForServer(); - - // Act - Call tool with long TTL - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "long-running-tool", // Use the base class tool which will block - Arguments = EmptyArguments(), - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(5) } - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - - // Verify task is in working state initially - var task = await client.GetTaskAsync(callResult.Task.TaskId, cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Working, task.Status); - } -} - -/// -/// Tests for task cancellation with multiple concurrent tasks. -/// -public class TaskCancellationConcurrencyTests : ClientServerTestBase -{ - private readonly Dictionary> _toolCancellations = new(); - private readonly Dictionary> _toolStarts = new(); - private readonly object _lock = new(); - - public TaskCancellationConcurrencyTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - } - - protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) - { - var taskStore = new InMemoryMcpTaskStore(); - services.AddSingleton(taskStore); - - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - // Tool that tracks cancellation per-invocation using a marker argument - mcpServerBuilder.WithTools([McpServerTool.Create( - async (string marker, CancellationToken ct) => - { - TaskCompletionSource startTcs; - TaskCompletionSource cancelTcs; - - lock (_lock) - { - if (!_toolStarts.TryGetValue(marker, out startTcs!)) - { - startTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _toolStarts[marker] = startTcs; - } - if (!_toolCancellations.TryGetValue(marker, out cancelTcs!)) - { - cancelTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _toolCancellations[marker] = cancelTcs; - } - } - - startTcs.TrySetResult(true); - - try - { - await Task.Delay(Timeout.Infinite, ct); - return $"completed-{marker}"; - } - catch (OperationCanceledException) - { - cancelTcs.TrySetResult(true); - throw; - } - }, - new McpServerToolCreateOptions - { - Name = "trackable-tool", - Description = "A tool that can be tracked by marker" - })]); - } - - private void RegisterMarker(string marker) - { - lock (_lock) - { - _toolStarts[marker] = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _toolCancellations[marker] = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - } - } - - private Task WaitForStart(string marker, CancellationToken ct) - { - lock (_lock) - { - return _toolStarts[marker].Task.WaitAsync(TestConstants.DefaultTimeout, ct); - } - } - - private Task WaitForCancellation(string marker, CancellationToken ct) - { - lock (_lock) - { - return _toolCancellations[marker].Task.WaitAsync(TestConstants.DefaultTimeout, ct); - } - } - - private static IDictionary CreateMarkerArgs(string marker) => - new Dictionary - { - ["marker"] = JsonDocument.Parse($"\"{marker}\"").RootElement.Clone() - }; - - [Fact] - public async Task CancelTask_OnlyCancelsTargetTask_NotOtherTasks() - { - // Arrange - await using McpClient client = await CreateMcpClientForServer(); - - RegisterMarker("task1"); - RegisterMarker("task2"); - - // Start two tasks - var result1 = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "trackable-tool", - Arguments = CreateMarkerArgs("task1"), - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(10) } - }, - cancellationToken: TestContext.Current.CancellationToken); - - var result2 = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "trackable-tool", - Arguments = CreateMarkerArgs("task2"), - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(10) } - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(result1.Task); - Assert.NotNull(result2.Task); - - // Wait for both tools to start - await WaitForStart("task1", TestContext.Current.CancellationToken); - await WaitForStart("task2", TestContext.Current.CancellationToken); - - // Act - Cancel only task1 - await client.CancelTaskAsync(result1.Task.TaskId, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - task1 should be cancelled - var task1Cancelled = await WaitForCancellation("task1", TestContext.Current.CancellationToken); - Assert.True(task1Cancelled, "Task1 should have been cancelled"); - - // task2 should still be running (give it a moment to verify it wasn't cancelled) - var task2Status = await client.GetTaskAsync(result2.Task.TaskId, cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Working, task2Status.Status); - - // Clean up - cancel task2 - await client.CancelTaskAsync(result2.Task.TaskId, cancellationToken: TestContext.Current.CancellationToken); - } - - [Fact] - public async Task MultipleTasks_WithDifferentTtls_CancelIndependently() - { - // Arrange - await using McpClient client = await CreateMcpClientForServer(); - - RegisterMarker("short-ttl"); - RegisterMarker("long-ttl"); - - // Start task with short TTL. Use a TTL long enough that thread pool scheduling - // delays on loaded CI machines don't cause the CTS to fire before the tool - // lambda begins executing (CancelAfter starts counting at task creation, not - // when the tool's Task.Run is scheduled). - var shortTtlResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "trackable-tool", - Arguments = CreateMarkerArgs("short-ttl"), - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromSeconds(5) } - }, - cancellationToken: TestContext.Current.CancellationToken); - - // Start task with long TTL - var longTtlResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "trackable-tool", - Arguments = CreateMarkerArgs("long-ttl"), - Task = new McpTaskMetadata { TimeToLive = TimeSpan.FromMinutes(10) } - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(shortTtlResult.Task); - Assert.NotNull(longTtlResult.Task); - - // Wait for both to start - await WaitForStart("short-ttl", TestContext.Current.CancellationToken); - await WaitForStart("long-ttl", TestContext.Current.CancellationToken); - - // Assert - short TTL task should be cancelled automatically - var shortCancelled = await WaitForCancellation("short-ttl", TestContext.Current.CancellationToken); - Assert.True(shortCancelled, "Short TTL task should have been cancelled when TTL expired"); - - // Long TTL task should still be running - var longTtlStatus = await client.GetTaskAsync(longTtlResult.Task.TaskId, cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Working, longTtlStatus.Status); - - // Clean up - await client.CancelTaskAsync(longTtlResult.Task.TaskId, cancellationToken: TestContext.Current.CancellationToken); - } -} - -/// -/// Tests verifying that terminal task states (completed, failed, cancelled) cannot transition. -/// Per spec: "Tasks with a completed, failed, or cancelled status are in a terminal state -/// and MUST NOT transition to any other status" -/// -public class TerminalTaskStatusTransitionTests : ClientServerTestBase -{ - public TerminalTaskStatusTransitionTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - } - - protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) - { - var taskStore = new InMemoryMcpTaskStore(); - services.AddSingleton(taskStore); - - services.Configure(options => - { - options.TaskStore = taskStore; - }); - - mcpServerBuilder.WithTools([ - McpServerTool.Create( - async (CancellationToken ct) => - { - await Task.Delay(10, ct); - return "quick result"; - }, - new McpServerToolCreateOptions - { - Name = "quick-tool", - Description = "A tool that completes quickly" - }), - McpServerTool.Create( - async (CancellationToken ct) => - { - await Task.Delay(10, ct); - throw new InvalidOperationException("Intentional failure"); -#pragma warning disable CS0162 - return "never"; -#pragma warning restore CS0162 - }, - new McpServerToolCreateOptions - { - Name = "failing-tool", - Description = "A tool that always fails" - }) - ]); - } - - private static IDictionary EmptyArguments() => new Dictionary(); - - [Fact] - public async Task CompletedTask_CannotTransitionToOtherStatus() - { - // Arrange - await using McpClient client = await CreateMcpClientForServer(); - - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "quick-tool", - Arguments = EmptyArguments(), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - string taskId = callResult.Task.TaskId; - - // Wait for completion - McpTask taskStatus; - do - { - await Task.Delay(50, TestContext.Current.CancellationToken); - taskStatus = await client.GetTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - } - while (taskStatus.Status == McpTaskStatus.Working); - - Assert.Equal(McpTaskStatus.Completed, taskStatus.Status); - - // Act - Try to cancel a completed task (should be idempotent) - var cancelResult = await client.CancelTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Status should still be completed (not cancelled) - Assert.Equal(McpTaskStatus.Completed, cancelResult.Status); - - // Verify via get - var verifyStatus = await client.GetTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(McpTaskStatus.Completed, verifyStatus.Status); - } - - [Fact] - public async Task FailedTask_CannotTransitionToOtherStatus() - { - // Arrange - await using McpClient client = await CreateMcpClientForServer(); - - var callResult = await client.CallToolAsync( - new CallToolRequestParams - { - Name = "failing-tool", - Arguments = EmptyArguments(), - Task = new McpTaskMetadata() - }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(callResult.Task); - string taskId = callResult.Task.TaskId; - - // Wait for failure - McpTask taskStatus; - do - { - await Task.Delay(50, TestContext.Current.CancellationToken); - taskStatus = await client.GetTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - } - while (taskStatus.Status == McpTaskStatus.Working); - - Assert.Equal(McpTaskStatus.Failed, taskStatus.Status); - - // Act - Try to cancel a failed task (should be idempotent) - var cancelResult = await client.CancelTaskAsync(taskId, cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Status should still be failed - Assert.Equal(McpTaskStatus.Failed, cancelResult.Status); - } -} diff --git a/tests/ModelContextProtocol.Tests/Server/ToolTaskSupportTests.cs b/tests/ModelContextProtocol.Tests/Server/ToolTaskSupportTests.cs deleted file mode 100644 index 25db2b330..000000000 --- a/tests/ModelContextProtocol.Tests/Server/ToolTaskSupportTests.cs +++ /dev/null @@ -1,727 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Server; - -/// -/// Integration tests verifying that tools report correct ToolTaskSupport values -/// based on server configuration and method signatures. -/// -public class ToolTaskSupportTests : LoggedTest -{ - public ToolTaskSupportTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - } - - [Fact] - public async Task Tools_WithoutTaskStore_ReportForbiddenTaskSupport() - { - // Arrange - Server without a task store - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([ - McpServerTool.Create(async (string input, CancellationToken ct) => - { - await Task.Delay(10, ct); - return $"Async: {input}"; - }, - new McpServerToolCreateOptions { Name = "async-tool", Description = "An async tool" }), - - McpServerTool.Create((string input) => $"Sync: {input}", - new McpServerToolCreateOptions { Name = "sync-tool", Description = "A sync tool" }) - ]); - }); - - // Act - var tools = await fixture.Client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Both tools should have Forbidden task support when no task store is configured - Assert.Equal(2, tools.Count); - - var asyncTool = tools.Single(t => t.Name == "async-tool"); - var syncTool = tools.Single(t => t.Name == "sync-tool"); - - // Without a task store, async tools should still report Optional (their intrinsic capability) - // but the server won't have tasks in capabilities. The tool itself declares its support. - Assert.Equal(ToolTaskSupport.Optional, asyncTool.ProtocolTool.Execution?.TaskSupport); - - // Sync tools should have null Execution or Forbidden task support - Assert.True( - syncTool.ProtocolTool.Execution is null || - syncTool.ProtocolTool.Execution.TaskSupport is null || - syncTool.ProtocolTool.Execution.TaskSupport == ToolTaskSupport.Forbidden, - "Sync tools should not support task execution"); - } - - [Fact] - public async Task Tools_WithTaskStore_AsyncToolsReportOptionalTaskSupport() - { - // Arrange - Server with a task store - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([ - McpServerTool.Create(async (string input, CancellationToken ct) => - { - await Task.Delay(10, ct); - return $"Async: {input}"; - }, - new McpServerToolCreateOptions { Name = "async-tool", Description = "An async tool" }), - - McpServerTool.Create((string input) => $"Sync: {input}", - new McpServerToolCreateOptions { Name = "sync-tool", Description = "A sync tool" }) - ]); - }, - configureServices: services => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - }); - - // Act - var tools = await fixture.Client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(2, tools.Count); - - var asyncTool = tools.Single(t => t.Name == "async-tool"); - var syncTool = tools.Single(t => t.Name == "sync-tool"); - - // Async tools should report Optional task support - Assert.Equal(ToolTaskSupport.Optional, asyncTool.ProtocolTool.Execution?.TaskSupport); - - // Sync tools should have null Execution or Forbidden task support - Assert.True( - syncTool.ProtocolTool.Execution is null || - syncTool.ProtocolTool.Execution.TaskSupport is null || - syncTool.ProtocolTool.Execution.TaskSupport == ToolTaskSupport.Forbidden, - "Sync tools should not support task execution"); - } - - [Fact] - public async Task Tools_WithExplicitTaskSupport_ReportsConfiguredValue() - { - // Arrange - Server with explicit task support configured on tools - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([ - McpServerTool.Create(async (string input, CancellationToken ct) => - { - await Task.Delay(10, ct); - return $"Async: {input}"; - }, - new McpServerToolCreateOptions - { - Name = "required-async-tool", - Description = "A tool that requires task execution", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Required } - }), - - McpServerTool.Create((string input) => $"Sync: {input}", - new McpServerToolCreateOptions - { - Name = "forbidden-sync-tool", - Description = "A tool that forbids task execution", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Forbidden } - }) - ]); - }, - configureServices: services => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - }); - - // Act - var tools = await fixture.Client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.Equal(2, tools.Count); - - var requiredTool = tools.Single(t => t.Name == "required-async-tool"); - var forbiddenTool = tools.Single(t => t.Name == "forbidden-sync-tool"); - - Assert.Equal(ToolTaskSupport.Required, requiredTool.ProtocolTool.Execution?.TaskSupport); - Assert.Equal(ToolTaskSupport.Forbidden, forbiddenTool.ProtocolTool.Execution?.TaskSupport); - } - - [Fact] - public async Task ServerCapabilities_WithoutTaskStore_DoNotIncludeTasksCapability() - { - // Arrange - Server without a task store - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([ - McpServerTool.Create((string input) => $"Result: {input}", - new McpServerToolCreateOptions { Name = "test-tool" }) - ]); - }); - - // Assert - Server capabilities should not include tasks - Assert.Null(fixture.Client.ServerCapabilities?.Tasks); - } - - [Fact] - public async Task ServerCapabilities_WithTaskStore_IncludeTasksCapability() - { - // Arrange - Server with a task store - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([ - McpServerTool.Create((string input) => $"Result: {input}", - new McpServerToolCreateOptions { Name = "test-tool" }) - ]); - }, - configureServices: services => - { - services.Configure(options => - { - options.TaskStore = taskStore; - }); - }); - - // Assert - Server capabilities should include tasks - Assert.NotNull(fixture.Client.ServerCapabilities?.Tasks); - Assert.NotNull(fixture.Client.ServerCapabilities.Tasks.List); - Assert.NotNull(fixture.Client.ServerCapabilities.Tasks.Cancel); - Assert.NotNull(fixture.Client.ServerCapabilities.Tasks.Requests?.Tools?.Call); - } - -#pragma warning disable MCPEXP001 // Tasks feature is experimental - [Fact] - public void McpServerToolAttribute_TaskSupport_CanBeSetOnAttribute() - { - // Test that the TaskSupport property can be set via the attribute - // and is correctly read when creating a tool - var tool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.RequiredTaskTool))!); - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Required, tool.ProtocolTool.Execution.TaskSupport); - - var optionalTool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.OptionalTaskTool))!); - Assert.NotNull(optionalTool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, optionalTool.ProtocolTool.Execution.TaskSupport); - - var forbiddenTool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.ForbiddenTaskTool))!); - Assert.NotNull(forbiddenTool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Forbidden, forbiddenTool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void McpServerToolAttribute_TaskSupport_WhenNotSet_AllowsAutoDetection() - { - // When TaskSupport is not set on the attribute, async tools should use auto-detection (Optional) - var asyncTool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.AsyncToolWithoutTaskSupport))!); - Assert.NotNull(asyncTool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, asyncTool.ProtocolTool.Execution.TaskSupport); - - // Sync tools without TaskSupport set should have null Execution or Forbidden - var syncTool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.SyncToolWithoutTaskSupport))!); - Assert.True( - syncTool.ProtocolTool.Execution is null || - syncTool.ProtocolTool.Execution.TaskSupport is null || - syncTool.ProtocolTool.Execution.TaskSupport == ToolTaskSupport.Forbidden, - "Sync tools without explicit TaskSupport should not support tasks"); - } - - [Fact] - public void McpServerToolAttribute_TaskSupport_ExplicitForbidden_OverridesAutoDetection() - { - // Verify that explicitly setting Forbidden overrides auto-detection for async methods - var forbiddenAsyncTool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.ForbiddenAsyncTool))!); - Assert.NotNull(forbiddenAsyncTool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Forbidden, forbiddenAsyncTool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void McpServerToolAttribute_TaskSupport_OptionalOnSyncMethod_IsAllowed() - { - // Setting Optional on a sync method is allowed - the tool will just execute very quickly - // This tests that the SDK doesn't prevent this configuration at tool creation time - var tool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.OptionalTaskTool))!); - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, tool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void McpServerToolAttribute_TaskSupport_RequiredOnSyncMethod_IsAllowed() - { - // Setting Required on a sync method is allowed - the tool will just execute very quickly - // This tests that the SDK doesn't prevent this configuration at tool creation time - var tool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.RequiredTaskTool))!); - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Required, tool.ProtocolTool.Execution.TaskSupport); - } -#pragma warning restore MCPEXP001 - -#pragma warning disable MCPEXP001 // Tasks feature is experimental - [Fact] - public void McpServerToolAttribute_TaskSupport_WhenNotSet_DefaultsBasedOnMethodSignature() - { - // When TaskSupport is not set on the attribute, async tools should default to Optional - var asyncTool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.AsyncToolWithoutTaskSupport))!); - Assert.NotNull(asyncTool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Optional, asyncTool.ProtocolTool.Execution.TaskSupport); - - // Sync tools should have null or no Execution set - var syncTool = McpServerTool.Create(typeof(TaskSupportAttributeTestTools).GetMethod(nameof(TaskSupportAttributeTestTools.SyncToolWithoutTaskSupport))!); - Assert.True( - syncTool.ProtocolTool.Execution is null || - syncTool.ProtocolTool.Execution.TaskSupport is null || - syncTool.ProtocolTool.Execution.TaskSupport == ToolTaskSupport.Forbidden, - "Sync tools without explicit TaskSupport should not support tasks"); - } - - [Theory] - [InlineData(ToolTaskSupport.Forbidden, "\"forbidden\"")] - [InlineData(ToolTaskSupport.Optional, "\"optional\"")] - [InlineData(ToolTaskSupport.Required, "\"required\"")] - public void ToolTaskSupport_SerializesToJsonCorrectly(ToolTaskSupport value, string expectedJson) - { - var json = JsonSerializer.Serialize(value, McpJsonUtilities.DefaultOptions); - Assert.Equal(expectedJson, json); - } - - [Theory] - [InlineData("\"forbidden\"", ToolTaskSupport.Forbidden)] - [InlineData("\"optional\"", ToolTaskSupport.Optional)] - [InlineData("\"required\"", ToolTaskSupport.Required)] - public void ToolTaskSupport_DeserializesFromJsonCorrectly(string json, ToolTaskSupport expected) - { - var value = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); - Assert.Equal(expected, value); - } - - [Fact] - public void ToolExecution_TaskSupport_NullByDefault() - { - // Verify that ToolExecution.TaskSupport is null by default - var execution = new ToolExecution(); - Assert.Null(execution.TaskSupport); - - // When serialized with a value, it should appear correctly - var tool = new Tool - { - Name = "test", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } - }; - var toolJson = JsonSerializer.Serialize(tool, McpJsonUtilities.DefaultOptions); - Assert.Contains("\"optional\"", toolJson); - } - - [Fact] - public void McpServerToolCreateOptions_Execution_OverridesAutoDetection() - { - // When Execution is set via options, it should override auto-detection - var tool = McpServerTool.Create( - async (string input, CancellationToken ct) => - { - await Task.Delay(1, ct); - return input; - }, - new McpServerToolCreateOptions - { - Name = "test", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Forbidden } - }); - - // Even though this is an async method, it should have Forbidden since it was explicitly set - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Forbidden, tool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void McpServerToolCreateOptions_Execution_Required_SetsCorrectly() - { - var tool = McpServerTool.Create( - (string input) => input, - new McpServerToolCreateOptions - { - Name = "test", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Required } - }); - - Assert.NotNull(tool.ProtocolTool.Execution); - Assert.Equal(ToolTaskSupport.Required, tool.ProtocolTool.Execution.TaskSupport); - } - - [Fact] - public void ToolTaskSupport_EnumValues_AreCorrect() - { - // Verify enum values are as expected (Forbidden = 0) - Assert.Equal(0, (int)ToolTaskSupport.Forbidden); - Assert.Equal(1, (int)ToolTaskSupport.Optional); - Assert.Equal(2, (int)ToolTaskSupport.Required); - } - - [Fact] - public void McpServerToolAttribute_TaskSupport_PublicPropertyDefaultsToForbidden() - { - // Verify that the public property returns Forbidden when not set - var attr = new McpServerToolAttribute(); - Assert.Equal(ToolTaskSupport.Forbidden, attr.TaskSupport); - } -#pragma warning restore MCPEXP001 - - [McpServerToolType] - private static class TaskSupportAttributeTestTools - { -#pragma warning disable MCPEXP001 // Tasks feature is experimental - [McpServerTool(TaskSupport = ToolTaskSupport.Required)] - public static string RequiredTaskTool(string input) => $"Required: {input}"; - - [McpServerTool(TaskSupport = ToolTaskSupport.Optional)] - public static string OptionalTaskTool(string input) => $"Optional: {input}"; - - [McpServerTool(TaskSupport = ToolTaskSupport.Forbidden)] - public static string ForbiddenTaskTool(string input) => $"Forbidden: {input}"; - - [McpServerTool(TaskSupport = ToolTaskSupport.Forbidden)] - public static async Task ForbiddenAsyncTool(string input, CancellationToken ct) - { - await Task.Delay(1, ct); - return $"ForbiddenAsync: {input}"; - } -#pragma warning restore MCPEXP001 - - [McpServerTool] - public static async Task AsyncToolWithoutTaskSupport(string input, CancellationToken ct) - { - await Task.Delay(1, ct); - return $"Async: {input}"; - } - - [McpServerTool] - public static string SyncToolWithoutTaskSupport(string input) => $"Sync: {input}"; - } - - #region Sync Method with Optional/Required TaskSupport Integration Tests - -#pragma warning disable MCPEXP001 // Tasks feature is experimental - [Fact] - public async Task SyncTool_WithOptionalTaskSupport_CanBeCalledAsTask() - { - // Arrange - Server with task store and a sync tool with Optional task support - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([McpServerTool.Create( - (string input) => $"Sync result: {input}", - new McpServerToolCreateOptions - { - Name = "optional-sync-tool", - Description = "A sync tool with optional task support", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } - })]); - }, - configureServices: services => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - }); - - // Act - Call the sync tool as a task - var mcpTask = await fixture.Client.CallToolAsTaskAsync( - "optional-sync-tool", - arguments: new Dictionary { ["input"] = "test" }, - taskMetadata: new McpTaskMetadata(), - progress: null, - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Task was created successfully - Assert.NotNull(mcpTask); - Assert.NotEmpty(mcpTask.TaskId); - } - - [Fact] - public async Task SyncTool_WithRequiredTaskSupport_CanBeCalledAsTask() - { - // Arrange - Server with task store and a sync tool with Required task support - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([McpServerTool.Create( - (string input) => $"Sync result: {input}", - new McpServerToolCreateOptions - { - Name = "required-sync-tool", - Description = "A sync tool with required task support", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Required } - })]); - }, - configureServices: services => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - }); - - // Act - Call the sync tool as a task - var mcpTask = await fixture.Client.CallToolAsTaskAsync( - "required-sync-tool", - arguments: new Dictionary { ["input"] = "test" }, - taskMetadata: new McpTaskMetadata(), - progress: null, - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Task was created successfully - Assert.NotNull(mcpTask); - Assert.NotEmpty(mcpTask.TaskId); - } - - [Fact] - public async Task SyncTool_WithRequiredTaskSupport_CannotBeCalledDirectly() - { - // Arrange - Server with task store and a sync tool with Required task support - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([McpServerTool.Create( - (string input) => $"Sync result: {input}", - new McpServerToolCreateOptions - { - Name = "required-sync-tool", - Description = "A sync tool with required task support", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Required } - })]); - }, - configureServices: services => - { - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - }); - - // Act & Assert - Calling directly should fail because task execution is required - var exception = await Assert.ThrowsAsync(() => - fixture.Client.CallToolAsync( - "required-sync-tool", - arguments: new Dictionary { ["input"] = "test" }, - cancellationToken: TestContext.Current.CancellationToken).AsTask()); - - // The server returns InvalidParams because direct invocation is not allowed for required-task tools - Assert.Equal(McpErrorCode.InvalidParams, exception.ErrorCode); - Assert.Contains("task", exception.Message, StringComparison.OrdinalIgnoreCase); - } - - [Fact] - public async Task TaskPath_Logs_Tool_Name_On_Successful_Call() - { - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([McpServerTool.Create( - (string input) => $"Result: {input}", - new McpServerToolCreateOptions - { - Name = "task-success-tool", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } - })]); - }, - configureServices: services => - { - services.AddSingleton(MockLoggerProvider); - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - }); - - var mcpTask = await fixture.Client.CallToolAsTaskAsync( - "task-success-tool", - arguments: new Dictionary { ["input"] = "test" }, - taskMetadata: new McpTaskMetadata(), - progress: null, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(mcpTask); - - // Wait for the async task execution to complete - await fixture.Client.GetTaskResultAsync(mcpTask.TaskId, cancellationToken: TestContext.Current.CancellationToken); - - var infoLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.Message == "\"task-success-tool\" completed. IsError = False."); - Assert.Equal(LogLevel.Information, infoLog.LogLevel); - } - - [Fact] - public async Task TaskPath_Logs_Tool_Name_With_IsError_When_Tool_Returns_Error() - { - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([McpServerTool.Create( - () => new CallToolResult - { - IsError = true, - Content = [new TextContentBlock { Text = "Task tool error" }], - }, - new McpServerToolCreateOptions - { - Name = "task-error-result-tool", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } - })]); - }, - configureServices: services => - { - services.AddSingleton(MockLoggerProvider); - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - }); - - var mcpTask = await fixture.Client.CallToolAsTaskAsync( - "task-error-result-tool", - taskMetadata: new McpTaskMetadata(), - progress: null, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(mcpTask); - - // Wait for the async task execution to complete - await fixture.Client.GetTaskResultAsync(mcpTask.TaskId, cancellationToken: TestContext.Current.CancellationToken); - - var infoLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.Message == "\"task-error-result-tool\" completed. IsError = True."); - Assert.Equal(LogLevel.Information, infoLog.LogLevel); - } - - [Fact] - public async Task TaskPath_Logs_Error_When_Tool_Throws() - { - var taskStore = new InMemoryMcpTaskStore(); - - await using var fixture = new ClientServerFixture( - LoggerFactory, - configureServer: builder => - { - builder.WithTools([McpServerTool.Create( - string () => throw new InvalidOperationException("Task tool error"), - new McpServerToolCreateOptions - { - Name = "task-throw-tool", - Execution = new ToolExecution { TaskSupport = ToolTaskSupport.Optional } - })]); - }, - configureServices: services => - { - services.AddSingleton(MockLoggerProvider); - services.AddSingleton(taskStore); - services.Configure(options => options.TaskStore = taskStore); - }); - - var mcpTask = await fixture.Client.CallToolAsTaskAsync( - "task-throw-tool", - taskMetadata: new McpTaskMetadata(), - progress: null, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.NotNull(mcpTask); - - // Wait for the async task execution to complete - await fixture.Client.GetTaskResultAsync(mcpTask.TaskId, cancellationToken: TestContext.Current.CancellationToken); - - var errorLog = Assert.Single(MockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); - Assert.Equal("\"task-throw-tool\" threw an unhandled exception.", errorLog.Message); - Assert.IsType(errorLog.Exception); - } -#pragma warning restore MCPEXP001 - - #endregion - - /// - /// A fixture that creates a connected MCP client-server pair for testing. - /// - private sealed class ClientServerFixture : IAsyncDisposable - { - private readonly System.IO.Pipelines.Pipe _clientToServerPipe = new(); - private readonly System.IO.Pipelines.Pipe _serverToClientPipe = new(); - private readonly CancellationTokenSource _cts; - private readonly Task _serverTask; - private readonly IServiceProvider _serviceProvider; - - public McpClient Client { get; } - public McpServer Server { get; } - - public ClientServerFixture( - ILoggerFactory loggerFactory, - Action? configureServer, - Action? configureServices = null) - { - ServiceCollection sc = new(); - sc.AddLogging(); - - var builder = sc - .AddMcpServer() - .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); - - configureServer?.Invoke(builder); - configureServices?.Invoke(sc); - - _serviceProvider = sc.BuildServiceProvider(validateScopes: true); - _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - - Server = _serviceProvider.GetRequiredService(); - _serverTask = Server.RunAsync(_cts.Token); - - // Create client synchronously by blocking - this is test code - Client = McpClient.CreateAsync( - new StreamClientTransport( - serverInput: _clientToServerPipe.Writer.AsStream(), - _serverToClientPipe.Reader.AsStream(), - loggerFactory), - loggerFactory: loggerFactory, - cancellationToken: TestContext.Current.CancellationToken).GetAwaiter().GetResult(); - } - - public async ValueTask DisposeAsync() - { - await Client.DisposeAsync(); - await _cts.CancelAsync(); - - _clientToServerPipe.Writer.Complete(); - _serverToClientPipe.Writer.Complete(); - - await _serverTask; - - if (_serviceProvider is IAsyncDisposable asyncDisposable) - { - await asyncDisposable.DisposeAsync(); - } - else if (_serviceProvider is IDisposable disposable) - { - disposable.Dispose(); - } - - _cts.Dispose(); - } - } -}