diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index d6b01e17f..57a27a3fd 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2025 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.client.transport; @@ -23,6 +23,7 @@ import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpHttpClientAuthorizationErrorHandler; import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.json.McpJsonDefaults; @@ -50,6 +51,7 @@ import reactor.core.publisher.Mono; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; +import reactor.util.retry.Retry; /** * An implementation of the Streamable HTTP protocol as defined by the @@ -72,6 +74,7 @@ *

* * @author Christian Tzolov + * @author Daniel Garnier-Moiroux * @see Streamable * HTTP transport specification @@ -115,6 +118,8 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final boolean openConnectionOnStartup; + private final McpHttpClientAuthorizationErrorHandler authorizationErrorHandler; + private final boolean resumableStreams; private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; @@ -132,7 +137,7 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, - List supportedProtocolVersions) { + McpHttpClientAuthorizationErrorHandler authorizationErrorHandler, List supportedProtocolVersions) { this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; @@ -140,6 +145,7 @@ private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient h this.endpoint = endpoint; this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; + this.authorizationErrorHandler = authorizationErrorHandler; this.activeSession.set(createTransportSession()); this.httpRequestCustomizer = httpRequestCustomizer; this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); @@ -239,7 +245,6 @@ public Mono closeGracefully() { } private Mono reconnect(McpTransportStream stream) { - return Mono.deferContextual(ctx -> { if (stream != null) { @@ -275,121 +280,120 @@ private Mono reconnect(McpTransportStream stream) { var transportContext = connectionCtx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); }) - .flatMapMany( - requestBuilder -> Flux.create( - sseSink -> this.httpClient - .sendAsync(requestBuilder.build(), - responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, - sseSink)) - .whenComplete((response, throwable) -> { - if (throwable != null) { - sseSink.error(throwable); - } - else { - logger.debug("SSE connection established successfully"); - } - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { - - if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - String data = responseEvent.sseEvent().data(); - // Per 2025-11-25 spec (SEP-1699), servers may - // send SSE events - // with empty data to prime the client for - // reconnection. - // Skip these events as they contain no JSON-RPC - // message. - if (data == null || data.isBlank()) { - logger.debug("Skipping SSE event with empty data (stream primer)"); - return Flux.empty(); - } - try { - // We don't support batching ATM and probably - // won't since the next version considers - // removing it. - McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.jsonMapper, data); - - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(responseEvent.sseEvent().id()), - List.of(message)); - - McpTransportStream sessionStream = stream != null ? stream - : new DefaultMcpTransportStream<>(this.resumableStreams, - this::reconnect); - logger.debug("Connected stream {}", sessionStream.streamId()); - - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); - - } - catch (IOException ioException) { - return Flux.error(new McpTransportException( - "Error parsing JSON-RPC message: " + responseEvent, ioException)); - } - } - else { - logger.debug("Received SSE event with type: {}", responseEvent.sseEvent()); - return Flux.empty(); - } - } - else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed - logger - .debug("The server does not support SSE streams, using request-response mode."); + .flatMapMany(requestBuilder -> Flux.create(sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), this.toSendMessageBodySubscriber(sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })).flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + if (statusCode == 401 || statusCode == 403) { + logger.debug("Authorization error in reconnect with code {}", statusCode); + return Mono.error( + new McpHttpClientTransportAuthorizationException( + "Authorization error connecting to SSE stream", + responseEvent.responseInfo())); + } + + if (!(responseEvent instanceof ResponseSubscribers.SseResponseEvent sseResponseEvent)) { + return Flux.error(new McpTransportException( + "Unrecognized server error when connecting to SSE stream, status code: " + + statusCode)); + } + else if (statusCode >= 200 && statusCode < 300) { + if (MESSAGE_EVENT_TYPE.equals(sseResponseEvent.sseEvent().event())) { + String data = sseResponseEvent.sseEvent().data(); + // Per 2025-11-25 spec (SEP-1699), servers may + // send SSE events + // with empty data to prime the client for + // reconnection. + // Skip these events as they contain no JSON-RPC + // message. + if (data == null || data.isBlank()) { + logger.debug("Skipping SSE event with empty data (stream primer)"); return Flux.empty(); } - else if (statusCode == NOT_FOUND) { - - if (transportSession != null && transportSession.sessionId().isPresent()) { - // only if the request was sent with a session id - // and the response is 404, we consider it a - // session not found error. - logger.debug("Session not found for session ID: {}", - transportSession.sessionId().get()); - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - return Flux.error( - new McpTransportException("Server Not Found. Status code:" + statusCode - + ", response-event:" + responseEvent)); - } - else if (statusCode == BAD_REQUEST) { - if (transportSession != null && transportSession.sessionId().isPresent()) { - // only if the request was sent with a session id - // and thre response is 404, we consider it a - // session not found error. - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - return Flux.error( - new McpTransportException("Bad Request. Status code:" + statusCode - + ", response-event:" + responseEvent)); + try { + // We don't support batching ATM and probably + // won't since the next version considers + // removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.jsonMapper, data); - } + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(sseResponseEvent.sseEvent().id()), List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); - return Flux.error(new McpTransportException( - "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); - }).flatMap( - jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) - .onErrorMap(CompletionException.class, t -> t.getCause()) - .onErrorComplete(t -> { - this.handleException(t); - return true; - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); } - })) + } + else { + logger.debug("Received SSE event with type: {}", sseResponseEvent.sseEvent()); + return Flux.empty(); + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and the response is 404, we consider it a + // session not found error. + logger.debug("Session not found for session ID: {}", + transportSession.sessionId().get()); + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Server Not Found. Status code:" + statusCode + + ", response-event:" + responseEvent)); + } + else if (statusCode == BAD_REQUEST) { + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and thre response is 404, we consider it a + // session not found error. + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Bad Request. Status code:" + statusCode + ", response-event:" + responseEvent)); + } + return Flux.error(new McpTransportException( + "Received unrecognized SSE event type: " + sseResponseEvent.sseEvent().event())); + }) + .retryWhen(authorizationErrorRetrySpec()) + .flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + })) .contextWrite(ctx) .subscribe(); @@ -400,6 +404,25 @@ else if (statusCode == BAD_REQUEST) { } + private Retry authorizationErrorRetrySpec() { + return Retry.from(companion -> companion.flatMap(retrySignal -> { + if (!(retrySignal.failure() instanceof McpHttpClientTransportAuthorizationException authException)) { + return Mono.error(retrySignal.failure()); + } + if (retrySignal.totalRetriesInARow() >= this.authorizationErrorHandler.maxRetries()) { + return Mono.error(retrySignal.failure()); + } + return Mono.deferContextual(ctx -> { + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono + .from(this.authorizationErrorHandler.handle(authException.getResponseInfo(), transportContext)) + .switchIfEmpty(Mono.just(false)) + .flatMap(shouldRetry -> shouldRetry ? Mono.just(retrySignal.totalRetries()) + : Mono.error(retrySignal.failure())); + }); + })); + } + private BodyHandler toSendMessageBodySubscriber(FluxSink sink) { BodyHandler responseBodyHandler = responseInfo -> { @@ -478,6 +501,13 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); })).flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + if (statusCode == 401 || statusCode == 403) { + logger.debug("Authorization error in sendMessage with code {}", statusCode); + return Mono.error(new McpHttpClientTransportAuthorizationException( + "Authorization error when sending message", responseEvent.responseInfo())); + } + if (transportSession.markInitialized( responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { // Once we have a session, we try to open an async stream for @@ -488,8 +518,6 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { String sessionRepresentation = sessionIdOrPlaceholder(transportSession); - int statusCode = responseEvent.responseInfo().statusCode(); - if (statusCode >= 200 && statusCode < 300) { String contentType = responseEvent.responseInfo() @@ -605,6 +633,7 @@ else if (statusCode == BAD_REQUEST) { return Flux.error( new RuntimeException("Failed to send message: " + responseEvent)); }) + .retryWhen(authorizationErrorRetrySpec()) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) .onErrorMap(CompletionException.class, t -> t.getCause()) .onErrorComplete(t -> { @@ -664,6 +693,8 @@ public static class Builder { private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18, ProtocolVersions.MCP_2025_11_25); + private McpHttpClientAuthorizationErrorHandler authorizationErrorHandler = McpHttpClientAuthorizationErrorHandler.NOOP; + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -801,6 +832,17 @@ public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer as return this; } + /** + * Sets the handler to be used when the server responds with HTTP 401 or HTTP 403 + * when sending a message. + * @param authorizationErrorHandler the handler + * @return this builder + */ + public Builder authorizationErrorHandler(McpHttpClientAuthorizationErrorHandler authorizationErrorHandler) { + this.authorizationErrorHandler = authorizationErrorHandler; + return this; + } + /** * Sets the connection timeout for the HTTP client. * @param connectTimeout the connection timeout duration @@ -845,7 +887,7 @@ public HttpClientStreamableHttpTransport build() { HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup, - httpRequestCustomizer, supportedProtocolVersions); + httpRequestCustomizer, authorizationErrorHandler, supportedProtocolVersions); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java new file mode 100644 index 000000000..31e5ae95e --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java @@ -0,0 +1,31 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.spec.McpTransportException; + +/** + * Thrown when the MCP server responds with an authorization error (HTTP 401 or HTTP 403). + * Subclass of {@link McpTransportException} for targeted retry handling in + * {@link HttpClientStreamableHttpTransport}. + * + * @author Daniel Garnier-Moiroux + */ +public class McpHttpClientTransportAuthorizationException extends McpTransportException { + + private final HttpResponse.ResponseInfo responseInfo; + + public McpHttpClientTransportAuthorizationException(String message, HttpResponse.ResponseInfo responseInfo) { + super(message); + this.responseInfo = responseInfo; + } + + public HttpResponse.ResponseInfo getResponseInfo() { + return responseInfo; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java new file mode 100644 index 000000000..c98fac61d --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java @@ -0,0 +1,104 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.client.transport.McpHttpClientTransportAuthorizationException; +import io.modelcontextprotocol.common.McpTransportContext; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * Handle security-related errors in HTTP-client based transports. This class handles MCP + * server responses with status code 401 and 403. + * + * @see MCP + * Specification: Authorization + * @author Daniel Garnier-Moiroux + */ +public interface McpHttpClientAuthorizationErrorHandler { + + /** + * Handle authorization error (HTTP 401 or 403), and signal whether the HTTP request + * should be retried or not. If the publisher returns true, the original transport + * method (connect, sendMessage) will be replayed with the original arguments. + * Otherwise, the transport will throw an + * {@link McpHttpClientTransportAuthorizationException}, indicating the error status. + *

+ * If the returned {@link Publisher} errors, the error will be propagated to the + * calling method, to be handled by the caller. + *

+ * The number of retries is bounded by {@link #maxRetries()}. + * @param responseInfo the HTTP response information + * @param context the MCP client transport context + * @return {@link Publisher} emitting true if the original request should be replayed, + * false otherwise. + */ + Publisher handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context); + + /** + * Maximum number of authorization error retries the transport will attempt. When the + * handler signals a retry via {@link #handle}, the transport will replay the original + * request at most this many times. If the authorization error persists after + * exhausting all retries, the transport will propagate the + * {@link McpHttpClientTransportAuthorizationException}. + *

+ * Defaults to {@code 1}. + * @return the maximum number of retries + */ + default int maxRetries() { + return 1; + } + + /** + * A no-op handler, used in the default use-case. + */ + McpHttpClientAuthorizationErrorHandler NOOP = new Noop(); + + /** + * Create a {@link McpHttpClientAuthorizationErrorHandler} from a synchronous handler. + * Will be subscribed on {@link Schedulers#boundedElastic()}. The handler may be + * blocking. + * @param handler the synchronous handler + * @return an async handler + */ + static McpHttpClientAuthorizationErrorHandler fromSync(Sync handler) { + return (info, context) -> Mono.fromCallable(() -> handler.handle(info, context)) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Synchronous authorization error handler. + */ + interface Sync { + + /** + * Handle authorization error (HTTP 401 or 403), and signal whether the HTTP + * request should be retried or not. If the return value is true, the original + * transport method (connect, sendMessage) will be replayed with the original + * arguments. Otherwise, the transport will throw an + * {@link McpHttpClientTransportAuthorizationException}, indicating the error + * status. + * @param responseInfo the HTTP response information + * @param context the MCP client transport context + * @return true if the original request should be replayed, false otherwise. + */ + boolean handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context); + + } + + class Noop implements McpHttpClientAuthorizationErrorHandler { + + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context) { + return Mono.just(false); + } + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java new file mode 100644 index 000000000..2812522f5 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.common.McpTransportContext; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +import static org.mockito.Mockito.mock; + +/** + * @author Daniel Garnier-Moiroux + */ +class McpHttpClientAuthorizationErrorHandlerTest { + + private final HttpResponse.ResponseInfo responseInfo = mock(HttpResponse.ResponseInfo.class); + + private final McpTransportContext context = McpTransportContext.EMPTY; + + @Test + void whenTrueThenRetry() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> true); + StepVerifier.create(handler.handle(responseInfo, context)).expectNext(true).verifyComplete(); + } + + @Test + void whenFalseThenError() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> false); + StepVerifier.create(handler.handle(responseInfo, context)).expectNext(false).verifyComplete(); + } + + @Test + void whenExceptionThenPropagate() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> { + throw new IllegalStateException("sync handler error"); + }); + StepVerifier.create(handler.handle(responseInfo, context)) + .expectErrorMatches(t -> t instanceof IllegalStateException && t.getMessage().equals("sync handler error")) + .verify(); + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java index b82d6eb2c..c4857e5b4 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -1,26 +1,24 @@ /* - * Copyright 2025-2025 the original author or authors. + * Copyright 2025-2026 the original author or authors. */ package io.modelcontextprotocol.client.transport; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; - import java.io.IOException; import java.net.InetSocketAddress; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; +import java.util.function.Predicate; import com.sun.net.httpserver.HttpServer; - +import io.modelcontextprotocol.client.transport.customizer.McpHttpClientAuthorizationErrorHandler; +import io.modelcontextprotocol.common.McpTransportContext; +import org.reactivestreams.Publisher; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; @@ -28,14 +26,30 @@ import io.modelcontextprotocol.spec.McpTransportException; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.ProtocolVersions; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + /** * Tests for error handling changes in HttpClientStreamableHttpTransport. Specifically * tests the distinction between session-related errors and general transport errors for * 404 and 400 status codes. * * @author Christian Tzolov + * @author Daniel Garnier-Moiroux */ @Timeout(15) public class HttpClientStreamableHttpTransportErrorHandlingTest { @@ -46,11 +60,17 @@ public class HttpClientStreamableHttpTransportErrorHandlingTest { private HttpServer server; - private AtomicReference serverResponseStatus = new AtomicReference<>(200); + private final AtomicInteger serverResponseStatus = new AtomicInteger(200); + + private final AtomicInteger serverSseResponseStatus = new AtomicInteger(200); - private AtomicReference currentServerSessionId = new AtomicReference<>(null); + private final AtomicReference currentServerSessionId = new AtomicReference<>(null); - private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + private final AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + + private final AtomicInteger processedMessagesCount = new AtomicInteger(0); + + private final AtomicInteger processedSseConnectCount = new AtomicInteger(0); private McpClientTransport transport; @@ -88,6 +108,20 @@ else if ("POST".equals(httpExchange.getRequestMethod())) { else { httpExchange.sendResponseHeaders(status, 0); } + processedMessagesCount.incrementAndGet(); + } + else if ("GET".equals(httpExchange.getRequestMethod())) { + int status = serverSseResponseStatus.get(); + if (status == 200) { + httpExchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + httpExchange.sendResponseHeaders(200, 0); + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + httpExchange.getResponseBody().write(sseData.getBytes()); + } + else { + httpExchange.sendResponseHeaders(status, 0); + } + processedSseConnectCount.incrementAndGet(); } httpExchange.close(); }); @@ -103,6 +137,7 @@ void stopServer() { if (server != null) { server.stop(0); } + StepVerifier.create(transport.closeGracefully()).verifyComplete(); } /** @@ -334,6 +369,386 @@ else if (status == 404) { StepVerifier.create(transport.closeGracefully()).verifyComplete(); } + @Nested + class AuthorizationError { + + @Nested + class SendMessage { + + @ParameterizedTest + @ValueSource(ints = { 401, 403 }) + void invokeHandler(int httpStatus) { + serverResponseStatus.set(httpStatus); + + AtomicReference capturedResponseInfo = new AtomicReference<>(); + AtomicReference capturedContext = new AtomicReference<>(); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + capturedResponseInfo.set(responseInfo); + capturedContext.set(context); + return Mono.just(false); + }) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(httpStatus)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + assertThat(capturedResponseInfo.get()).isNotNull(); + assertThat(capturedResponseInfo.get().statusCode()).isEqualTo(httpStatus); + assertThat(capturedContext.get()).isNotNull(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void defaultHandler() { + serverResponseStatus.set(401); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST).build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retry() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + serverResponseStatus.set(200); + return Mono.just(true); + }) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())).verifyComplete(); + // initial request + retry + assertThat(processedMessagesCount.get()).isEqualTo(2); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retryAtMostOnce() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.just(true)) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + // initial request + 1 retry (maxRetries default is 1) + assertThat(processedMessagesCount.get()).isEqualTo(2); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void customMaxRetries() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler(new McpHttpClientAuthorizationErrorHandler() { + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, + McpTransportContext context) { + return Mono.just(true); + } + + @Override + public int maxRetries() { + return 3; + } + }) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + // initial request + 3 retries + assertThat(processedMessagesCount.get()).isEqualTo(4); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void noRetry() { + serverResponseStatus.set(401); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.just(false)) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void propagateHandlerError() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler( + (responseInfo, context) -> Mono.error(new IllegalStateException("handler error"))) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(throwable -> throwable instanceof IllegalStateException + && throwable.getMessage().equals("handler error")) + .verify(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void emptyHandler() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.empty()) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + } + + @Nested + class Connect { + + @ParameterizedTest + @ValueSource(ints = { 401, 403 }) + void invokeHandler(int httpStatus) { + serverSseResponseStatus.set(httpStatus); + @SuppressWarnings("unchecked") + AtomicReference capturedException = new AtomicReference<>(); + + AtomicReference capturedResponseInfo = new AtomicReference<>(); + AtomicReference capturedContext = new AtomicReference<>(); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + capturedResponseInfo.set(responseInfo); + capturedContext.set(context); + return Mono.just(false); + }) + .openConnectionOnStartup(true) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedResponseInfo.get()).isNotNull(); + assertThat(capturedResponseInfo.get().statusCode()).isEqualTo(httpStatus); + assertThat(capturedContext.get()).isNotNull(); + assertThat(capturedException.get()).hasMessage("Authorization error connecting to SSE stream") + .asInstanceOf(type(McpHttpClientTransportAuthorizationException.class)) + .extracting(McpHttpClientTransportAuthorizationException::getResponseInfo) + .extracting(HttpResponse.ResponseInfo::statusCode) + .isEqualTo(httpStatus); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void defaultHandler() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + StepVerifier.create(authTransport.connect(msg -> msg)).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retry() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + serverSseResponseStatus.set(200); + return Mono.just(true); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + var messageHandlerClosed = new AtomicBoolean(false); + StepVerifier + .create(authTransport + .connect(msg -> msg.doOnNext(messages::add).doFinally(s -> messageHandlerClosed.set(true)))) + .verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(messageHandlerClosed).isTrue()); + assertThat(processedSseConnectCount.get()).isEqualTo(2); + assertThat(messages).hasSize(1); + assertThat(capturedException.get()).isNull(); + assertThat(messageHandlerClosed.get()).isTrue(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retryAtMostOnce() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + return Mono.just(true); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(capturedException.get()).isNotNull()); + // initial request + 1 retry (maxRetries default is 1) + assertThat(processedSseConnectCount.get()).isEqualTo(2); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void customMaxRetries() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler(new McpHttpClientAuthorizationErrorHandler() { + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, + McpTransportContext context) { + return Mono.just(true); + } + + @Override + public int maxRetries() { + return 3; + } + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(capturedException.get()).isNotNull()); + // initial request + 3 retries + assertThat(processedSseConnectCount.get()).isEqualTo(4); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void noRetry() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + // if there was a retry, the request would succeed. + serverSseResponseStatus.set(200); + return Mono.just(false); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void emptyHandler() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> Mono.empty()) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void propagateHandlerError() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler( + (responseInfo, context) -> Mono.error(new IllegalStateException("handler error"))) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(IllegalStateException.class) + .hasMessage("handler error"); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + } + + private static Predicate authorizationError(int httpStatus) { + return throwable -> throwable instanceof McpHttpClientTransportAuthorizationException + && throwable.getMessage().contains("Authorization error") + && ((McpHttpClientTransportAuthorizationException) throwable).getResponseInfo() + .statusCode() == httpStatus; + } + + } + private McpSchema.JSONRPCRequest createTestRequestMessage() { var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, McpSchema.ClientCapabilities.builder().roots(true).build(),