diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 957082a85..440246059 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyUrl, BaseModel @@ -40,6 +40,8 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # JWT sub claim — identifies the end-user + claims: dict[str, Any] | None = None # arbitrary JWT claims from the token RegistrationErrorCode = Literal[ diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 1538adc7c..f38b44e09 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -218,6 +218,18 @@ def client_id(self) -> str | None: """Get the client ID if available.""" return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover + @property + def subject(self) -> str | None: + """Get the authenticated user's subject (JWT sub claim) if available. + + This returns the subject claim from the OAuth access token, identifying + the end-user on whose behalf the request is made. + """ + from mcp.server.auth.middleware.auth_context import get_access_token + + token = get_access_token() + return token.subject if token else None + @property def request_id(self) -> str: """Get the unique ID for this request.""" diff --git a/tests/server/auth/middleware/test_auth_context.py b/tests/server/auth/middleware/test_auth_context.py index 66481bcf7..607158ddb 100644 --- a/tests/server/auth/middleware/test_auth_context.py +++ b/tests/server/auth/middleware/test_auth_context.py @@ -41,6 +41,7 @@ def valid_access_token() -> AccessToken: client_id="test_client", scopes=["read", "write"], expires_at=int(time.time()) + 3600, # 1 hour from now + subject="user_123", ) @@ -83,6 +84,27 @@ async def send(message: Message) -> None: # pragma: no cover assert get_access_token() is None +@pytest.mark.anyio +async def test_auth_context_middleware_subject_preserved(valid_access_token: AccessToken): + """Test that subject field on AccessToken is available via get_access_token().""" + app = MockApp() + middleware = AuthContextMiddleware(app) + + user = AuthenticatedUser(valid_access_token) + scope: Scope = {"type": "http", "user": user} + + async def receive() -> Message: # pragma: no cover + return {"type": "http.request"} + + async def send(message: Message) -> None: # pragma: no cover + pass + + await middleware(scope, receive, send) + + assert app.access_token_during_call is not None + assert app.access_token_during_call.subject == "user_123" + + @pytest.mark.anyio async def test_auth_context_middleware_with_no_user(): """Test middleware with no user in scope.""" diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index bd14e294c..1e7bbb80a 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -77,6 +77,7 @@ def valid_access_token() -> AccessToken: client_id="test_client", scopes=["read", "write"], expires_at=int(time.time()) + 3600, # 1 hour from now + subject="user_123", )