diff --git a/.gitignore b/.gitignore index a1b83bc4f..75c943a57 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,4 @@ dist/ # IDE .idea/ +.cursor/ \ No newline at end of file diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 7613e3786..0332ab81a 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -480,6 +480,7 @@ const getServer = () => { { async createTask({ duration }, { taskStore, taskRequestedTtl }) { // Create the task + if (!taskStore) throw new Error('Task store not found'); const task = await taskStore.createTask({ ttl: taskRequestedTtl }); @@ -503,10 +504,12 @@ const getServer = () => { }; }, async getTask(_args, { taskId, taskStore }) { - return await taskStore.getTask(taskId); + if (!taskStore) throw new Error('Task store not found'); + return await taskStore.getTask(taskId!); }, async getTaskResult(_args, { taskId, taskStore }) { - const result = await taskStore.getTaskResult(taskId); + if (!taskStore) throw new Error('Task store not found'); + const result = await taskStore.getTaskResult(taskId!); return result as CallToolResult; } } diff --git a/packages/core/src/experimental/tasks/interfaces.ts b/packages/core/src/experimental/tasks/interfaces.ts index c1901d70a..4bf11942c 100644 --- a/packages/core/src/experimental/tasks/interfaces.ts +++ b/packages/core/src/experimental/tasks/interfaces.ts @@ -3,7 +3,6 @@ * WARNING: These APIs are experimental and may change without notice. */ -import type { RequestHandlerExtra, RequestTaskStore } from '../../shared/protocol.js'; import type { JSONRPCErrorResponse, JSONRPCNotification, @@ -12,8 +11,6 @@ import type { Request, RequestId, Result, - ServerNotification, - ServerRequest, Task, ToolExecution } from '../../types/types.js'; @@ -22,23 +19,6 @@ import type { // Task Handler Types (for registerToolTask) // ============================================================================ -/** - * Extended handler extra with task store for task creation. - * @experimental - */ -export interface CreateTaskRequestHandlerExtra extends RequestHandlerExtra { - taskStore: RequestTaskStore; -} - -/** - * Extended handler extra with task ID and store for task operations. - * @experimental - */ -export interface TaskRequestHandlerExtra extends RequestHandlerExtra { - taskId: string; - taskStore: RequestTaskStore; -} - /** * Task-specific execution configuration. * taskSupport cannot be 'forbidden' for task-based tools. diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 9c65015d1..bfbc45c03 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -237,6 +237,8 @@ export interface RequestTaskStore { /** * Extra data given to request handlers. + * + * @deprecated Use {@link ContextInterface} from {@link Context} instead. Future major versions will remove this type. */ export type RequestHandlerExtra = { /** @@ -718,43 +720,15 @@ export abstract class Protocol = { - signal: abortController.signal, - sessionId: capturedTransport?.sessionId, - _meta: request.params?._meta, - sendNotification: async notification => { - // Include related-task metadata if this request is part of a task - const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; - if (relatedTaskId) { - notificationOptions.relatedTask = { taskId: relatedTaskId }; - } - await this.notification(notification, notificationOptions); - }, - sendRequest: async (r, resultSchema, options?) => { - // Include related-task metadata if this request is part of a task - const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } - - // Set task status to input_required when sending a request within a task context - // Use the taskId from options (explicit) or fall back to relatedTaskId (inherited) - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } - - return await this.request(r, resultSchema, requestOptions); - }, - authInfo: extra?.authInfo, - requestId: request.id, - requestInfo: extra?.requestInfo, - taskId: relatedTaskId, - taskStore: taskStore, - taskRequestedTtl: taskCreationParams?.ttl, - closeSSEStream: extra?.closeSSEStream, - closeStandaloneSSEStream: extra?.closeStandaloneSSEStream - }; + const fullExtra: RequestHandlerExtra = this.createRequestExtra({ + request, + taskStore, + relatedTaskId, + taskCreationParams, + abortController, + capturedTransport, + extra + }); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() @@ -832,6 +806,60 @@ export abstract class Protocol { + const { request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra } = args; + + return { + signal: abortController.signal, + sessionId: capturedTransport?.sessionId, + _meta: request.params?._meta, + sendNotification: async notification => { + // Include related-task metadata if this request is part of a task + const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; + if (relatedTaskId) { + notificationOptions.relatedTask = { taskId: relatedTaskId }; + } + await this.notification(notification, notificationOptions); + }, + sendRequest: async (r, resultSchema, options?) => { + // Include related-task metadata if this request is part of a task + const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; + if (relatedTaskId && !requestOptions.relatedTask) { + requestOptions.relatedTask = { taskId: relatedTaskId }; + } + + // Set task status to input_required when sending a request within a task context + // Use the taskId from options (explicit) or fall back to relatedTaskId (inherited) + const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; + if (effectiveTaskId && taskStore) { + await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); + } + + return await this.request(r, resultSchema, requestOptions); + }, + authInfo: extra?.authInfo, + requestId: request.id, + requestInfo: extra?.requestInfo, + taskId: relatedTaskId, + taskStore: taskStore, + taskRequestedTtl: taskCreationParams?.ttl, + closeSSEStream: extra?.closeSSEStream, + closeStandaloneSSEStream: extra?.closeStandaloneSSEStream + } as RequestHandlerExtra; + } + private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index b16a4453d..9f3ca111e 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -2095,7 +2095,7 @@ describe('Request Cancellation vs Task Cancellation', () => { let wasAborted = false; const TestRequestSchema = z.object({ method: z.literal('test/longRunning'), - params: z.optional(z.record(z.unknown())) + params: z.optional(z.record(z.string(), z.unknown())) }); protocol.setRequestHandler(TestRequestSchema, async (_request, extra) => { // Simulate a long-running operation @@ -2310,7 +2310,7 @@ describe('Request Cancellation vs Task Cancellation', () => { let requestCompleted = false; const TestMethodSchema = z.object({ method: z.literal('test/method'), - params: z.optional(z.record(z.unknown())) + params: z.optional(z.record(z.string(), z.unknown())) }); protocol.setRequestHandler(TestMethodSchema, async () => { await new Promise(resolve => setTimeout(resolve, 50)); @@ -3690,7 +3690,7 @@ describe('Message Interception', () => { method: z.literal('test/taskRequest'), params: z .object({ - _meta: z.optional(z.record(z.unknown())) + _meta: z.optional(z.record(z.string(), z.unknown())) }) .passthrough() }); @@ -3737,7 +3737,7 @@ describe('Message Interception', () => { method: z.literal('test/taskRequestError'), params: z .object({ - _meta: z.optional(z.record(z.unknown())) + _meta: z.optional(z.record(z.string(), z.unknown())) }) .passthrough() }); @@ -3817,7 +3817,7 @@ describe('Message Interception', () => { // Set up a request handler const TestRequestSchema = z.object({ method: z.literal('test/normalRequest'), - params: z.optional(z.record(z.unknown())) + params: z.optional(z.record(z.string(), z.unknown())) }); protocol.setRequestHandler(TestRequestSchema, async () => { diff --git a/packages/server/src/experimental/tasks/interfaces.ts b/packages/server/src/experimental/tasks/interfaces.ts index 0b32be213..574d256db 100644 --- a/packages/server/src/experimental/tasks/interfaces.ts +++ b/packages/server/src/experimental/tasks/interfaces.ts @@ -6,14 +6,15 @@ import type { AnySchema, CallToolResult, - CreateTaskRequestHandlerExtra, CreateTaskResult, GetTaskResult, Result, - TaskRequestHandlerExtra, + ServerNotification, + ServerRequest, ZodRawShapeCompat } from '@modelcontextprotocol/core'; +import type { ContextInterface } from '../../server/context.js'; import type { BaseToolCallback } from '../../server/mcp.js'; // ============================================================================ @@ -27,7 +28,7 @@ import type { BaseToolCallback } from '../../server/mcp.js'; export type CreateTaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; +> = BaseToolCallback, Args>; /** * Handler for task operations (get, getResult). @@ -36,7 +37,7 @@ export type CreateTaskRequestHandler< export type TaskRequestHandler< SendResultT extends Result, Args extends undefined | ZodRawShapeCompat | AnySchema = undefined -> = BaseToolCallback; +> = BaseToolCallback, Args>; /** * Interface for task-based tool handlers. diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 4b0c42053..e98b80007 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,4 +1,5 @@ export * from './server/completable.js'; +export * from './server/context.js'; export * from './server/express.js'; export * from './server/mcp.js'; export * from './server/server.js'; diff --git a/packages/server/src/server/context.ts b/packages/server/src/server/context.ts new file mode 100644 index 000000000..f133460e1 --- /dev/null +++ b/packages/server/src/server/context.ts @@ -0,0 +1,343 @@ +import type { + AnySchema, + AuthInfo, + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + JSONRPCRequest, + LoggingMessageNotification, + Notification, + Request, + RequestHandlerExtra, + RequestId, + RequestInfo, + RequestMeta, + RequestOptions, + RequestTaskStore, + Result, + SchemaOutput, + ServerNotification, + ServerRequest +} from '@modelcontextprotocol/core'; +import { ElicitResultSchema } from '@modelcontextprotocol/core'; + +import type { Server } from './server.js'; + +/** + * Interface for sending logging messages to the client via {@link LoggingMessageNotification}. + */ +export interface LoggingMessageNotificationSenderInterface { + /** + * Sends a logging message to the client. + */ + log(params: LoggingMessageNotification['params'], sessionId?: string): Promise; + /** + * Sends a debug log message to the client. + */ + debug(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends an info log message to the client. + */ + info(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends a warning log message to the client. + */ + warning(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends an error log message to the client. + */ + error(message: string, extraLogData?: Record, sessionId?: string): Promise; +} + +export class ServerLogger implements LoggingMessageNotificationSenderInterface { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + constructor(private readonly server: Server) {} + + /** + * Sends a logging message. + */ + public async log(params: LoggingMessageNotification['params'], sessionId?: string) { + await this.server.sendLoggingMessage(params, sessionId); + } + + /** + * Sends a debug log message. + */ + public async debug(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'debug', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an info log message. + */ + public async info(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'info', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends a warning log message. + */ + public async warning(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'warning', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an error log message. + */ + public async error(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'error', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } +} + +export interface ContextInterface + extends RequestHandlerExtra { + elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise; + requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; + loggingNotification: LoggingMessageNotificationSenderInterface; +} +/** + * A context object that is passed to request handlers. + * + * Implements the RequestHandlerExtra interface for backwards compatibility. + */ +export class Context + implements ContextInterface +{ + private readonly server: Server; + + /** + * The request context. + * A type-safe context that is passed to request handlers. + */ + private readonly requestCtx: RequestHandlerExtra; + + /** + * The MCP context - Contains information about the current MCP request and session. + */ + public readonly mcpContext: { + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + */ + requestId: RequestId; + /** + * The method of the request. + */ + method: string; + /** + * The metadata of the request. + */ + _meta?: RequestMeta; + /** + * The session ID of the request. + */ + sessionId?: string; + }; + + public readonly task: + | { + id: string | undefined; + store: RequestTaskStore | undefined; + requestedTtl: number | null | undefined; + } + | undefined; + + public readonly stream: + | { + /** + * Closes the SSE stream for this request, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior during long-running operations. + */ + closeSSEStream: (() => void) | undefined; + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior for server-initiated notifications. + */ + closeStandaloneSSEStream: (() => void) | undefined; + } + | undefined; + + public readonly loggingNotification: LoggingMessageNotificationSenderInterface; + + constructor(args: { + server: Server; + request: JSONRPCRequest; + requestCtx: RequestHandlerExtra; + }) { + this.server = args.server; + this.requestCtx = args.requestCtx; + this.mcpContext = { + requestId: args.requestCtx.requestId, + method: args.request.method, + _meta: args.requestCtx._meta, + sessionId: args.requestCtx.sessionId + }; + + this.task = { + id: args.requestCtx.taskId, + store: args.requestCtx.taskStore, + requestedTtl: args.requestCtx.taskRequestedTtl + }; + + this.loggingNotification = new ServerLogger(args.server); + + this.stream = { + closeSSEStream: args.requestCtx.closeSSEStream, + closeStandaloneSSEStream: args.requestCtx.closeStandaloneSSEStream + }; + } + + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + * + * @deprecated Use {@link mcpContext.requestId} instead. + */ + public get requestId(): RequestId { + return this.requestCtx.requestId; + } + + public get signal(): AbortSignal { + return this.requestCtx.signal; + } + + public get authInfo(): AuthInfo | undefined { + return this.requestCtx.authInfo; + } + + public get requestInfo(): RequestInfo | undefined { + return this.requestCtx.requestInfo; + } + + /** + * @deprecated Use {@link mcpContext._meta} instead. + */ + public get _meta(): RequestMeta | undefined { + return this.requestCtx._meta; + } + + /** + * @deprecated Use {@link mcpContext.sessionId} instead. + */ + public get sessionId(): string | undefined { + return this.mcpContext.sessionId; + } + + /** + * @deprecated Use {@link task.id} instead. + */ + public get taskId(): string | undefined { + return this.requestCtx.taskId; + } + + /** + * @deprecated Use {@link task.store} instead. + */ + public get taskStore(): RequestTaskStore | undefined { + return this.requestCtx.taskStore; + } + + /** + * @deprecated Use {@link task.requestedTtl} instead. + */ + public get taskRequestedTtl(): number | undefined { + return this.requestCtx.taskRequestedTtl ?? undefined; + } + + /** + * @deprecated Use {@link stream.closeSSEStream} instead. + */ + public get closeSSEStream(): (() => void) | undefined { + return this.requestCtx.closeSSEStream; + } + + /** + * @deprecated Use {@link stream.closeStandaloneSSEStream} instead. + */ + public get closeStandaloneSSEStream(): (() => void) | undefined { + return this.requestCtx.closeStandaloneSSEStream; + } + + /** + * Sends a notification that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + public sendNotification = (notification: NotificationT | ServerNotification): Promise => { + return this.requestCtx.sendNotification(notification); + }; + + /** + * Sends a request that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + public sendRequest = ( + request: RequestT | ServerRequest, + resultSchema: U, + options?: RequestOptions + ): Promise> => { + return this.requestCtx.sendRequest(request, resultSchema, { ...options, relatedRequestId: this.requestId }); + }; + + /** + * Sends a request to sample an LLM via the client. + */ + public requestSampling(params: CreateMessageRequest['params'], options?: RequestOptions) { + return this.server.createMessage(params, options); + } + + /** + * Sends an elicitation request to the client. + */ + public async elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise { + const request: ElicitRequest = { + method: 'elicitation/create', + params + }; + return await this.server.request(request, ElicitResultSchema, { ...options, relatedRequestId: this.requestId }); + } +} diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 8564212c1..64cda5269 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -18,7 +18,6 @@ import type { PromptArgument, PromptReference, ReadResourceResult, - RequestHandlerExtra, Resource, ResourceTemplateReference, Result, @@ -63,6 +62,7 @@ import { ZodOptional } from 'zod'; import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcp-server.js'; import { getCompleter, isCompletable } from './completable.js'; +import type { ContextInterface } from './context.js'; import type { ServerOptions } from './server.js'; import { Server } from './server.js'; @@ -326,7 +326,7 @@ export class McpServer { private async executeToolHandler( tool: RegisteredTool, args: unknown, - extra: RequestHandlerExtra + extra: ContextInterface ): Promise { const handler = tool.handler as AnyToolHandler; const isTaskHandler = 'createTask' in handler; @@ -340,7 +340,7 @@ export class McpServer { if (tool.inputSchema) { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any - return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); + return await Promise.resolve(typedHandler.createTask(args as any, extra)); } else { const typedHandler = handler as ToolTaskHandler; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -365,7 +365,7 @@ export class McpServer { private async handleAutomaticTaskPolling( tool: RegisteredTool, request: RequestT, - extra: RequestHandlerExtra + extra: ContextInterface ): Promise { if (!extra.taskStore) { throw new Error('No task store provided for task-capable tool.'); @@ -374,12 +374,11 @@ export class McpServer { // Validate input and create task const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); const handler = tool.handler as ToolTaskHandler; - const taskExtra = { ...extra, taskStore: extra.taskStore }; const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined - ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) + ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, extra)) : // eslint-disable-next-line @typescript-eslint/no-explicit-any - await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(taskExtra)); + await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(extra)); // Poll until completion const taskId = createTaskResult.task.taskId; @@ -1137,7 +1136,7 @@ export class McpServer { /** * Registers a prompt with a config object and callback. */ - registerPrompt( + registerPrompt( name: string, config: { title?: string; @@ -1272,7 +1271,7 @@ export class ResourceTemplate { export type BaseToolCallback< SendResultT extends Result, - Extra extends RequestHandlerExtra, + Extra extends ContextInterface, Args extends undefined | ZodRawShapeCompat | AnySchema > = Args extends ZodRawShapeCompat ? (args: ShapeOutput, extra: Extra) => SendResultT | Promise @@ -1292,7 +1291,7 @@ export type BaseToolCallback< */ export type ToolCallback = BaseToolCallback< CallToolResult, - RequestHandlerExtra, + ContextInterface, Args >; @@ -1411,7 +1410,7 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra + extra: ContextInterface ) => ListResourcesResult | Promise; /** @@ -1419,7 +1418,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - extra: RequestHandlerExtra + extra: ContextInterface ) => ReadResourceResult | Promise; export type RegisteredResource = { @@ -1447,7 +1446,7 @@ export type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - extra: RequestHandlerExtra + extra: ContextInterface ) => ReadResourceResult | Promise; export type RegisteredResourceTemplate = { @@ -1472,8 +1471,8 @@ export type RegisteredResourceTemplate = { type PromptArgsRawShape = ZodRawShapeCompat; export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, extra: RequestHandlerExtra) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + ? (args: ShapeOutput, extra: ContextInterface) => GetPromptResult | Promise + : (extra: ContextInterface) => GetPromptResult | Promise; export type RegisteredPrompt = { title?: string; diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 8132e342b..6cc4c104a 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -12,11 +12,13 @@ import type { Implementation, InitializeRequest, InitializeResult, + JSONRPCRequest, JsonSchemaType, jsonSchemaValidator, ListRootsRequest, LoggingLevel, LoggingMessageNotification, + MessageExtraInfo, Notification, NotificationOptions, ProtocolOptions, @@ -30,8 +32,11 @@ import type { ServerNotification, ServerRequest, ServerResult, + TaskCreationParams, + TaskStore, ToolResultContent, ToolUseContent, + Transport, ZodV3Internal, ZodV4Internal } from '@modelcontextprotocol/core'; @@ -63,6 +68,7 @@ import { } from '@modelcontextprotocol/core'; import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; +import { Context } from './context.js'; export type ServerOptions = ProtocolOptions & { /** @@ -226,9 +232,31 @@ export class Server< requestSchema: T, handler: ( request: SchemaOutput, - extra: RequestHandlerExtra + extra: Context ) => ServerResult | ResultT | Promise ): void { + // Wrap the handler to ensure the extra is a Context and return a decorated handler that can be passed to the base implementation + + // Factory function to create a handler decorator that ensures the extra is a Context and returns a decorated handler that can be passed to the base implementation + const handlerDecoratorFactory = ( + innerHandler: ( + request: SchemaOutput, + extra: Context + ) => ServerResult | ResultT | Promise + ) => { + const decoratedHandler = ( + request: SchemaOutput, + extra: RequestHandlerExtra + ) => { + if (!this.isContextExtra(extra)) { + throw new Error('Internal error: Expected Context for request handler extra'); + } + return innerHandler(request, extra); + }; + + return decoratedHandler; + }; + const shape = getObjectShape(requestSchema); const methodSchema = shape?.method; if (!methodSchema) { @@ -266,7 +294,7 @@ export class Server< const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handlerDecoratorFactory(handler)(request, extra)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -293,11 +321,18 @@ export class Server< }; // Install the wrapped handler - return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + return super.setRequestHandler(requestSchema, handlerDecoratorFactory(wrappedHandler)); } // Other handlers use default behavior - return super.setRequestHandler(requestSchema, handler); + return super.setRequestHandler(requestSchema, handlerDecoratorFactory(handler)); + } + + // Runtime type guard: ensure extra is our Context + private isContextExtra( + extra: RequestHandlerExtra + ): extra is Context { + return extra instanceof Context; } protected assertCapabilityForMethod(method: RequestT['method']): void { @@ -475,6 +510,25 @@ export class Server< return this._capabilities; } + protected override createRequestExtra(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): RequestHandlerExtra { + const base = super.createRequestExtra(args) as RequestHandlerExtra; + + // Expose a Context instance to handlers, which implements RequestHandlerExtra + return new Context({ + server: this, + request: args.request, + requestCtx: base + }); + } + async ping() { return this.request({ method: 'ping' }, EmptyResultSchema); } diff --git a/test/integration/test/issues/test_1277_zod_v4_description.test.ts b/test/integration/test/issues/test_1277_zod_v4_description.test.ts index fe58cfcd5..75a61cb36 100644 --- a/test/integration/test/issues/test_1277_zod_v4_description.test.ts +++ b/test/integration/test/issues/test_1277_zod_v4_description.test.ts @@ -9,7 +9,8 @@ import { Client } from '@modelcontextprotocol/client'; import { InMemoryTransport, ListPromptsResultSchema } from '@modelcontextprotocol/core'; import { McpServer } from '@modelcontextprotocol/server'; -import { type ZodMatrixEntry, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { zodTestMatrix } from '@modelcontextprotocol/test-helpers'; describe.each(zodTestMatrix)('Issue #1277: $zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/test/integration/test/server/context.test.ts b/test/integration/test/server/context.test.ts new file mode 100644 index 000000000..10ecc0e98 --- /dev/null +++ b/test/integration/test/server/context.test.ts @@ -0,0 +1,273 @@ +import { Client } from '@modelcontextprotocol/client'; +import type { RequestHandlerExtra, ServerNotification, ServerRequest } from '@modelcontextprotocol/core'; +import { + CallToolResultSchema, + GetPromptResultSchema, + InMemoryTransport, + ListResourcesResultSchema, + LoggingMessageNotificationSchema, + ReadResourceResultSchema +} from '@modelcontextprotocol/core'; +import { Context, McpServer, ResourceTemplate } from '@modelcontextprotocol/server'; +import { z } from 'zod/v4'; + +describe('Context', () => { + /*** + * Test: `extra` provided to callbacks is Context (parameterized) + */ + type Seen = { isContext: boolean; hasRequestId: boolean }; + const contextCases: Array<[string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise]> = + [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + (_args: { name: string }, extra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + mcpServer.registerResource('ctx-resource', 'test://res/1', { title: 'ctx-resource' }, async (_uri, extra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + }); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + const template = new ResourceTemplate('test://items/{id}', { + list: async extra => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _extra) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + mcpServer.registerPrompt('ctx-prompt', {}, async extra => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { messages: [] }; + }); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass Context as extra to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + + const logLevelsThroughContext = ['debug', 'info', 'warning', 'error'] as const; + + //it.each for each log level, test that logging message is sent to client + it.each(logLevelsThroughContext)('should send logging message to client for %s level from Context', async level => { + const mcpServer = new McpServer( + { name: 'ctx-test', version: '1.0' }, + { + capabilities: { + logging: {} + } + } + ); + const client = new Client( + { name: 'ctx-client', version: '1.0' }, + { + capabilities: {} + } + ); + + let seen = 0; + + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + seen++; + expect(notification.params.level).toBe(level); + expect(notification.params.data).toBe('Test message'); + expect(notification.params._meta?.test).toBe('test'); + expect(notification.params._meta?.sessionId).toBe('sample-session-id'); + return; + }); + + mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, extra) => { + await extra.loggingNotification[level]('Test message', { test: 'test' }, 'sample-session-id'); + await extra.loggingNotification.log( + { + level, + data: 'Test message', + logger: 'test-logger-namespace', + _meta: { + test: 'test', + sessionId: 'sample-session-id' + } + }, + 'sample-session-id' + ); + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { name: 'ctx-log-test', arguments: { name: 'ctx-log-test-name' } } + }, + CallToolResultSchema + ); + + // two messages should have been sent - one from the .log method and one from the .debug/info/warning/error method + expect(seen).toBe(2); + + expect(result.content).toHaveLength(1); + expect(result.content[0]).toMatchObject({ + type: 'text', + text: 'ok' + }); + }); + describe('Legacy RequestHandlerExtra API', () => { + const contextCases: Array< + [string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise] + > = [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + (_args: { name: string }, extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + mcpServer.registerResource( + 'ctx-resource', + 'test://res/1', + { title: 'ctx-resource' }, + async (_uri, extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + } + ); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + const template = new ResourceTemplate('test://items/{id}', { + list: async (extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _extra) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + mcpServer.registerPrompt('ctx-prompt', {}, async (extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { messages: [] }; + }); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass Context as extra to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + }); +}); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index f7bcececc..7e925b4ec 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1,26 +1,27 @@ import { Client } from '@modelcontextprotocol/client'; -import { getDisplayName, InMemoryTaskStore, InMemoryTransport, UriTemplate } from '@modelcontextprotocol/core'; +import type { CallToolResult, Notification, ServerNotification, ServerRequest, TextContent } from '@modelcontextprotocol/core'; import { - type CallToolResult, CallToolResultSchema, CompleteResultSchema, ElicitRequestSchema, ErrorCode, + getDisplayName, GetPromptResultSchema, + InMemoryTaskStore, + InMemoryTransport, ListPromptsResultSchema, ListResourcesResultSchema, ListResourceTemplatesResultSchema, ListToolsResultSchema, LoggingMessageNotificationSchema, - type Notification, ReadResourceResultSchema, - type TextContent, + UriTemplate, UrlElicitationRequiredError } from '@modelcontextprotocol/core'; - -import { completable } from '../../../../packages/server/src/server/completable.js'; -import { McpServer, ResourceTemplate } from '../../../../packages/server/src/server/mcp.js'; -import { type ZodMatrixEntry, zodTestMatrix } from '../../../../packages/server/test/server/__fixtures__/zodTestMatrix.js'; +import type { ContextInterface } from '@modelcontextprotocol/server'; +import { completable, Context, McpServer, ResourceTemplate } from '@modelcontextprotocol/server'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { zodTestMatrix } from '@modelcontextprotocol/test-helpers'; function createLatch() { let latch = false; @@ -241,7 +242,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { sendNotification: () => { throw new Error('Not implemented'); } - }); + } as unknown as ContextInterface); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); }); @@ -1901,16 +1902,19 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async (_args, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000 }); return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) throw new Error('Task not found'); return task; }, getTaskResult: async (_args, extra) => { - return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + if (!extra.taskStore) throw new Error('Task store not found'); + return (await extra.taskStore.getTaskResult(extra.taskId!)) as CallToolResult; } } ); @@ -1970,16 +1974,17 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async (_args, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000 }); return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + const task = await extra.taskStore?.getTask(extra.taskId!); if (!task) throw new Error('Task not found'); return task; }, getTaskResult: async (_args, extra) => { - return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult; + return (await extra.taskStore?.getTaskResult(extra.taskId!)) as CallToolResult; } } ); @@ -4385,17 +4390,20 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }) } }, - async ({ department, name }) => ({ - messages: [ - { - role: 'assistant', - content: { - type: 'text', - text: `Hello ${name}, welcome to the ${department} team!` + async ({ department, name }, extra: ContextInterface) => { + expect(extra).toBeInstanceOf(Context); + return { + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}, welcome to the ${department} team!` + } } - } - ] - }) + ] + }; + } ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -6300,6 +6308,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ input }, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6315,14 +6324,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async (_input, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } @@ -6405,6 +6416,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ value }, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6421,14 +6433,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async (_value, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } @@ -6513,6 +6527,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async ({ data }, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6520,6 +6535,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Simulate async work setTimeout(async () => { + if (!store) throw new Error('Task store not found'); await store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text' as const, text: `Completed: ${data}` }] }); @@ -6529,14 +6545,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async (_data, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } @@ -6630,6 +6648,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async extra => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6736,6 +6755,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async extra => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); // Capture taskStore for use in setTimeout @@ -6750,14 +6770,16 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { return { task }; }, getTask: async extra => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async extra => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } @@ -6823,18 +6845,21 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }, { createTask: async (_args, extra) => { + if (!extra.taskStore) throw new Error('Task store not found'); const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); return { task }; }, getTask: async (_args, extra) => { - const task = await extra.taskStore.getTask(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const task = await extra.taskStore.getTask(extra.taskId!); if (!task) { throw new Error('Task not found'); } return task; }, getTaskResult: async (_args, extra) => { - const result = await extra.taskStore.getTaskResult(extra.taskId); + if (!extra.taskStore) throw new Error('Task store not found'); + const result = await extra.taskStore.getTaskResult(extra.taskId!); return result as CallToolResult; } } diff --git a/test/integration/test/stateManagementStreamableHttp.test.ts b/test/integration/test/stateManagementStreamableHttp.test.ts index c33100efa..6839cba6b 100644 --- a/test/integration/test/stateManagementStreamableHttp.test.ts +++ b/test/integration/test/stateManagementStreamableHttp.test.ts @@ -1,5 +1,6 @@ import { randomUUID } from 'node:crypto'; -import { createServer, type Server } from 'node:http'; +import type { Server } from 'node:http'; +import { createServer } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import { @@ -11,8 +12,8 @@ import { McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; -import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; -import { type ZodMatrixEntry, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; +import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; +import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers'; describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts index 1e4d8a0fd..178a95202 100644 --- a/test/integration/test/taskResumability.test.ts +++ b/test/integration/test/taskResumability.test.ts @@ -2,13 +2,13 @@ import { randomUUID } from 'node:crypto'; import { createServer, type Server } from 'node:http'; import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; +import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; import { CallToolResultSchema, LoggingMessageNotificationSchema, McpServer, StreamableHTTPServerTransport } from '@modelcontextprotocol/server'; -import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; import type { ZodMatrixEntry } from '@modelcontextprotocol/test-helpers'; import { listenOnRandomPort, zodTestMatrix } from '@modelcontextprotocol/test-helpers';