From fa1861e15363414b0524b0944161573b197c61f6 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Tue, 10 Mar 2026 17:11:47 +0100 Subject: [PATCH 1/2] HttpClientStreamHttpTransport: add authorization error handler - Closes #240 Signed-off-by: Daniel Garnier-Moiroux --- .../HttpClientStreamableHttpTransport.java | 270 ++++++++------ .../McpHttpClientTransportException.java | 34 ++ ...cpHttpClientAuthorizationErrorHandler.java | 116 ++++++ ...tpClientAuthorizationErrorHandlerTest.java | 106 ++++++ ...eamableHttpTransportErrorHandlingTest.java | 335 +++++++++++++++++- 5 files changed, 730 insertions(+), 131 deletions(-) create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportException.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index d6b01e17f..5310433ba 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2025 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.client.transport; @@ -23,6 +23,7 @@ import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpHttpClientAuthorizationErrorHandler; import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.json.McpJsonDefaults; @@ -72,6 +73,7 @@ *

* * @author Christian Tzolov + * @author Daniel Garnier-Moiroux * @see Streamable * HTTP transport specification @@ -115,6 +117,8 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final boolean openConnectionOnStartup; + private final McpHttpClientAuthorizationErrorHandler authorizationErrorHandler; + private final boolean resumableStreams; private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; @@ -132,7 +136,7 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, - List supportedProtocolVersions) { + McpHttpClientAuthorizationErrorHandler authorizationErrorHandler, List supportedProtocolVersions) { this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; @@ -140,6 +144,7 @@ private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient h this.endpoint = endpoint; this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; + this.authorizationErrorHandler = authorizationErrorHandler; this.activeSession.set(createTransportSession()); this.httpRequestCustomizer = httpRequestCustomizer; this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); @@ -239,7 +244,6 @@ public Mono closeGracefully() { } private Mono reconnect(McpTransportStream stream) { - return Mono.deferContextual(ctx -> { if (stream != null) { @@ -275,121 +279,128 @@ private Mono reconnect(McpTransportStream stream) { var transportContext = connectionCtx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); }) - .flatMapMany( - requestBuilder -> Flux.create( - sseSink -> this.httpClient - .sendAsync(requestBuilder.build(), - responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, - sseSink)) - .whenComplete((response, throwable) -> { - if (throwable != null) { - sseSink.error(throwable); - } - else { - logger.debug("SSE connection established successfully"); - } - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { - - if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - String data = responseEvent.sseEvent().data(); - // Per 2025-11-25 spec (SEP-1699), servers may - // send SSE events - // with empty data to prime the client for - // reconnection. - // Skip these events as they contain no JSON-RPC - // message. - if (data == null || data.isBlank()) { - logger.debug("Skipping SSE event with empty data (stream primer)"); - return Flux.empty(); - } - try { - // We don't support batching ATM and probably - // won't since the next version considers - // removing it. - McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.jsonMapper, data); - - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(responseEvent.sseEvent().id()), - List.of(message)); - - McpTransportStream sessionStream = stream != null ? stream - : new DefaultMcpTransportStream<>(this.resumableStreams, - this::reconnect); - logger.debug("Connected stream {}", sessionStream.streamId()); - - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); - - } - catch (IOException ioException) { - return Flux.error(new McpTransportException( - "Error parsing JSON-RPC message: " + responseEvent, ioException)); - } - } - else { - logger.debug("Received SSE event with type: {}", responseEvent.sseEvent()); - return Flux.empty(); - } - } - else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed - logger - .debug("The server does not support SSE streams, using request-response mode."); + .flatMapMany(requestBuilder -> Flux.create(sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), this.toSendMessageBodySubscriber(sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })).flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + if (statusCode == 401 || statusCode == 403) { + logger.debug("Authorization error in sendMessage with code {}", statusCode); + return Mono.deferContextual(innerCtx -> { + var transportContext = innerCtx.getOrDefault(McpTransportContext.KEY, + McpTransportContext.EMPTY); + return Mono.from(this.authorizationErrorHandler.onAuthorizationError( + responseEvent.responseInfo(), transportContext, Mono.defer(() -> { + logger.debug("Authorization error handled, retrying original request"); + return this.reconnect(stream).then(); + }), + Mono.error(new McpHttpClientTransportException( + "Authorization error connecting to SSE stream", + responseEvent.responseInfo())))) + .then(Mono.empty()); + }); + } + + if (!(responseEvent instanceof ResponseSubscribers.SseResponseEvent sseResponseEvent)) { + return Flux.error(new McpHttpClientTransportException( + "Unrecognized server error when connecting to SSE stream", + responseEvent.responseInfo())); + } + else if (statusCode >= 200 && statusCode < 300) { + if (MESSAGE_EVENT_TYPE.equals(sseResponseEvent.sseEvent().event())) { + String data = sseResponseEvent.sseEvent().data(); + // Per 2025-11-25 spec (SEP-1699), servers may + // send SSE events + // with empty data to prime the client for + // reconnection. + // Skip these events as they contain no JSON-RPC + // message. + if (data == null || data.isBlank()) { + logger.debug("Skipping SSE event with empty data (stream primer)"); return Flux.empty(); } - else if (statusCode == NOT_FOUND) { - - if (transportSession != null && transportSession.sessionId().isPresent()) { - // only if the request was sent with a session id - // and the response is 404, we consider it a - // session not found error. - logger.debug("Session not found for session ID: {}", - transportSession.sessionId().get()); - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - return Flux.error( - new McpTransportException("Server Not Found. Status code:" + statusCode - + ", response-event:" + responseEvent)); - } - else if (statusCode == BAD_REQUEST) { - if (transportSession != null && transportSession.sessionId().isPresent()) { - // only if the request was sent with a session id - // and thre response is 404, we consider it a - // session not found error. - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - return Flux.error( - new McpTransportException("Bad Request. Status code:" + statusCode - + ", response-event:" + responseEvent)); + try { + // We don't support batching ATM and probably + // won't since the next version considers + // removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.jsonMapper, data); - } + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(sseResponseEvent.sseEvent().id()), List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); - return Flux.error(new McpTransportException( - "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); - }).flatMap( - jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) - .onErrorMap(CompletionException.class, t -> t.getCause()) - .onErrorComplete(t -> { - this.handleException(t); - return true; - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); } - })) + } + else { + logger.debug("Received SSE event with type: {}", sseResponseEvent.sseEvent()); + return Flux.empty(); + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and the response is 404, we consider it a + // session not found error. + logger.debug("Session not found for session ID: {}", + transportSession.sessionId().get()); + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Server Not Found. Status code:" + statusCode + + ", response-event:" + responseEvent)); + } + else if (statusCode == BAD_REQUEST) { + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and thre response is 404, we consider it a + // session not found error. + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Bad Request. Status code:" + statusCode + ", response-event:" + responseEvent)); + } + return Flux.error(new McpTransportException( + "Received unrecognized SSE event type: " + sseResponseEvent.sseEvent().event())); + }) + .flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + })) .contextWrite(ctx) .subscribe(); @@ -478,6 +489,22 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); })).flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + if (statusCode == 401 || statusCode == 403) { + logger.debug("Authorization error in sendMessage with code {}", statusCode); + return Mono.deferContextual(ctx -> { + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.authorizationErrorHandler + .onAuthorizationError(responseEvent.responseInfo(), transportContext, Mono.defer(() -> { + logger.debug("Authorization error handled, retrying original request"); + return this.sendMessage(sentMessage); + }), Mono.error(new McpHttpClientTransportException( + "Authorization error when sending message", responseEvent.responseInfo())))) + .doOnSuccess(s -> deliveredSink.success()) + .then(Mono.empty()); + }); + } + if (transportSession.markInitialized( responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { // Once we have a session, we try to open an async stream for @@ -488,8 +515,6 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { String sessionRepresentation = sessionIdOrPlaceholder(transportSession); - int statusCode = responseEvent.responseInfo().statusCode(); - if (statusCode >= 200 && statusCode < 300) { String contentType = responseEvent.responseInfo() @@ -664,6 +689,8 @@ public static class Builder { private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18, ProtocolVersions.MCP_2025_11_25); + private McpHttpClientAuthorizationErrorHandler authorizationErrorHandler = McpHttpClientAuthorizationErrorHandler.NOOP; + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -801,6 +828,17 @@ public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer as return this; } + /** + * Sets the handler to be used when the server responds with HTTP 401 or HTTP 403 + * when sending a message. + * @param authorizationErrorHandler the handler + * @return this builder + */ + public Builder authorizationErrorHandler(McpHttpClientAuthorizationErrorHandler authorizationErrorHandler) { + this.authorizationErrorHandler = authorizationErrorHandler; + return this; + } + /** * Sets the connection timeout for the HTTP client. * @param connectTimeout the connection timeout duration @@ -845,7 +883,7 @@ public HttpClientStreamableHttpTransport build() { HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup, - httpRequestCustomizer, supportedProtocolVersions); + httpRequestCustomizer, authorizationErrorHandler, supportedProtocolVersions); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportException.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportException.java new file mode 100644 index 000000000..c4b082b7a --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.spec.McpTransportException; + +/** + * Authorization-related exception for {@link java.net.http.HttpClient}-based client + * transport. Thrown when the server responds with HTTP 401 or HTTP 403. Wraps the + * response info for further inspection of the headers and the status code. + * + * @see MCP + * Specification: Authorization + * @author Daniel Garnier-Moiroux + */ +public class McpHttpClientTransportException extends McpTransportException { + + private final HttpResponse.ResponseInfo responseInfo; + + public McpHttpClientTransportException(String message, HttpResponse.ResponseInfo responseInfo) { + super(message); + this.responseInfo = responseInfo; + } + + public HttpResponse.ResponseInfo getResponseInfo() { + return responseInfo; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java new file mode 100644 index 000000000..f544bc233 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java @@ -0,0 +1,116 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.client.transport.McpHttpClientTransportException; +import io.modelcontextprotocol.common.McpTransportContext; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * Handle security-related errors in HTTP-client based transports. This class handles MCP + * server responses with status code 401 and 403. + * + * @see MCP + * Specification: Authorization + * @author Daniel Garnier-Moiroux + */ +public interface McpHttpClientAuthorizationErrorHandler { + + /** + * Handle authorization error (HTTP 401 or 403), and signal whether the HTTP request + * should be retried or not. If the publisher returns true, the original transport + * method (connect, sendMessage) will be replayed with the original arguments. + * Otherwise, the transport will throw an {@link McpHttpClientTransportException}, + * indicating the error status. + *

+ * If the returned {@link Publisher} errors, the error will be propagated to the + * calling method, to be handled by the caller. + *

+ * The caller is responsible for bounding the number of retries. + * @param responseInfo the HTTP response information + * @param context the MCP client transport context + * @return {@link Publisher} emitting true if the original request should be replayed, + * false otherwise. + */ + Publisher handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context); + + /** + * A no-op handler, used in the default use-case. + */ + McpHttpClientAuthorizationErrorHandler NOOP = new Noop(); + + /** + * Handle authorization error (HTTP 401 or 403), and optionally retry the HTTP + * request, or trigger a transport error. To retry, use the {@code retryAction} + * publisher. To emit the default transport error, use the {@code defaultError} + * publisher. + *

+ * Optionally, the returned {@link Publisher} may error to trigger an out-of-band + * action. In that case, the error will be propagated to the calling method, to be + * handled by the caller. + *

+ * Defaults to {@link #handle(HttpResponse.ResponseInfo, McpTransportContext)}, and + * uses the boolean from the return value to decide whether it should retry the + * request. + * @param responseInfo the HTTP response information + * @param context the MCP client transport context + * @param retryAction handler to retry the original request + * @param defaultError handler to emit an error + * @return a {@link Publisher} to signal either an error or a retry + */ + default Publisher onAuthorizationError(HttpResponse.ResponseInfo responseInfo, McpTransportContext context, + Publisher retryAction, Publisher defaultError) { + return Mono.from(this.handle(responseInfo, context)) + .switchIfEmpty(Mono.just(false)) + .flatMap(shouldRetry -> shouldRetry != null && shouldRetry ? Mono.from(retryAction) + : Mono.from(defaultError)); + } + + /** + * Create a {@link McpHttpClientAuthorizationErrorHandler} from a synchronous handler. + * Will be subscribed on {@link Schedulers#boundedElastic()}. The handler may be + * blocking. + * @param handler the synchronous handler + * @return an async handler + */ + static McpHttpClientAuthorizationErrorHandler fromSync(Sync handler) { + return (info, context) -> Mono.fromCallable(() -> handler.handle(info, context)) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Synchronous authorization error handler. + */ + interface Sync { + + /** + * Handle authorization error (HTTP 401 or 403), and signal whether the HTTP + * request should be retried or not. If the return value is true, the original + * transport method (connect, sendMessage) will be replayed with the original + * arguments. Otherwise, the transport will throw an + * {@link McpHttpClientTransportException}, indicating the error status. + * @param responseInfo the HTTP response information + * @param context the MCP client transport context + * @return true if the original request should be replayed, false otherwise. + */ + boolean handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context); + + } + + class Noop implements McpHttpClientAuthorizationErrorHandler { + + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context) { + return Mono.just(false); + } + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java new file mode 100644 index 000000000..b935f95a5 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.common.McpTransportContext; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.mockito.Mockito.mock; + +/** + * @author Daniel Garnier-Moiroux + */ +class McpHttpClientAuthorizationErrorHandlerTest { + + private final HttpResponse.ResponseInfo responseInfo = mock(HttpResponse.ResponseInfo.class); + + private final McpTransportContext context = McpTransportContext.EMPTY; + + @Nested + class OnAuthorizationError { + + @Test + void whenTrueThenRetry() { + McpHttpClientAuthorizationErrorHandler handler = (info, ctx) -> Mono.just(true); + Mono retryAction = Mono.empty(); + Mono defaultError = Mono.error(new RuntimeException("should not be called")); + + StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) + .verifyComplete(); + } + + @Test + void whenFalseThenError() { + McpHttpClientAuthorizationErrorHandler handler = (info, ctx) -> Mono.just(false); + Mono retryAction = Mono.error(new RuntimeException("should not be called")); + Mono defaultError = Mono.error(new RuntimeException("authorization error")); + + StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) + .expectErrorMatches(t -> t instanceof RuntimeException && t.getMessage().equals("authorization error")) + .verify(); + } + + @Test + void whenErrorThenPropagate() { + McpHttpClientAuthorizationErrorHandler handler = (info, ctx) -> Mono + .error(new IllegalStateException("handler error")); + Mono retryAction = Mono.error(new RuntimeException("should not be called")); + Mono defaultError = Mono.error(new RuntimeException("should not be called")); + + StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) + .expectErrorMatches(t -> t instanceof IllegalStateException && t.getMessage().equals("handler error")) + .verify(); + } + + } + + @Nested + class FromSync { + + @Test + void whenTrueThenRetry() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> true); + Mono retryAction = Mono.empty(); + Mono defaultError = Mono.error(new RuntimeException("should not be called")); + + StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) + .verifyComplete(); + } + + @Test + void whenFalseThenError() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> false); + Mono retryAction = Mono.error(new RuntimeException("should not be called")); + Mono defaultError = Mono.error(new RuntimeException("authorization error")); + + StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) + .expectErrorMatches(t -> t instanceof RuntimeException && t.getMessage().equals("authorization error")) + .verify(); + } + + @Test + void whenExceptionThenPropagate() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> { + throw new IllegalStateException("sync handler error"); + }); + Mono retryAction = Mono.error(new RuntimeException("should not be called")); + Mono defaultError = Mono.error(new RuntimeException("should not be called")); + + StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) + .expectErrorMatches( + t -> t instanceof IllegalStateException && t.getMessage().equals("sync handler error")) + .verify(); + } + + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java index b82d6eb2c..20ae10dd3 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -1,26 +1,21 @@ /* - * Copyright 2025-2025 the original author or authors. + * Copyright 2025-2026 the original author or authors. */ package io.modelcontextprotocol.client.transport; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; - import java.io.IOException; import java.net.InetSocketAddress; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; +import java.util.function.Predicate; import com.sun.net.httpserver.HttpServer; - +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; @@ -28,14 +23,30 @@ import io.modelcontextprotocol.spec.McpTransportException; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.ProtocolVersions; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + /** * Tests for error handling changes in HttpClientStreamableHttpTransport. Specifically * tests the distinction between session-related errors and general transport errors for * 404 and 400 status codes. * * @author Christian Tzolov + * @author Daniel Garnier-Moiroux */ @Timeout(15) public class HttpClientStreamableHttpTransportErrorHandlingTest { @@ -46,11 +57,17 @@ public class HttpClientStreamableHttpTransportErrorHandlingTest { private HttpServer server; - private AtomicReference serverResponseStatus = new AtomicReference<>(200); + private final AtomicInteger serverResponseStatus = new AtomicInteger(200); + + private final AtomicInteger serverSseResponseStatus = new AtomicInteger(200); + + private final AtomicReference currentServerSessionId = new AtomicReference<>(null); - private AtomicReference currentServerSessionId = new AtomicReference<>(null); + private final AtomicReference lastReceivedSessionId = new AtomicReference<>(null); - private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + private final AtomicInteger processedMessagesCount = new AtomicInteger(0); + + private final AtomicInteger processedSseConnectCount = new AtomicInteger(0); private McpClientTransport transport; @@ -88,6 +105,20 @@ else if ("POST".equals(httpExchange.getRequestMethod())) { else { httpExchange.sendResponseHeaders(status, 0); } + processedMessagesCount.incrementAndGet(); + } + else if ("GET".equals(httpExchange.getRequestMethod())) { + int status = serverSseResponseStatus.get(); + if (status == 200) { + httpExchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + httpExchange.sendResponseHeaders(200, 0); + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + httpExchange.getResponseBody().write(sseData.getBytes()); + } + else { + httpExchange.sendResponseHeaders(status, 0); + } + processedSseConnectCount.incrementAndGet(); } httpExchange.close(); }); @@ -103,6 +134,7 @@ void stopServer() { if (server != null) { server.stop(0); } + StepVerifier.create(transport.closeGracefully()).verifyComplete(); } /** @@ -334,6 +366,279 @@ else if (status == 404) { StepVerifier.create(transport.closeGracefully()).verifyComplete(); } + @Nested + class AuthorizationError { + + @Nested + class SendMessage { + + @ParameterizedTest + @ValueSource(ints = { 401, 403 }) + void invokeHandler(int httpStatus) { + serverResponseStatus.set(httpStatus); + + AtomicReference capturedResponseInfo = new AtomicReference<>(); + AtomicReference capturedContext = new AtomicReference<>(); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + capturedResponseInfo.set(responseInfo); + capturedContext.set(context); + return Mono.just(false); + }) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(httpStatus)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + assertThat(capturedResponseInfo.get()).isNotNull(); + assertThat(capturedResponseInfo.get().statusCode()).isEqualTo(httpStatus); + assertThat(capturedContext.get()).isNotNull(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void defaultHandler() { + serverResponseStatus.set(401); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST).build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retry() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + serverResponseStatus.set(200); + return Mono.just(true); + }) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())).verifyComplete(); + // initial request + retry + assertThat(processedMessagesCount.get()).isEqualTo(2); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void noRetry() { + serverResponseStatus.set(401); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.just(false)) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void propagateHandlerError() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler( + (responseInfo, context) -> Mono.error(new IllegalStateException("handler error"))) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(throwable -> throwable instanceof IllegalStateException + && throwable.getMessage().equals("handler error")) + .verify(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void emptyHandler() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.empty()) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + } + + @Nested + class Connect { + + @ParameterizedTest + @ValueSource(ints = { 401, 403 }) + void invokeHandler(int httpStatus) { + serverSseResponseStatus.set(httpStatus); + @SuppressWarnings("unchecked") + AtomicReference capturedException = new AtomicReference<>(); + + AtomicReference capturedResponseInfo = new AtomicReference<>(); + AtomicReference capturedContext = new AtomicReference<>(); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + capturedResponseInfo.set(responseInfo); + capturedContext.set(context); + return Mono.just(false); + }) + .openConnectionOnStartup(true) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedResponseInfo.get()).isNotNull(); + assertThat(capturedResponseInfo.get().statusCode()).isEqualTo(httpStatus); + assertThat(capturedContext.get()).isNotNull(); + assertThat(capturedException.get()).hasMessage("Authorization error connecting to SSE stream") + .asInstanceOf(type(McpHttpClientTransportException.class)) + .extracting(McpHttpClientTransportException::getResponseInfo) + .extracting(HttpResponse.ResponseInfo::statusCode) + .isEqualTo(httpStatus); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void defaultHandler() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + StepVerifier.create(authTransport.connect(msg -> msg)).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retry() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + serverSseResponseStatus.set(200); + return Mono.just(true); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(2)); + assertThat(messages).hasSize(1); + assertThat(capturedException.get()).isNull(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void noRetry() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + // if there was a retry, the request would succeed. + serverSseResponseStatus.set(200); + return Mono.just(false); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void emptyHandler() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> Mono.empty()) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void propagateHandlerError() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler( + (responseInfo, context) -> Mono.error(new IllegalStateException("handler error"))) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(IllegalStateException.class) + .hasMessage("handler error"); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + } + + private static Predicate authorizationError(int httpStatus) { + return throwable -> throwable instanceof McpHttpClientTransportException + && throwable.getMessage().contains("Authorization error") + && ((McpHttpClientTransportException) throwable).getResponseInfo().statusCode() == httpStatus; + } + + } + private McpSchema.JSONRPCRequest createTestRequestMessage() { var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, McpSchema.ClientCapabilities.builder().roots(true).build(), From 3f45fafe6bc256122467d6803bec25ceea079401 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Thu, 12 Mar 2026 22:53:21 +0100 Subject: [PATCH 2/2] Introduce retry handling Signed-off-by: Daniel Garnier-Moiroux --- .../HttpClientStreamableHttpTransport.java | 60 ++++---- ...ClientTransportAuthorizationException.java | 31 +++++ .../McpHttpClientTransportException.java | 34 ----- ...cpHttpClientAuthorizationErrorHandler.java | 50 +++---- ...tpClientAuthorizationErrorHandlerTest.java | 98 +++----------- ...eamableHttpTransportErrorHandlingTest.java | 128 ++++++++++++++++-- 6 files changed, 221 insertions(+), 180 deletions(-) create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java delete mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportException.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 5310433ba..57a27a3fd 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -51,6 +51,7 @@ import reactor.core.publisher.Mono; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; +import reactor.util.retry.Retry; /** * An implementation of the Streamable HTTP protocol as defined by the @@ -291,26 +292,17 @@ private Mono reconnect(McpTransportStream stream) { })).flatMap(responseEvent -> { int statusCode = responseEvent.responseInfo().statusCode(); if (statusCode == 401 || statusCode == 403) { - logger.debug("Authorization error in sendMessage with code {}", statusCode); - return Mono.deferContextual(innerCtx -> { - var transportContext = innerCtx.getOrDefault(McpTransportContext.KEY, - McpTransportContext.EMPTY); - return Mono.from(this.authorizationErrorHandler.onAuthorizationError( - responseEvent.responseInfo(), transportContext, Mono.defer(() -> { - logger.debug("Authorization error handled, retrying original request"); - return this.reconnect(stream).then(); - }), - Mono.error(new McpHttpClientTransportException( - "Authorization error connecting to SSE stream", - responseEvent.responseInfo())))) - .then(Mono.empty()); - }); + logger.debug("Authorization error in reconnect with code {}", statusCode); + return Mono.error( + new McpHttpClientTransportAuthorizationException( + "Authorization error connecting to SSE stream", + responseEvent.responseInfo())); } if (!(responseEvent instanceof ResponseSubscribers.SseResponseEvent sseResponseEvent)) { - return Flux.error(new McpHttpClientTransportException( - "Unrecognized server error when connecting to SSE stream", - responseEvent.responseInfo())); + return Flux.error(new McpTransportException( + "Unrecognized server error when connecting to SSE stream, status code: " + + statusCode)); } else if (statusCode >= 200 && statusCode < 300) { if (MESSAGE_EVENT_TYPE.equals(sseResponseEvent.sseEvent().event())) { @@ -389,6 +381,7 @@ else if (statusCode == BAD_REQUEST) { return Flux.error(new McpTransportException( "Received unrecognized SSE event type: " + sseResponseEvent.sseEvent().event())); }) + .retryWhen(authorizationErrorRetrySpec()) .flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) .onErrorMap(CompletionException.class, t -> t.getCause()) .onErrorComplete(t -> { @@ -411,6 +404,25 @@ else if (statusCode == BAD_REQUEST) { } + private Retry authorizationErrorRetrySpec() { + return Retry.from(companion -> companion.flatMap(retrySignal -> { + if (!(retrySignal.failure() instanceof McpHttpClientTransportAuthorizationException authException)) { + return Mono.error(retrySignal.failure()); + } + if (retrySignal.totalRetriesInARow() >= this.authorizationErrorHandler.maxRetries()) { + return Mono.error(retrySignal.failure()); + } + return Mono.deferContextual(ctx -> { + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono + .from(this.authorizationErrorHandler.handle(authException.getResponseInfo(), transportContext)) + .switchIfEmpty(Mono.just(false)) + .flatMap(shouldRetry -> shouldRetry ? Mono.just(retrySignal.totalRetries()) + : Mono.error(retrySignal.failure())); + }); + })); + } + private BodyHandler toSendMessageBodySubscriber(FluxSink sink) { BodyHandler responseBodyHandler = responseInfo -> { @@ -492,17 +504,8 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { int statusCode = responseEvent.responseInfo().statusCode(); if (statusCode == 401 || statusCode == 403) { logger.debug("Authorization error in sendMessage with code {}", statusCode); - return Mono.deferContextual(ctx -> { - var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); - return Mono.from(this.authorizationErrorHandler - .onAuthorizationError(responseEvent.responseInfo(), transportContext, Mono.defer(() -> { - logger.debug("Authorization error handled, retrying original request"); - return this.sendMessage(sentMessage); - }), Mono.error(new McpHttpClientTransportException( - "Authorization error when sending message", responseEvent.responseInfo())))) - .doOnSuccess(s -> deliveredSink.success()) - .then(Mono.empty()); - }); + return Mono.error(new McpHttpClientTransportAuthorizationException( + "Authorization error when sending message", responseEvent.responseInfo())); } if (transportSession.markInitialized( @@ -630,6 +633,7 @@ else if (statusCode == BAD_REQUEST) { return Flux.error( new RuntimeException("Failed to send message: " + responseEvent)); }) + .retryWhen(authorizationErrorRetrySpec()) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) .onErrorMap(CompletionException.class, t -> t.getCause()) .onErrorComplete(t -> { diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java new file mode 100644 index 000000000..31e5ae95e --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java @@ -0,0 +1,31 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.spec.McpTransportException; + +/** + * Thrown when the MCP server responds with an authorization error (HTTP 401 or HTTP 403). + * Subclass of {@link McpTransportException} for targeted retry handling in + * {@link HttpClientStreamableHttpTransport}. + * + * @author Daniel Garnier-Moiroux + */ +public class McpHttpClientTransportAuthorizationException extends McpTransportException { + + private final HttpResponse.ResponseInfo responseInfo; + + public McpHttpClientTransportAuthorizationException(String message, HttpResponse.ResponseInfo responseInfo) { + super(message); + this.responseInfo = responseInfo; + } + + public HttpResponse.ResponseInfo getResponseInfo() { + return responseInfo; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportException.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportException.java deleted file mode 100644 index c4b082b7a..000000000 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportException.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2026-2026 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import java.net.http.HttpResponse; - -import io.modelcontextprotocol.spec.McpTransportException; - -/** - * Authorization-related exception for {@link java.net.http.HttpClient}-based client - * transport. Thrown when the server responds with HTTP 401 or HTTP 403. Wraps the - * response info for further inspection of the headers and the status code. - * - * @see MCP - * Specification: Authorization - * @author Daniel Garnier-Moiroux - */ -public class McpHttpClientTransportException extends McpTransportException { - - private final HttpResponse.ResponseInfo responseInfo; - - public McpHttpClientTransportException(String message, HttpResponse.ResponseInfo responseInfo) { - super(message); - this.responseInfo = responseInfo; - } - - public HttpResponse.ResponseInfo getResponseInfo() { - return responseInfo; - } - -} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java index f544bc233..c98fac61d 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java @@ -6,7 +6,7 @@ import java.net.http.HttpResponse; -import io.modelcontextprotocol.client.transport.McpHttpClientTransportException; +import io.modelcontextprotocol.client.transport.McpHttpClientTransportAuthorizationException; import io.modelcontextprotocol.common.McpTransportContext; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; @@ -27,13 +27,13 @@ public interface McpHttpClientAuthorizationErrorHandler { * Handle authorization error (HTTP 401 or 403), and signal whether the HTTP request * should be retried or not. If the publisher returns true, the original transport * method (connect, sendMessage) will be replayed with the original arguments. - * Otherwise, the transport will throw an {@link McpHttpClientTransportException}, - * indicating the error status. + * Otherwise, the transport will throw an + * {@link McpHttpClientTransportAuthorizationException}, indicating the error status. *

* If the returned {@link Publisher} errors, the error will be propagated to the * calling method, to be handled by the caller. *

- * The caller is responsible for bounding the number of retries. + * The number of retries is bounded by {@link #maxRetries()}. * @param responseInfo the HTTP response information * @param context the MCP client transport context * @return {@link Publisher} emitting true if the original request should be replayed, @@ -42,36 +42,23 @@ public interface McpHttpClientAuthorizationErrorHandler { Publisher handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context); /** - * A no-op handler, used in the default use-case. + * Maximum number of authorization error retries the transport will attempt. When the + * handler signals a retry via {@link #handle}, the transport will replay the original + * request at most this many times. If the authorization error persists after + * exhausting all retries, the transport will propagate the + * {@link McpHttpClientTransportAuthorizationException}. + *

+ * Defaults to {@code 1}. + * @return the maximum number of retries */ - McpHttpClientAuthorizationErrorHandler NOOP = new Noop(); + default int maxRetries() { + return 1; + } /** - * Handle authorization error (HTTP 401 or 403), and optionally retry the HTTP - * request, or trigger a transport error. To retry, use the {@code retryAction} - * publisher. To emit the default transport error, use the {@code defaultError} - * publisher. - *

- * Optionally, the returned {@link Publisher} may error to trigger an out-of-band - * action. In that case, the error will be propagated to the calling method, to be - * handled by the caller. - *

- * Defaults to {@link #handle(HttpResponse.ResponseInfo, McpTransportContext)}, and - * uses the boolean from the return value to decide whether it should retry the - * request. - * @param responseInfo the HTTP response information - * @param context the MCP client transport context - * @param retryAction handler to retry the original request - * @param defaultError handler to emit an error - * @return a {@link Publisher} to signal either an error or a retry + * A no-op handler, used in the default use-case. */ - default Publisher onAuthorizationError(HttpResponse.ResponseInfo responseInfo, McpTransportContext context, - Publisher retryAction, Publisher defaultError) { - return Mono.from(this.handle(responseInfo, context)) - .switchIfEmpty(Mono.just(false)) - .flatMap(shouldRetry -> shouldRetry != null && shouldRetry ? Mono.from(retryAction) - : Mono.from(defaultError)); - } + McpHttpClientAuthorizationErrorHandler NOOP = new Noop(); /** * Create a {@link McpHttpClientAuthorizationErrorHandler} from a synchronous handler. @@ -95,7 +82,8 @@ interface Sync { * request should be retried or not. If the return value is true, the original * transport method (connect, sendMessage) will be replayed with the original * arguments. Otherwise, the transport will throw an - * {@link McpHttpClientTransportException}, indicating the error status. + * {@link McpHttpClientTransportAuthorizationException}, indicating the error + * status. * @param responseInfo the HTTP response information * @param context the MCP client transport context * @return true if the original request should be replayed, false otherwise. diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java index b935f95a5..2812522f5 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java @@ -6,9 +6,7 @@ import java.net.http.HttpResponse; import io.modelcontextprotocol.common.McpTransportContext; -import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import static org.mockito.Mockito.mock; @@ -22,85 +20,29 @@ class McpHttpClientAuthorizationErrorHandlerTest { private final McpTransportContext context = McpTransportContext.EMPTY; - @Nested - class OnAuthorizationError { - - @Test - void whenTrueThenRetry() { - McpHttpClientAuthorizationErrorHandler handler = (info, ctx) -> Mono.just(true); - Mono retryAction = Mono.empty(); - Mono defaultError = Mono.error(new RuntimeException("should not be called")); - - StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) - .verifyComplete(); - } - - @Test - void whenFalseThenError() { - McpHttpClientAuthorizationErrorHandler handler = (info, ctx) -> Mono.just(false); - Mono retryAction = Mono.error(new RuntimeException("should not be called")); - Mono defaultError = Mono.error(new RuntimeException("authorization error")); - - StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) - .expectErrorMatches(t -> t instanceof RuntimeException && t.getMessage().equals("authorization error")) - .verify(); - } - - @Test - void whenErrorThenPropagate() { - McpHttpClientAuthorizationErrorHandler handler = (info, ctx) -> Mono - .error(new IllegalStateException("handler error")); - Mono retryAction = Mono.error(new RuntimeException("should not be called")); - Mono defaultError = Mono.error(new RuntimeException("should not be called")); - - StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) - .expectErrorMatches(t -> t instanceof IllegalStateException && t.getMessage().equals("handler error")) - .verify(); - } - + @Test + void whenTrueThenRetry() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> true); + StepVerifier.create(handler.handle(responseInfo, context)).expectNext(true).verifyComplete(); } - @Nested - class FromSync { - - @Test - void whenTrueThenRetry() { - McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler - .fromSync((info, ctx) -> true); - Mono retryAction = Mono.empty(); - Mono defaultError = Mono.error(new RuntimeException("should not be called")); - - StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) - .verifyComplete(); - } - - @Test - void whenFalseThenError() { - McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler - .fromSync((info, ctx) -> false); - Mono retryAction = Mono.error(new RuntimeException("should not be called")); - Mono defaultError = Mono.error(new RuntimeException("authorization error")); - - StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) - .expectErrorMatches(t -> t instanceof RuntimeException && t.getMessage().equals("authorization error")) - .verify(); - } - - @Test - void whenExceptionThenPropagate() { - McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler - .fromSync((info, ctx) -> { - throw new IllegalStateException("sync handler error"); - }); - Mono retryAction = Mono.error(new RuntimeException("should not be called")); - Mono defaultError = Mono.error(new RuntimeException("should not be called")); - - StepVerifier.create(handler.onAuthorizationError(responseInfo, context, retryAction, defaultError)) - .expectErrorMatches( - t -> t instanceof IllegalStateException && t.getMessage().equals("sync handler error")) - .verify(); - } + @Test + void whenFalseThenError() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> false); + StepVerifier.create(handler.handle(responseInfo, context)).expectNext(false).verifyComplete(); + } + @Test + void whenExceptionThenPropagate() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> { + throw new IllegalStateException("sync handler error"); + }); + StepVerifier.create(handler.handle(responseInfo, context)) + .expectErrorMatches(t -> t instanceof IllegalStateException && t.getMessage().equals("sync handler error")) + .verify(); } } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java index 20ae10dd3..c4857e5b4 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -9,13 +9,16 @@ import java.net.http.HttpResponse; import java.time.Duration; import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Predicate; import com.sun.net.httpserver.HttpServer; +import io.modelcontextprotocol.client.transport.customizer.McpHttpClientAuthorizationErrorHandler; import io.modelcontextprotocol.common.McpTransportContext; +import org.reactivestreams.Publisher; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; @@ -429,6 +432,47 @@ void retry() { StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); } + @Test + void retryAtMostOnce() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.just(true)) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + // initial request + 1 retry (maxRetries default is 1) + assertThat(processedMessagesCount.get()).isEqualTo(2); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void customMaxRetries() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler(new McpHttpClientAuthorizationErrorHandler() { + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, + McpTransportContext context) { + return Mono.just(true); + } + + @Override + public int maxRetries() { + return 3; + } + }) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + // initial request + 3 retries + assertThat(processedMessagesCount.get()).isEqualTo(4); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + @Test void noRetry() { serverResponseStatus.set(401); @@ -510,8 +554,8 @@ void invokeHandler(int httpStatus) { assertThat(capturedResponseInfo.get().statusCode()).isEqualTo(httpStatus); assertThat(capturedContext.get()).isNotNull(); assertThat(capturedException.get()).hasMessage("Authorization error connecting to SSE stream") - .asInstanceOf(type(McpHttpClientTransportException.class)) - .extracting(McpHttpClientTransportException::getResponseInfo) + .asInstanceOf(type(McpHttpClientTransportAuthorizationException.class)) + .extracting(McpHttpClientTransportAuthorizationException::getResponseInfo) .extracting(HttpResponse.ResponseInfo::statusCode) .isEqualTo(httpStatus); @@ -531,7 +575,7 @@ void defaultHandler() { Awaitility.await() .atMost(Duration.ofSeconds(1)) .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); - assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportException.class); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); } @@ -550,12 +594,77 @@ void retry() { authTransport.setExceptionHandler(capturedException::set); var messages = new ArrayList(); - StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + var messageHandlerClosed = new AtomicBoolean(false); + StepVerifier + .create(authTransport + .connect(msg -> msg.doOnNext(messages::add).doFinally(s -> messageHandlerClosed.set(true)))) + .verifyComplete(); Awaitility.await() .atMost(Duration.ofSeconds(1)) - .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(2)); + .untilAsserted(() -> assertThat(messageHandlerClosed).isTrue()); + assertThat(processedSseConnectCount.get()).isEqualTo(2); assertThat(messages).hasSize(1); assertThat(capturedException.get()).isNull(); + assertThat(messageHandlerClosed.get()).isTrue(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retryAtMostOnce() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + return Mono.just(true); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(capturedException.get()).isNotNull()); + // initial request + 1 retry (maxRetries default is 1) + assertThat(processedSseConnectCount.get()).isEqualTo(2); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void customMaxRetries() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler(new McpHttpClientAuthorizationErrorHandler() { + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, + McpTransportContext context) { + return Mono.just(true); + } + + @Override + public int maxRetries() { + return 3; + } + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(capturedException.get()).isNotNull()); + // initial request + 3 retries + assertThat(processedSseConnectCount.get()).isEqualTo(4); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); } @@ -580,7 +689,7 @@ void noRetry() { .atMost(Duration.ofSeconds(1)) .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); assertThat(messages).isEmpty(); - assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportException.class); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); } @@ -601,7 +710,7 @@ void emptyHandler() { .atMost(Duration.ofSeconds(1)) .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); assertThat(messages).isEmpty(); - assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportException.class); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); } @@ -632,9 +741,10 @@ void propagateHandlerError() { } private static Predicate authorizationError(int httpStatus) { - return throwable -> throwable instanceof McpHttpClientTransportException + return throwable -> throwable instanceof McpHttpClientTransportAuthorizationException && throwable.getMessage().contains("Authorization error") - && ((McpHttpClientTransportException) throwable).getResponseInfo().statusCode() == httpStatus; + && ((McpHttpClientTransportAuthorizationException) throwable).getResponseInfo() + .statusCode() == httpStatus; } }