Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@

from mcp.client._memory import InMemoryTransport
from mcp.client._transport import Transport
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.client.session import (
ClientSession,
ElicitationFnT,
ListChangedFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
)
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server
from mcp.server.mcpserver import MCPServer
Expand Down Expand Up @@ -95,6 +103,15 @@ async def main():
elicitation_callback: ElicitationFnT | None = None
"""Callback for handling elicitation requests."""

tools_list_changed_callback: ListChangedFnT | None = None
"""Callback invoked when the server sends a tools/list_changed notification."""

resources_list_changed_callback: ListChangedFnT | None = None
"""Callback invoked when the server sends a resources/list_changed notification."""

prompts_list_changed_callback: ListChangedFnT | None = None
"""Callback invoked when the server sends a prompts/list_changed notification."""

_session: ClientSession | None = field(init=False, default=None)
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
_transport: Transport = field(init=False)
Expand Down Expand Up @@ -126,6 +143,9 @@ async def __aenter__(self) -> Client:
message_handler=self.message_handler,
client_info=self.client_info,
elicitation_callback=self.elicitation_callback,
tools_list_changed_callback=self.tools_list_changed_callback,
resources_list_changed_callback=self.resources_list_changed_callback,
prompts_list_changed_callback=self.prompts_list_changed_callback,
)
)

Expand Down
20 changes: 20 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ async def __call__(
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch


class ListChangedFnT(Protocol):
async def __call__(self) -> None: ... # pragma: no branch


class LoggingFnT(Protocol):
async def __call__(self, params: types.LoggingMessageNotificationParams) -> None: ... # pragma: no branch

Expand Down Expand Up @@ -95,6 +99,10 @@ async def _default_logging_callback(
pass


async def _default_list_changed_callback() -> None:
pass


ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)


Expand All @@ -118,6 +126,9 @@ def __init__(
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
tools_list_changed_callback: ListChangedFnT | None = None,
resources_list_changed_callback: ListChangedFnT | None = None,
prompts_list_changed_callback: ListChangedFnT | None = None,
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
Expand All @@ -130,6 +141,9 @@ def __init__(
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tools_list_changed_callback = tools_list_changed_callback or _default_list_changed_callback
self._resources_list_changed_callback = resources_list_changed_callback or _default_list_changed_callback
self._prompts_list_changed_callback = prompts_list_changed_callback or _default_list_changed_callback
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
self._experimental_features: ExperimentalClientFeatures | None = None
Expand Down Expand Up @@ -470,6 +484,12 @@ async def _received_notification(self, notification: types.ServerNotification) -
match notification:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case types.ToolListChangedNotification():
await self._tools_list_changed_callback()
case types.ResourceListChangedNotification():
await self._resources_list_changed_callback()
case types.PromptListChangedNotification():
await self._prompts_list_changed_callback()
case types.ElicitCompleteNotification(params=params):
# Handle elicitation completion notification
# Clients MAY use this to retry requests or update UI
Expand Down
6 changes: 3 additions & 3 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,15 +474,15 @@ async def send_progress_notification(
related_request_id,
)

async def send_resource_list_changed(self) -> None: # pragma: no cover
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(types.ResourceListChangedNotification())

async def send_tool_list_changed(self) -> None: # pragma: no cover
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(types.ToolListChangedNotification())

async def send_prompt_list_changed(self) -> None: # pragma: no cover
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(types.PromptListChangedNotification())

Expand Down
135 changes: 135 additions & 0 deletions tests/client/test_list_changed_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Tests for tools/resources/prompts list_changed notification callbacks."""

import anyio
import pytest

from mcp import Client, types
from mcp.server.mcpserver import Context, MCPServer
from mcp.shared.session import RequestResponder
from mcp.types import TextContent

pytestmark = pytest.mark.anyio


async def test_tools_list_changed_callback():
"""Verify that the client invokes the tools_list_changed_callback when
the server sends a notifications/tools/list_changed notification."""
server = MCPServer("test")
received = anyio.Event()

async def on_tools_list_changed() -> None:
received.set()

@server.tool("trigger_tool_change")
async def trigger_tool_change(ctx: Context) -> str:
await ctx.session.send_tool_list_changed()
return "triggered"

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception): # pragma: no cover
raise message

async with Client(
server,
tools_list_changed_callback=on_tools_list_changed,
message_handler=message_handler,
) as client:
result = await client.call_tool("trigger_tool_change", {})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "triggered"

with anyio.fail_after(5):
await received.wait()


async def test_resources_list_changed_callback():
"""Verify that the client invokes the resources_list_changed_callback when
the server sends a notifications/resources/list_changed notification."""
server = MCPServer("test")
received = anyio.Event()

async def on_resources_list_changed() -> None:
received.set()

@server.tool("trigger_resource_change")
async def trigger_resource_change(ctx: Context) -> str:
await ctx.session.send_resource_list_changed()
return "triggered"

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception): # pragma: no cover
raise message

async with Client(
server,
resources_list_changed_callback=on_resources_list_changed,
message_handler=message_handler,
) as client:
result = await client.call_tool("trigger_resource_change", {})
assert result.is_error is False

with anyio.fail_after(5):
await received.wait()


async def test_prompts_list_changed_callback():
"""Verify that the client invokes the prompts_list_changed_callback when
the server sends a notifications/prompts/list_changed notification."""
server = MCPServer("test")
received = anyio.Event()

async def on_prompts_list_changed() -> None:
received.set()

@server.tool("trigger_prompt_change")
async def trigger_prompt_change(ctx: Context) -> str:
await ctx.session.send_prompt_list_changed()
return "triggered"

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception): # pragma: no cover
raise message

async with Client(
server,
prompts_list_changed_callback=on_prompts_list_changed,
message_handler=message_handler,
) as client:
result = await client.call_tool("trigger_prompt_change", {})
assert result.is_error is False

with anyio.fail_after(5):
await received.wait()


async def test_list_changed_callbacks_not_called_without_notification():
"""Verify that list_changed callbacks are NOT invoked when
no list_changed notification is sent."""
server = MCPServer("test")
called = False

async def should_not_be_called() -> None:
nonlocal called
called = True # pragma: no cover

@server.tool("normal_tool")
async def normal_tool() -> str:
return "ok"

async with Client(
server,
tools_list_changed_callback=should_not_be_called,
resources_list_changed_callback=should_not_be_called,
prompts_list_changed_callback=should_not_be_called,
) as client:
result = await client.call_tool("normal_tool", {})
assert result.is_error is False

assert not called
Loading