diff --git a/ccproxy/core/plugins/factories.py b/ccproxy/core/plugins/factories.py index 35476d25..c854707d 100644 --- a/ccproxy/core/plugins/factories.py +++ b/ccproxy/core/plugins/factories.py @@ -105,6 +105,7 @@ class BaseProviderPluginFactory(ProviderPluginFactory): cli_commands: list[CliCommandSpec] = [] cli_arguments: list[CliArgumentSpec] = [] tool_accumulator_class: type | None = None + use_mock_adapter_in_bypass_mode: bool = True def __init__(self) -> None: """Initialize factory with manifest built from class attributes.""" @@ -231,7 +232,8 @@ async def create_adapter(self, context: PluginContext) -> BaseAdapter: adapter=self.adapter_class.__name__, category="lifecycle", ) - return MockAdapter(service_container.get_mock_handler()) + if self.use_mock_adapter_in_bypass_mode: + return MockAdapter(service_container.get_mock_handler()) # Extract services from context (one-time extraction) http_pool_manager: HTTPPoolManager | None = cast( @@ -285,6 +287,8 @@ async def create_adapter(self, context: PluginContext) -> BaseAdapter: if hasattr(context, "get") else None, } + if settings and getattr(settings.server, "bypass_mode", False): + adapter_kwargs["mock_handler"] = service_container.get_mock_handler() if self.tool_accumulator_class: adapter_kwargs["tool_accumulator_class"] = self.tool_accumulator_class @@ -320,6 +324,9 @@ async def create_adapter(self, context: PluginContext) -> BaseAdapter: "model_mapper": context.get("model_mapper") if hasattr(context, "get") else None, + "mock_handler": service_container.get_mock_handler() + if settings and getattr(settings.server, "bypass_mode", False) + else None, } if self.tool_accumulator_class: non_http_adapter_kwargs["tool_accumulator_class"] = ( diff --git a/ccproxy/plugins/codex/adapter.py b/ccproxy/plugins/codex/adapter.py index 7c3af71f..7324d8fe 100644 --- a/ccproxy/plugins/codex/adapter.py +++ b/ccproxy/plugins/codex/adapter.py @@ -17,6 +17,7 @@ ) from ccproxy.services.adapters.chain_composer import compose_from_chain from ccproxy.services.adapters.http_adapter import BaseHTTPAdapter +from ccproxy.services.adapters.mock_adapter import MockAdapter from ccproxy.services.handler_config import HandlerConfig from ccproxy.streaming import DeferredStreaming, StreamingBufferService from ccproxy.utils.headers import ( @@ -58,6 +59,9 @@ async def handle_request( # Context + request info ctx = request.state.context self._ensure_tool_accumulator(ctx) + if self.mock_handler: + return await MockAdapter(self.mock_handler).handle_request(request) + endpoint = ctx.metadata.get("endpoint", "") body = await request.body() body = await self._map_request_model(ctx, body) @@ -286,6 +290,13 @@ async def prepare_provider_request( else: body_data.pop("instructions", None) + body_data = self._sanitize_provider_body(body_data) + + return json.dumps(body_data).encode(), filtered_headers + + def _sanitize_provider_body(self, body_data: dict[str, Any]) -> dict[str, Any]: + """Apply Codex-specific payload sanitization shared by all request paths.""" + # Codex backend requires stream=true, always override body_data["stream"] = True body_data["store"] = False @@ -305,11 +316,10 @@ async def prepare_provider_request( input for input in list_input if input.get("type") != "item_reference" ] - # # Remove any prefixed metadata fields that shouldn't be sent to the API body_data = self._remove_metadata_fields(body_data) - return json.dumps(body_data).encode(), filtered_headers + return body_data async def prepare_provider_headers(self, headers: dict[str, str]) -> dict[str, str]: token_value = await self._resolve_access_token() @@ -475,6 +485,10 @@ async def handle_streaming( if not self.streaming_handler: # Fallback to base behavior return await super().handle_streaming(request, endpoint, **kwargs) + if self.mock_handler: + return await MockAdapter(self.mock_handler).handle_streaming( + request, endpoint, **kwargs + ) # Get context ctx = request.state.context @@ -655,9 +669,6 @@ def _request_body_is_encoded(self, headers: dict[str, str]) -> bool: encoding = headers.get("content-encoding", "").strip().lower() return bool(encoding and encoding != "identity") - def _should_apply_detection_payload(self) -> bool: - return bool(getattr(self.config, "inject_detection_payload", True)) - def _detect_streaming_intent(self, body: bytes, headers: dict[str, str]) -> bool: if self._request_body_is_encoded(headers): accept = headers.get("accept", "").lower() @@ -670,6 +681,9 @@ def _detect_streaming_intent(self, body: bytes, headers: dict[str, str]) -> bool accept = headers.get("accept", "").lower() return "text/event-stream" in accept + def _should_apply_detection_payload(self) -> bool: + return bool(getattr(self.config, "inject_detection_payload", True)) + def _get_instructions(self) -> str: if not self.detection_service: return "" diff --git a/ccproxy/plugins/codex/detection_service.py b/ccproxy/plugins/codex/detection_service.py index 88885560..7738aab6 100644 --- a/ccproxy/plugins/codex/detection_service.py +++ b/ccproxy/plugins/codex/detection_service.py @@ -163,18 +163,12 @@ def get_detected_prompts(self) -> DetectedPrompts: data = self.get_cached_data() prompts = data.prompts if data else DetectedPrompts() - if prompts.has_instructions() or prompts.has_system(): - return prompts fallback = self._safe_fallback_data() if fallback is None: return prompts - fallback_prompts = fallback.prompts - if fallback_prompts.has_instructions() or fallback_prompts.has_system(): - return fallback_prompts - - return prompts + return self._merge_detected_prompts(prompts, fallback.prompts) def get_ignored_headers(self) -> list[str]: """Headers that should be ignored when forwarding CLI values.""" @@ -537,6 +531,26 @@ def _safe_fallback_data(self) -> CodexCacheData | None: ) return None + @staticmethod + def _merge_detected_prompts( + prompts: DetectedPrompts, fallback: DetectedPrompts + ) -> DetectedPrompts: + """Merge partial prompt caches with fallback defaults.""" + + prompt_raw = prompts.raw if isinstance(prompts.raw, dict) else {} + fallback_raw = fallback.raw if isinstance(fallback.raw, dict) else {} + merged_raw = dict(fallback_raw) + merged_raw.update(prompt_raw) + + instructions = prompts.instructions or fallback.instructions + system = prompts.system if prompts.system is not None else fallback.system + + return DetectedPrompts( + instructions=instructions, + system=system, + raw=merged_raw, + ) + def invalidate_cache(self) -> None: """Clear all cached detection data.""" # Clear the async cache for _get_codex_version diff --git a/ccproxy/plugins/codex/plugin.py b/ccproxy/plugins/codex/plugin.py index 7ae806d2..1433f50b 100644 --- a/ccproxy/plugins/codex/plugin.py +++ b/ccproxy/plugins/codex/plugin.py @@ -222,6 +222,7 @@ class CodexFactory(BaseProviderPluginFactory): """Factory for Codex provider plugin.""" cli_safe = False # Heavy provider plugin - not safe for CLI + use_mock_adapter_in_bypass_mode = False # Plugin configuration via class attributes plugin_name = "codex" diff --git a/ccproxy/plugins/codex/routes.py b/ccproxy/plugins/codex/routes.py index 5a8310f3..cf605515 100644 --- a/ccproxy/plugins/codex/routes.py +++ b/ccproxy/plugins/codex/routes.py @@ -1,6 +1,5 @@ """Codex plugin routes.""" -import contextlib import json from collections import deque from pathlib import Path @@ -31,8 +30,10 @@ ) from ccproxy.core.logging import get_plugin_logger from ccproxy.core.plugins import PluginRegistry, ProviderPluginRuntime +from ccproxy.core.request_context import RequestContext from ccproxy.streaming import DeferredStreaming from ccproxy.streaming.sse_parser import SSEStreamParser +from ccproxy.utils.model_mapper import restore_model_aliases from .config import CodexSettings @@ -114,7 +115,9 @@ def _make_websocket_terminal_event( provider_payload: dict[str, Any], *, error: dict[str, Any] | None = None, + sequence_number: int = 0, ) -> dict[str, Any]: + event_type = "response.failed" if error else "response.completed" response_payload: dict[str, Any] = { "id": f"resp_ws_{uuid4().hex}", "object": "response", @@ -126,7 +129,11 @@ def _make_websocket_terminal_event( "error": error, "incomplete_details": None, } - return {"type": "response.completed", "response": response_payload} + return { + "type": event_type, + "sequence_number": sequence_number, + "response": response_payload, + } def _is_websocket_warmup_request(provider_payload: dict[str, Any]) -> bool: @@ -135,25 +142,13 @@ def _is_websocket_warmup_request(provider_payload: dict[str, Any]) -> bool: async def _authenticate_websocket(websocket: WebSocket) -> None: - """Enforce bearer auth on WebSocket connections when auth is configured. - - Mirrors the ConditionalAuthDep logic: if security.auth_token is set, - the client must provide a matching Authorization header. Closes the - connection with 1008 (Policy Violation) on failure. - """ - container = getattr(websocket.app.state, "service_container", None) - settings: Settings | None = None - if container is not None: - with contextlib.suppress(ValueError): - settings = container.get_service(Settings) - if settings is None: - with contextlib.suppress(Exception): - settings = Settings() - - if settings is None or not settings.security.auth_token: + """Enforce bearer auth on WebSocket connections when auth is configured.""" + settings = _get_websocket_settings(websocket) + expected_token = settings.security.auth_token + if expected_token is None: return - expected = settings.security.auth_token.get_secret_value() + expected = expected_token.get_secret_value() auth_header = websocket.headers.get("authorization", "") scheme, _, credentials = auth_header.partition(" ") if scheme.lower() == "bearer": @@ -162,16 +157,66 @@ async def _authenticate_websocket(websocket: WebSocket) -> None: else: token = "" + if not token: + await _deny_websocket_connection( + websocket, + status_code=401, + detail="Authentication required", + headers={"WWW-Authenticate": "Bearer"}, + ) + raise WebSocketDisconnect(code=1008) + if token != expected: - await websocket.close(code=1008, reason="Authentication required") + await _deny_websocket_connection( + websocket, + status_code=401, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) raise WebSocketDisconnect(code=1008) +def _get_websocket_settings(websocket: WebSocket) -> Settings: + app_settings = getattr(websocket.app.state, "settings", None) + if isinstance(app_settings, Settings): + return app_settings + + container = getattr(websocket.app.state, "service_container", None) + if container is None: + raise RuntimeError("Service container not initialized for websocket auth") + + try: + settings = container.get_service(Settings) + except ValueError as exc: + raise RuntimeError("Settings service unavailable for websocket auth") from exc + + if not isinstance(settings, Settings): + raise RuntimeError("Settings service returned invalid websocket auth settings") + + return settings + + +async def _deny_websocket_connection( + websocket: WebSocket, + *, + status_code: int, + detail: str, + headers: dict[str, str] | None = None, +) -> None: + await websocket.send_denial_response( + Response(content=detail, status_code=status_code, headers=headers) + ) + + async def _sanitize_websocket_payload( - adapter: "CodexAdapter", provider_payload: dict[str, Any], headers: dict[str, str] + adapter: "CodexAdapter", + provider_payload: dict[str, Any], + headers: dict[str, str], + request_context: RequestContext, ) -> tuple[dict[str, Any], dict[str, str]]: """Run the same request normalization used by HTTP routes on a WS payload.""" body_bytes = json.dumps(provider_payload).encode("utf-8") + body_bytes = await adapter._map_request_model(request_context, body_bytes) prepared_body, prepared_headers = await adapter.prepare_provider_request( body_bytes, headers, UPSTREAM_ENDPOINT_OPENAI_RESPONSES ) @@ -179,6 +224,73 @@ async def _sanitize_websocket_payload( return sanitized_payload, prepared_headers +def _new_websocket_request_context() -> RequestContext: + return RequestContext( + request_id=f"ws_{uuid4().hex}", + start_time=time(), + logger=logger, + metadata={}, + format_chain=[FORMAT_OPENAI_RESPONSES], + ) + + +def _restore_websocket_event_models( + event: dict[str, Any], request_context: RequestContext +) -> dict[str, Any]: + metadata = getattr(request_context, "metadata", None) + if isinstance(metadata, dict): + restore_model_aliases(event, metadata) + return event + + +async def _prepare_mock_websocket_payload( + adapter: "CodexAdapter", + provider_payload: dict[str, Any], + request_context: RequestContext, +) -> dict[str, Any]: + body_bytes = json.dumps(provider_payload).encode("utf-8") + body_bytes = await adapter._map_request_model(request_context, body_bytes) + payload = json.loads(body_bytes.decode("utf-8")) + + if adapter._should_apply_detection_payload(): + payload = adapter._apply_request_template(payload) + detected_instructions = adapter._get_instructions() + else: + payload = adapter._normalize_input_messages(payload) + detected_instructions = "" + + existing_instructions = payload.get("instructions") + if isinstance(existing_instructions, str) and existing_instructions: + instructions = ( + f"{detected_instructions}\n{existing_instructions}" + if detected_instructions + else existing_instructions + ) + else: + instructions = detected_instructions + + if instructions: + payload["instructions"] = instructions + else: + payload.pop("instructions", None) + + payload = adapter._sanitize_provider_body(payload) + return payload + + +async def _send_websocket_event( + websocket: WebSocket, + event: dict[str, Any], + request_context: RequestContext, +) -> None: + await websocket.send_text( + json.dumps( + _restore_websocket_event_models(event, request_context), + separators=(",", ":"), + ) + ) + + def _serialize_codex_models(config: CodexSettings) -> list[dict[str, Any]]: models: list[dict[str, Any]] = [] for card in config.models_endpoint: @@ -243,9 +355,19 @@ async def _stream_websocket_response( provider_payload: dict[str, Any], ) -> None: request_headers = _prepare_websocket_headers(websocket) + request_context = _new_websocket_request_context() + if adapter.mock_handler: + provider_payload = await _prepare_mock_websocket_payload( + adapter, provider_payload, request_context + ) + await _stream_websocket_mock_response( + websocket, adapter, provider_payload, request_context + ) + return provider_payload, provider_headers = await _sanitize_websocket_payload( - adapter, provider_payload, request_headers + adapter, provider_payload, request_headers, request_context ) + target_url = await adapter.get_target_url(UPSTREAM_ENDPOINT_OPENAI_RESPONSES) parsed_url = urlparse(target_url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" @@ -273,9 +395,12 @@ async def _stream_websocket_response( } await websocket.send_text( json.dumps( - _make_websocket_terminal_event( - provider_payload, - error=error_payload.get("error", error_payload), + _restore_websocket_event_models( + _make_websocket_terminal_event( + provider_payload, + error=error_payload.get("error", error_payload), + ), + request_context, ), separators=(",", ":"), ) @@ -286,28 +411,74 @@ async def _stream_websocket_response( for event in parser.feed(chunk): if event.get("type") in {"response.completed", "response.failed"}: saw_terminal_event = True - await websocket.send_text(json.dumps(event, separators=(",", ":"))) + await _send_websocket_event(websocket, event, request_context) for event in parser.flush(): if event.get("type") in {"response.completed", "response.failed"}: saw_terminal_event = True - await websocket.send_text(json.dumps(event, separators=(",", ":"))) + await _send_websocket_event(websocket, event, request_context) if not saw_terminal_event: await websocket.send_text( json.dumps( - _make_websocket_terminal_event( - provider_payload, - error={ - "type": "server_error", - "message": "WebSocket stream ended before response.completed", - }, + _restore_websocket_event_models( + _make_websocket_terminal_event( + provider_payload, + error={ + "type": "server_error", + "message": "WebSocket stream ended before response.completed", + }, + ), + request_context, ), separators=(",", ":"), ) ) +async def _stream_websocket_mock_response( + websocket: WebSocket, + adapter: "CodexAdapter", + provider_payload: dict[str, Any], + request_context: RequestContext, +) -> None: + body = json.dumps(provider_payload).encode("utf-8") + parser = SSEStreamParser() + saw_terminal_event = False + + stream_response = await adapter.mock_handler.generate_streaming_response( + provider_payload.get("model"), + FORMAT_OPENAI_RESPONSES, + request_context, + adapter.mock_handler.extract_message_type(body), + adapter.mock_handler.extract_prompt_text(body), + ) + + async for chunk in stream_response.body_iterator: + for event in parser.feed(chunk): + if event.get("type") in {"response.completed", "response.failed"}: + saw_terminal_event = True + await _send_websocket_event(websocket, event, request_context) + + for event in parser.flush(): + if event.get("type") in {"response.completed", "response.failed"}: + saw_terminal_event = True + await _send_websocket_event(websocket, event, request_context) + + if not saw_terminal_event: + await _send_websocket_event( + websocket, + _make_websocket_terminal_event( + provider_payload, + error={ + "type": "server_error", + "message": "Mock WebSocket stream ended before response.completed", + }, + ), + request_context, + ) + + @router.post("/v1/responses", response_model=None) @with_format_chain( [FORMAT_OPENAI_RESPONSES], endpoint=UPSTREAM_ENDPOINT_OPENAI_RESPONSES @@ -322,10 +493,9 @@ async def codex_responses( @router.websocket("/v1/responses") async def codex_responses_websocket(websocket: WebSocket) -> None: - await websocket.accept() - await _authenticate_websocket(websocket) - try: + await _authenticate_websocket(websocket) + await websocket.accept() adapter = _get_codex_websocket_adapter(websocket) local_response_ids: deque[str] = deque(maxlen=_MAX_LOCAL_RESPONSE_IDS) logger.debug("websocket_connected", client=str(websocket.client)) diff --git a/ccproxy/services/adapters/base.py b/ccproxy/services/adapters/base.py index 4df4a139..8e38e8c4 100644 --- a/ccproxy/services/adapters/base.py +++ b/ccproxy/services/adapters/base.py @@ -20,6 +20,7 @@ def __init__(self, config: Any, **kwargs: Any) -> None: **kwargs: Additional keyword arguments for subclasses """ self.config = config + self.mock_handler = kwargs.pop("mock_handler", None) self.tool_accumulator_class = kwargs.pop("tool_accumulator_class", None) @abstractmethod diff --git a/ccproxy/services/adapters/mock_adapter.py b/ccproxy/services/adapters/mock_adapter.py index 3a9903c7..67271e47 100644 --- a/ccproxy/services/adapters/mock_adapter.py +++ b/ccproxy/services/adapters/mock_adapter.py @@ -2,7 +2,7 @@ import json import time -from typing import Any +from typing import Any, cast import structlog from fastapi import Request @@ -116,8 +116,11 @@ async def handle_request( ) if self._extract_stream_flag(body): - return await self.mock_handler.generate_streaming_response( - model, target_format, ctx, message_type, prompt_text + return cast( + StreamingResponse | DeferredStreaming, + await self.mock_handler.generate_streaming_response( + model, target_format, ctx, message_type, prompt_text + ), ) else: ( @@ -155,6 +158,9 @@ async def handle_streaming( logger=structlog.get_logger(__name__), ) - return await self.mock_handler.generate_streaming_response( - model, target_format, ctx, message_type, prompt_text + return cast( + StreamingResponse, + await self.mock_handler.generate_streaming_response( + model, target_format, ctx, message_type, prompt_text + ), ) diff --git a/tests/helpers/e2e_validation.py b/tests/helpers/e2e_validation.py index 89ccc6b1..84fb66ed 100644 --- a/tests/helpers/e2e_validation.py +++ b/tests/helpers/e2e_validation.py @@ -281,6 +281,123 @@ def get_validation_model_for_format( return None +# --- WebSocket validation helpers --- + + +def validate_ws_codex_event_sequence( + events: list[dict[str, Any]], +) -> tuple[bool, list[str]]: + """Validate that a Codex WebSocket event sequence is well-formed. + + Checks: + - At least one event received + - Terminal event (response.completed or response.failed) is present + - Terminal event is last + - response.completed carries required fields + """ + errors: list[str] = [] + + if not events: + errors.append("No WebSocket events received") + return False, errors + + terminal_types = {"response.completed", "response.failed"} + event_types = [e.get("type") for e in events] + + has_terminal = any(t in terminal_types for t in event_types) + if not has_terminal: + errors.append(f"No terminal event found; got types: {event_types}") + + last_type = event_types[-1] + if last_type not in terminal_types: + errors.append(f"Last event should be terminal, got: {last_type}") + + terminal_event = events[-1] + response_obj = terminal_event.get("response") + if not isinstance(response_obj, dict): + errors.append("Terminal event missing 'response' object") + else: + for field in ("id", "object", "status"): + if field not in response_obj: + errors.append(f"Terminal response missing field: {field}") + + return len(errors) == 0, errors + + +def validate_ws_codex_streaming_content( + events: list[dict[str, Any]], +) -> tuple[str, list[str]]: + """Extract and validate text content from a Codex WebSocket event stream. + + Returns: + Tuple of (assembled_text, errors) + """ + errors: list[str] = [] + deltas: list[str] = [] + + for event in events: + if event.get("type") == "response.output_text.delta": + delta = event.get("delta") + if isinstance(delta, str): + deltas.append(delta) + else: + errors.append(f"Delta event has non-string delta: {type(delta)}") + + text = "".join(deltas) + + done_events = [e for e in events if e.get("type") == "response.output_text.done"] + if done_events: + done_text = done_events[-1].get("text", "") + if done_text and done_text != text: + errors.append( + f"Assembled deltas ({text!r}) differ from done text ({done_text!r})" + ) + + return text, errors + + +def validate_ws_codex_warmup_response(event: dict[str, Any]) -> tuple[bool, list[str]]: + """Validate a warmup (empty input) response event.""" + errors: list[str] = [] + + if event.get("type") != "response.completed": + errors.append(f"Expected response.completed, got: {event.get('type')}") + + response_obj = event.get("response", {}) + if response_obj.get("status") != "completed": + errors.append(f"Expected status=completed, got: {response_obj.get('status')}") + + if response_obj.get("output") != []: + errors.append( + f"Warmup output should be empty list, got: {response_obj.get('output')}" + ) + + if not isinstance(response_obj.get("id"), str) or not response_obj["id"]: + errors.append("Warmup response missing id") + + return len(errors) == 0, errors + + +def validate_ws_codex_error_response(event: dict[str, Any]) -> tuple[bool, list[str]]: + """Validate an error terminal event from WebSocket.""" + errors: list[str] = [] + + if event.get("type") != "response.failed": + errors.append(f"Expected response.failed, got: {event.get('type')}") + + response_obj = event.get("response", {}) + if response_obj.get("status") != "failed": + errors.append(f"Expected status=failed, got: {response_obj.get('status')}") + + error_obj = response_obj.get("error") + if not isinstance(error_obj, dict): + errors.append("Error response missing 'error' object") + elif "type" not in error_obj: + errors.append("Error object missing 'type' field") + + return len(errors) == 0, errors + + # Format normalization helper def _normalize_format(format_type: str) -> str: alias_map = { diff --git a/tests/helpers/test_data.py b/tests/helpers/test_data.py index 6f0c7772..3211ccfd 100644 --- a/tests/helpers/test_data.py +++ b/tests/helpers/test_data.py @@ -286,6 +286,64 @@ def normalize_format(format_type: str) -> str: ] +# WebSocket Endpoint Test Data +WS_ENDPOINT_CONFIGURATIONS = [ + { + "name": "codex_ws_responses_stream", + "endpoint": "/codex/v1/responses", + "model": "gpt-5", + "description": "Codex WebSocket responses streaming", + }, + { + "name": "codex_ws_responses_legacy_stream", + "endpoint": "/codex/responses", + "model": "gpt-5", + "description": "Codex WebSocket responses legacy streaming", + }, +] + + +def create_ws_codex_request( + content: str = "Hello", + model: str = "gpt-5", + **kwargs: Any, +) -> dict[str, Any]: + """Create a Codex WebSocket request payload (response.create envelope).""" + request: dict[str, Any] = { + "type": "response.create", + "model": model, + "input": [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": content}], + } + ], + } + request.update(kwargs) + return request + + +def create_ws_codex_warmup_request(model: str = "gpt-5") -> dict[str, Any]: + """Create a Codex WebSocket warmup request (empty input).""" + return { + "type": "response.create", + "model": model, + "input": [], + } + + +# Expected WebSocket event types for Codex streaming +CODEX_WS_STREAMING_EVENT_TYPES = [ + "response.created", + "response.output_text.delta", + "response.output_text.done", + "response.completed", +] + +CODEX_WS_TERMINAL_EVENT_TYPES = {"response.completed", "response.failed"} + + def create_openai_request( content: str = "Hello", model: str = CLAUDE_SONNET_MODEL, diff --git a/tests/integration/test_websocket_e2e.py b/tests/integration/test_websocket_e2e.py new file mode 100644 index 00000000..815fe786 --- /dev/null +++ b/tests/integration/test_websocket_e2e.py @@ -0,0 +1,450 @@ +"""End-to-end integration tests for CCProxy WebSocket endpoints. + +Follows the same parameterized pattern as test_endpoint_e2e.py, covering +WebSocket transport for Codex responses (v1 and legacy paths). + +Tests validate: +- WebSocket configuration structure +- Request builder correctness +- Event sequence validation helpers +- Warmup, streaming, error, and multi-message flows +- Live server WebSocket flows (when CCPROXY_BASE_URL is set) +""" + +import asyncio +import json +import os +from collections.abc import Generator +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient + +from ccproxy.api.app import create_app, initialize_plugins_startup +from ccproxy.api.bootstrap import create_service_container +from ccproxy.config.settings import Settings +from ccproxy.core.logging import setup_logging +from ccproxy.models.detection import DetectedHeaders, DetectedPrompts +from ccproxy.plugins.codex.models import CodexCacheData +from tests.helpers.e2e_validation import ( + validate_ws_codex_error_response, + validate_ws_codex_event_sequence, + validate_ws_codex_streaming_content, + validate_ws_codex_warmup_response, +) +from tests.helpers.test_data import ( + CODEX_WS_TERMINAL_EVENT_TYPES, + WS_ENDPOINT_CONFIGURATIONS, + create_ws_codex_request, + create_ws_codex_warmup_request, +) + + +pytestmark = [pytest.mark.integration, pytest.mark.e2e] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _build_detection_data() -> CodexCacheData: + prompts = DetectedPrompts.from_body( + {"instructions": "You are a helpful coding assistant."} + ) + return CodexCacheData( + codex_version="fallback", + headers=DetectedHeaders({}), + prompts=prompts, + body_json=prompts.raw, + method="POST", + url="https://chatgpt.com/backend-codex/responses", + path="/api/backend-codex/responses", + query_params={}, + ) + + +@pytest.fixture +def codex_ws_app() -> Generator[TestClient, None, None]: + """Create a fully-initialised Codex app wrapped in a sync TestClient. + + Patches OAuth credentials and detection so no real providers are needed. + """ + setup_logging(json_logs=False, log_level_name="ERROR") + settings = Settings( + enable_plugins=True, + plugins={ + "codex": {"enabled": True}, + "oauth_codex": {"enabled": True}, + "duckdb_storage": {"enabled": False}, + "analytics": {"enabled": False}, + "metrics": {"enabled": False}, + }, + enabled_plugins=["codex", "oauth_codex"], + plugins_disable_local_discovery=False, + ) + service_container = create_service_container(settings) + app = create_app(service_container) + + credentials_stub = SimpleNamespace( + access_token="test-codex-access-token", + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + profile_stub = SimpleNamespace(chatgpt_account_id="test-account-id") + detection_data = _build_detection_data() + + async def init_detection_stub(self: Any) -> CodexCacheData: + self._cached_data = detection_data + return detection_data + + with ( + patch( + "ccproxy.plugins.oauth_codex.manager.CodexTokenManager.load_credentials", + new=AsyncMock(return_value=credentials_stub), + ), + patch( + "ccproxy.plugins.oauth_codex.manager.CodexTokenManager.get_profile_quick", + new=AsyncMock(return_value=profile_stub), + ), + patch( + "ccproxy.plugins.codex.detection_service.CodexDetectionService.initialize_detection", + new=init_detection_stub, + ), + ): + asyncio.run(initialize_plugins_startup(app, settings)) + with TestClient(app) as client: + yield client + + +# --------------------------------------------------------------------------- +# Configuration structure tests (no app needed, mirrors test_endpoint_e2e.py) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ws_endpoint_configurations_structure() -> None: + """Verify all WebSocket endpoint configs have required fields.""" + assert len(WS_ENDPOINT_CONFIGURATIONS) > 0 + + for config in WS_ENDPOINT_CONFIGURATIONS: + required_fields = ["name", "endpoint", "model", "description"] + assert all(field in config for field in required_fields), ( + f"Config {config.get('name')} missing fields" + ) + + endpoint = config["endpoint"] + assert isinstance(endpoint, str) + assert endpoint.startswith("/") + assert isinstance(config["model"], str) + assert len(config["model"]) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("config", WS_ENDPOINT_CONFIGURATIONS) +async def test_ws_request_creation_for_each_endpoint( + config: dict[str, Any], +) -> None: + """Verify request builders produce valid payloads for each config.""" + model = config["model"] + + request_data = create_ws_codex_request( + content="Test WebSocket message", model=model + ) + assert request_data["type"] == "response.create" + assert request_data["model"] == model + assert isinstance(request_data["input"], list) + assert len(request_data["input"]) > 0 + + warmup = create_ws_codex_warmup_request(model=model) + assert warmup["type"] == "response.create" + assert warmup["input"] == [] + + +@pytest.mark.asyncio +async def test_ws_validation_helpers_work() -> None: + """Verify validation helpers detect good and bad event sequences.""" + good_events: list[dict[str, Any]] = [ + {"type": "response.created", "response": {"id": "r1", "object": "response"}}, + {"type": "response.output_text.delta", "delta": "Hello"}, + { + "type": "response.completed", + "response": { + "id": "r1", + "object": "response", + "status": "completed", + "output": [], + }, + }, + ] + is_valid, errors = validate_ws_codex_event_sequence(good_events) + assert is_valid, errors + + text, text_errors = validate_ws_codex_streaming_content(good_events) + assert text == "Hello" + assert not text_errors + + # Empty events should fail + is_valid, errors = validate_ws_codex_event_sequence([]) + assert not is_valid + + # Missing terminal event should fail + is_valid, errors = validate_ws_codex_event_sequence([{"type": "response.created"}]) + assert not is_valid + + # Warmup validation + warmup_event = { + "type": "response.completed", + "response": { + "id": "w1", + "object": "response", + "status": "completed", + "output": [], + }, + } + is_valid, errors = validate_ws_codex_warmup_response(warmup_event) + assert is_valid, errors + + # Error validation + error_event = { + "type": "response.failed", + "response": { + "id": "e1", + "object": "response", + "status": "failed", + "error": {"type": "invalid_request_error", "message": "bad"}, + }, + } + is_valid, errors = validate_ws_codex_error_response(error_event) + assert is_valid, errors + + +# --------------------------------------------------------------------------- +# Live WebSocket tests (require codex_ws_app + external API mocks) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("config", WS_ENDPOINT_CONFIGURATIONS, ids=lambda c: c["name"]) +def test_ws_warmup_request( + codex_ws_app: TestClient, + config: dict[str, Any], +) -> None: + """Empty-input warmup should return a completed terminal event immediately.""" + warmup = create_ws_codex_warmup_request(model=config["model"]) + + with codex_ws_app.websocket_connect(config["endpoint"]) as ws: + ws.send_json(warmup) + event = ws.receive_json() + ws.close() + + is_valid, errors = validate_ws_codex_warmup_response(event) + assert is_valid, errors + + +@pytest.mark.parametrize("config", WS_ENDPOINT_CONFIGURATIONS, ids=lambda c: c["name"]) +def test_ws_streaming_response( + codex_ws_app: TestClient, + mock_external_openai_codex_api_streaming: Any, + config: dict[str, Any], +) -> None: + """A real request should stream events ending with a terminal event.""" + request = create_ws_codex_request( + content="Reply with exactly OK", model=config["model"] + ) + + events: list[dict[str, Any]] = [] + with codex_ws_app.websocket_connect(config["endpoint"]) as ws: + ws.send_json(request) + while True: + event = ws.receive_json() + events.append(event) + if event.get("type") in CODEX_WS_TERMINAL_EVENT_TYPES: + ws.close() + break + + is_valid, errors = validate_ws_codex_event_sequence(events) + assert is_valid, errors + + text, text_errors = validate_ws_codex_streaming_content(events) + assert not text_errors, text_errors + assert len(text) > 0, "Expected non-empty streamed text" + + +@pytest.mark.parametrize("config", WS_ENDPOINT_CONFIGURATIONS, ids=lambda c: c["name"]) +def test_ws_upstream_error( + codex_ws_app: TestClient, + mock_external_openai_codex_api_error: Any, + config: dict[str, Any], +) -> None: + """Upstream errors should produce a failed terminal event.""" + request = create_ws_codex_request( + content="Reply with exactly OK", model=config["model"] + ) + + with codex_ws_app.websocket_connect(config["endpoint"]) as ws: + ws.send_json(request) + event = ws.receive_json() + ws.close() + + is_valid, errors = validate_ws_codex_error_response(event) + assert is_valid, errors + + +@pytest.mark.parametrize("config", WS_ENDPOINT_CONFIGURATIONS, ids=lambda c: c["name"]) +def test_ws_warmup_then_real_request( + codex_ws_app: TestClient, + mock_external_openai_codex_api_streaming: Any, + config: dict[str, Any], +) -> None: + """Warmup followed by real request on same connection should both succeed.""" + warmup = create_ws_codex_warmup_request(model=config["model"]) + request = create_ws_codex_request( + content="Reply with exactly OK", model=config["model"] + ) + + with codex_ws_app.websocket_connect(config["endpoint"]) as ws: + # Warmup + ws.send_json(warmup) + warmup_event = ws.receive_json() + + # Strip synthetic previous_response_id + warmup_id = warmup_event.get("response", {}).get("id") + request["previous_response_id"] = warmup_id + + # Real request + ws.send_json(request) + events: list[dict[str, Any]] = [] + while True: + event = ws.receive_json() + events.append(event) + if event.get("type") in CODEX_WS_TERMINAL_EVENT_TYPES: + ws.close() + break + + # Validate warmup + is_valid, errors = validate_ws_codex_warmup_response(warmup_event) + assert is_valid, errors + + # Validate streaming + is_valid, errors = validate_ws_codex_event_sequence(events) + assert is_valid, errors + + text, text_errors = validate_ws_codex_streaming_content(events) + assert not text_errors, text_errors + assert len(text) > 0 + + +# --------------------------------------------------------------------------- +# Live server tests (require `make dev` + real credentials) +# +# Run with: CCPROXY_BASE_URL=http://127.0.0.1:8000 pytest -m real_api -k websocket +# --------------------------------------------------------------------------- + +_LIVE_BASE_URL = os.environ.get("CCPROXY_BASE_URL", "").rstrip("/") +_skip_no_live = pytest.mark.skipif( + not _LIVE_BASE_URL, + reason="CCPROXY_BASE_URL not set; skipping live WebSocket tests", +) + + +def _ws_url(http_base: str, path: str) -> str: + """Convert http(s) base URL + path to a ws(s) URL.""" + return http_base.replace("https://", "wss://").replace("http://", "ws://") + path + + +@_skip_no_live +@pytest.mark.real_api +@pytest.mark.parametrize("config", WS_ENDPOINT_CONFIGURATIONS, ids=lambda c: c["name"]) +@pytest.mark.asyncio +async def test_live_ws_warmup(config: dict[str, Any]) -> None: + """Send a warmup to the live server and validate the response.""" + import websockets + + warmup = create_ws_codex_warmup_request(model=config["model"]) + url = _ws_url(_LIVE_BASE_URL, config["endpoint"]) + + async with websockets.connect(url) as ws: + await ws.send(json.dumps(warmup)) + raw = await asyncio.wait_for(ws.recv(), timeout=10) + event = json.loads(raw) + + is_valid, errors = validate_ws_codex_warmup_response(event) + assert is_valid, errors + + +@_skip_no_live +@pytest.mark.real_api +@pytest.mark.parametrize("config", WS_ENDPOINT_CONFIGURATIONS, ids=lambda c: c["name"]) +@pytest.mark.asyncio +async def test_live_ws_streaming(config: dict[str, Any]) -> None: + """Send a real request to the live server and collect streaming events.""" + import websockets + + request = create_ws_codex_request( + content="Reply with exactly OK", model=config["model"] + ) + url = _ws_url(_LIVE_BASE_URL, config["endpoint"]) + + events: list[dict[str, Any]] = [] + async with websockets.connect(url) as ws: + await ws.send(json.dumps(request)) + while True: + raw = await asyncio.wait_for(ws.recv(), timeout=60) + event = json.loads(raw) + events.append(event) + if event.get("type") in CODEX_WS_TERMINAL_EVENT_TYPES: + break + + is_valid, errors = validate_ws_codex_event_sequence(events) + assert is_valid, errors + + text, text_errors = validate_ws_codex_streaming_content(events) + assert not text_errors, text_errors + assert len(text) > 0, "Expected non-empty response from live server" + + +@_skip_no_live +@pytest.mark.real_api +@pytest.mark.parametrize("config", WS_ENDPOINT_CONFIGURATIONS, ids=lambda c: c["name"]) +@pytest.mark.asyncio +async def test_live_ws_warmup_then_request(config: dict[str, Any]) -> None: + """Warmup followed by real request on a single live WebSocket connection.""" + import websockets + + warmup = create_ws_codex_warmup_request(model=config["model"]) + request = create_ws_codex_request( + content="Reply with exactly OK", model=config["model"] + ) + url = _ws_url(_LIVE_BASE_URL, config["endpoint"]) + + async with websockets.connect(url) as ws: + # Warmup + await ws.send(json.dumps(warmup)) + raw = await asyncio.wait_for(ws.recv(), timeout=10) + warmup_event = json.loads(raw) + + is_valid, errors = validate_ws_codex_warmup_response(warmup_event) + assert is_valid, errors + + # Attach previous_response_id from warmup + warmup_id = warmup_event.get("response", {}).get("id") + request["previous_response_id"] = warmup_id + + # Real request + await ws.send(json.dumps(request)) + events: list[dict[str, Any]] = [] + while True: + raw = await asyncio.wait_for(ws.recv(), timeout=60) + event = json.loads(raw) + events.append(event) + if event.get("type") in CODEX_WS_TERMINAL_EVENT_TYPES: + break + + is_valid, errors = validate_ws_codex_event_sequence(events) + assert is_valid, errors + + text, text_errors = validate_ws_codex_streaming_content(events) + assert not text_errors, text_errors + assert len(text) > 0 diff --git a/tests/plugins/codex/integration/test_codex_websocket.py b/tests/plugins/codex/integration/test_codex_websocket.py new file mode 100644 index 00000000..b8617e03 --- /dev/null +++ b/tests/plugins/codex/integration/test_codex_websocket.py @@ -0,0 +1,384 @@ +import asyncio +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient +from pydantic import SecretStr +from starlette.testclient import WebSocketDenialResponse + +from ccproxy.api.app import create_app, initialize_plugins_startup +from ccproxy.api.bootstrap import create_service_container +from ccproxy.config.core import ServerSettings +from ccproxy.config.security import SecuritySettings +from ccproxy.config.settings import Settings +from ccproxy.core.logging import setup_logging +from ccproxy.models.detection import DetectedHeaders, DetectedPrompts +from ccproxy.plugins.codex.models import CodexCacheData + + +@pytest.fixture +def codex_ws_client() -> Any: + setup_logging(json_logs=False, log_level_name="ERROR") + settings = Settings( + enable_plugins=True, + plugins={ + "codex": {"enabled": True}, + "oauth_codex": {"enabled": True}, + "duckdb_storage": {"enabled": False}, + "analytics": {"enabled": False}, + "metrics": {"enabled": False}, + }, + enabled_plugins=["codex", "oauth_codex"], + plugins_disable_local_discovery=False, + ) + service_container = create_service_container(settings) + app = create_app(service_container) + + credentials_stub = SimpleNamespace( + access_token="test-codex-access-token", + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + profile_stub = SimpleNamespace(chatgpt_account_id="test-account-id") + prompts = DetectedPrompts.from_body( + {"instructions": "You are a helpful coding assistant."} + ) + detection_data = CodexCacheData( + codex_version="fallback", + headers=DetectedHeaders({}), + prompts=prompts, + body_json=prompts.raw, + method="POST", + url="https://chatgpt.com/backend-codex/responses", + path="/api/backend-codex/responses", + query_params={}, + ) + + async def init_detection_stub(self): # type: ignore[no-untyped-def] + self._cached_data = detection_data + return detection_data + + load_patch = patch( + "ccproxy.plugins.oauth_codex.manager.CodexTokenManager.load_credentials", + new=AsyncMock(return_value=credentials_stub), + ) + profile_patch = patch( + "ccproxy.plugins.oauth_codex.manager.CodexTokenManager.get_profile_quick", + new=AsyncMock(return_value=profile_stub), + ) + detection_patch = patch( + "ccproxy.plugins.codex.detection_service.CodexDetectionService.initialize_detection", + new=init_detection_stub, + ) + + with load_patch, profile_patch, detection_patch: + asyncio.run(initialize_plugins_startup(app, settings)) + with TestClient(app) as client: + yield client + + +@pytest.fixture +def codex_ws_bypass_client() -> Any: + setup_logging(json_logs=False, log_level_name="ERROR") + settings = Settings( + enable_plugins=True, + server=ServerSettings(bypass_mode=True), + plugins={ + "codex": {"enabled": True}, + "oauth_codex": {"enabled": True}, + "duckdb_storage": {"enabled": False}, + "analytics": {"enabled": False}, + "metrics": {"enabled": False}, + }, + enabled_plugins=["codex", "oauth_codex"], + plugins_disable_local_discovery=False, + ) + service_container = create_service_container(settings) + app = create_app(service_container) + + prompts = DetectedPrompts.from_body( + {"instructions": "You are a helpful coding assistant."} + ) + detection_data = CodexCacheData( + codex_version="fallback", + headers=DetectedHeaders({}), + prompts=prompts, + body_json=prompts.raw, + method="POST", + url="https://chatgpt.com/backend-codex/responses", + path="/api/backend-codex/responses", + query_params={}, + ) + + async def init_detection_stub(self): # type: ignore[no-untyped-def] + self._cached_data = detection_data + return detection_data + + detection_patch = patch( + "ccproxy.plugins.codex.detection_service.CodexDetectionService.initialize_detection", + new=init_detection_stub, + ) + + with detection_patch: + asyncio.run(initialize_plugins_startup(app, settings)) + with TestClient(app) as client: + yield client + + +@pytest.fixture +def codex_ws_auth_client() -> Any: + setup_logging(json_logs=False, log_level_name="ERROR") + settings = Settings( + enable_plugins=True, + security=SecuritySettings(auth_token=SecretStr("test-auth-token")), + plugins={ + "codex": {"enabled": True}, + "oauth_codex": {"enabled": True}, + "duckdb_storage": {"enabled": False}, + "analytics": {"enabled": False}, + "metrics": {"enabled": False}, + }, + enabled_plugins=["codex", "oauth_codex"], + plugins_disable_local_discovery=False, + ) + service_container = create_service_container(settings) + app = create_app(service_container) + + credentials_stub = SimpleNamespace( + access_token="test-codex-access-token", + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + profile_stub = SimpleNamespace(chatgpt_account_id="test-account-id") + prompts = DetectedPrompts.from_body( + {"instructions": "You are a helpful coding assistant."} + ) + detection_data = CodexCacheData( + codex_version="fallback", + headers=DetectedHeaders({}), + prompts=prompts, + body_json=prompts.raw, + method="POST", + url="https://chatgpt.com/backend-codex/responses", + path="/api/backend-codex/responses", + query_params={}, + ) + + async def init_detection_stub(self): # type: ignore[no-untyped-def] + self._cached_data = detection_data + return detection_data + + with ( + patch( + "ccproxy.plugins.oauth_codex.manager.CodexTokenManager.load_credentials", + new=AsyncMock(return_value=credentials_stub), + ), + patch( + "ccproxy.plugins.oauth_codex.manager.CodexTokenManager.get_profile_quick", + new=AsyncMock(return_value=profile_stub), + ), + patch( + "ccproxy.plugins.codex.detection_service.CodexDetectionService.initialize_detection", + new=init_detection_stub, + ), + ): + asyncio.run(initialize_plugins_startup(app, settings)) + with TestClient(app) as client: + yield client + + +@pytest.mark.integration +@pytest.mark.codex +def test_codex_websocket_responses_streaming( + codex_ws_client: TestClient, + mock_external_openai_codex_api_streaming: Any, +) -> None: + request_payload = { + "type": "response.create", + "model": "gpt-5", + "stream": True, + "instructions": "Reply with exactly OK", + "input": [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Reply with exactly OK"}], + } + ], + } + + with codex_ws_client.websocket_connect( + "/codex/v1/responses", + headers={ + "authorization": "Bearer ignored-client-token", + "chatgpt-account-id": "test-account-id", + "openai-beta": "responses_websockets=2026-02-06", + "originator": "Codex Desktop", + "session_id": "test-session", + "version": "0.114.0", + "x-codex-beta-features": "multi_agent", + "x-codex-turn-metadata": '{"turn_id":"","sandbox":"seatbelt"}', + }, + ) as websocket: + websocket.send_json(request_payload) + + events: list[dict[str, Any]] = [] + while True: + try: + events.append(websocket.receive_json()) + if events[-1].get("type") == "response.completed": + websocket.close() + break + except Exception: + break + + event_types = [event.get("type") for event in events] + assert event_types == [ + "response.created", + "response.output_text.delta", + "response.output_text.delta", + "response.output_text.done", + "response.completed", + ] + assert events[-1]["response"]["output"][0]["content"][0]["text"] == "Hello Codex!" + + +@pytest.mark.integration +@pytest.mark.codex +def test_codex_websocket_returns_terminal_event_on_upstream_error( + codex_ws_client: TestClient, + mock_external_openai_codex_api_error: Any, +) -> None: + request_payload = { + "type": "response.create", + "model": "gpt-5", + "input": [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Reply with exactly OK"}], + } + ], + } + + with codex_ws_client.websocket_connect("/codex/v1/responses") as websocket: + websocket.send_json(request_payload) + event = websocket.receive_json() + websocket.close() + + assert event["type"] == "response.failed" + assert event["response"]["status"] == "failed" + assert event["response"]["error"]["type"] == "invalid_request_error" + + +@pytest.mark.integration +@pytest.mark.codex +def test_codex_websocket_short_circuits_empty_warmup_request( + codex_ws_client: TestClient, +) -> None: + request_payload = { + "type": "response.create", + "model": "gpt-5", + "input": [], + } + + with codex_ws_client.websocket_connect("/codex/v1/responses") as websocket: + websocket.send_json(request_payload) + event = websocket.receive_json() + websocket.close() + + assert event["type"] == "response.completed" + assert event["response"]["status"] == "completed" + assert event["response"]["output"] == [] + + +@pytest.mark.integration +@pytest.mark.codex +def test_codex_websocket_warmup_then_real_request_same_connection( + codex_ws_client: TestClient, + mock_external_openai_codex_api_streaming: Any, +) -> None: + warmup_payload = { + "type": "response.create", + "model": "gpt-5", + "input": [], + } + request_payload = { + "type": "response.create", + "model": "gpt-5", + "input": [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Reply with exactly OK"}], + } + ], + } + + with codex_ws_client.websocket_connect("/codex/v1/responses") as websocket: + websocket.send_json(warmup_payload) + warmup_event = websocket.receive_json() + request_payload["previous_response_id"] = warmup_event["response"]["id"] + + websocket.send_json(request_payload) + + events: list[dict[str, Any]] = [] + while True: + event = websocket.receive_json() + events.append(event) + if event.get("type") == "response.completed": + websocket.close() + break + + assert warmup_event["type"] == "response.completed" + assert events[-1]["type"] == "response.completed" + assert events[-1]["response"]["output"][0]["content"][0]["text"] == "Hello Codex!" + + +@pytest.mark.integration +@pytest.mark.codex +def test_codex_websocket_bypass_mode_streams_mock_events( + codex_ws_bypass_client: TestClient, +) -> None: + request_payload = { + "type": "response.create", + "model": "gpt-5", + "input": [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Reply with exactly OK"}], + } + ], + } + + with codex_ws_bypass_client.websocket_connect("/codex/v1/responses") as websocket: + websocket.send_json(request_payload) + + events: list[dict[str, Any]] = [] + while True: + event = websocket.receive_json() + events.append(event) + if event.get("type") == "response.completed": + websocket.close() + break + + assert events[0]["type"] == "response.created" + assert events[-1]["type"] == "response.completed" + assert events[-1]["response"]["status"] == "completed" + + +@pytest.mark.integration +@pytest.mark.codex +def test_codex_websocket_denies_unauthorized_handshake( + codex_ws_auth_client: TestClient, +) -> None: + with ( + pytest.raises(WebSocketDenialResponse) as exc_info, + codex_ws_auth_client.websocket_connect("/codex/v1/responses"), + ): + pass + + assert exc_info.value.status_code == 401 + assert exc_info.value.text == "Authentication required" diff --git a/tests/plugins/codex/unit/test_routes.py b/tests/plugins/codex/unit/test_routes.py new file mode 100644 index 00000000..e6d0011e --- /dev/null +++ b/tests/plugins/codex/unit/test_routes.py @@ -0,0 +1,123 @@ +"""Unit tests for Codex websocket route helpers.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import pytest + +from ccproxy.config.settings import Settings +from ccproxy.core.request_context import RequestContext +from ccproxy.models.detection import DetectedHeaders, DetectedPrompts +from ccproxy.models.provider import ModelMappingRule +from ccproxy.plugins.codex.adapter import CodexAdapter +from ccproxy.plugins.codex.detection_service import CodexDetectionService +from ccproxy.plugins.codex.routes import ( + _get_websocket_settings, + _restore_websocket_event_models, + _sanitize_websocket_payload, +) + + +@pytest.fixture +def codex_adapter() -> CodexAdapter: + detection_service = Mock(spec=CodexDetectionService) + prompts = DetectedPrompts.from_body({"instructions": "Detected instructions"}) + detection_service.get_detected_prompts = Mock(return_value=prompts) + detection_service.get_system_prompt = Mock( + return_value=prompts.instructions_payload() + ) + detection_service.get_detected_headers = Mock(return_value=DetectedHeaders({})) + detection_service.get_ignored_headers = Mock(return_value=[]) + detection_service.get_redacted_headers = Mock(return_value=[]) + + auth_manager = Mock() + auth_manager.get_access_token = AsyncMock(return_value="test-token") + auth_manager.get_access_token_with_refresh = AsyncMock(return_value="test-token") + auth_manager.load_credentials = AsyncMock(return_value=Mock(access_token="token")) + auth_manager.should_refresh = Mock(return_value=False) + auth_manager.get_token_snapshot = AsyncMock(return_value=None) + auth_manager.get_profile_quick = AsyncMock(return_value=None) + + config = Mock() + config.base_url = "https://chatgpt.com/backend-codex" + config.model_mappings = [ + ModelMappingRule( + match="alias-model", + target="gpt-5.3-codex", + kind="exact", + ) + ] + + return CodexAdapter( + detection_service=detection_service, + config=config, + auth_manager=auth_manager, + http_pool_manager=Mock(), + ) + + +@pytest.mark.asyncio +async def test_sanitize_websocket_payload_applies_model_mapping( + codex_adapter: CodexAdapter, +) -> None: + request_context = RequestContext( + request_id="ws-test", + start_time=0.0, + logger=Mock(), + metadata={}, + format_chain=["openai.responses"], + ) + + payload, headers = await _sanitize_websocket_payload( + codex_adapter, + {"model": "alias-model", "input": []}, + {"content-type": "application/json"}, + request_context, + ) + + assert payload["model"] == "gpt-5.3-codex" + assert headers["authorization"] == "Bearer test-token" + assert request_context.metadata["_model_alias_map"] == { + "gpt-5.3-codex": "alias-model" + } + + +def test_restore_websocket_event_models_uses_client_alias() -> None: + request_context = RequestContext( + request_id="ws-test", + start_time=0.0, + logger=Mock(), + metadata={"_model_alias_map": {"gpt-5.3-codex": "alias-model"}}, + ) + event = { + "type": "response.completed", + "response": { + "model": "gpt-5.3-codex", + "output": [{"type": "message", "model": "gpt-5.3-codex"}], + }, + } + + restored = _restore_websocket_event_models(event, request_context) + + assert restored["response"]["model"] == "alias-model" + assert restored["response"]["output"][0]["model"] == "alias-model" + + +def test_get_websocket_settings_prefers_app_state_settings() -> None: + settings = Settings() + websocket = Mock() + websocket.app.state = SimpleNamespace(settings=settings) + + resolved = _get_websocket_settings(websocket) + + assert resolved is settings + + +def test_get_websocket_settings_raises_when_container_cannot_provide_settings() -> None: + container = Mock() + container.get_service.side_effect = ValueError("missing settings") + websocket = Mock() + websocket.app.state = SimpleNamespace(service_container=container) + + with pytest.raises(RuntimeError, match="Settings service unavailable"): + _get_websocket_settings(websocket) diff --git a/tests/unit/core/test_provider_factory_bypass.py b/tests/unit/core/test_provider_factory_bypass.py index b0d6f840..e5d4d666 100644 --- a/tests/unit/core/test_provider_factory_bypass.py +++ b/tests/unit/core/test_provider_factory_bypass.py @@ -6,8 +6,8 @@ import pytest from ccproxy.core.plugins import factories as plugin_factories +from ccproxy.plugins.codex.adapter import CodexAdapter from ccproxy.plugins.codex.plugin import CodexFactory -from ccproxy.services.adapters.mock_adapter import MockAdapter @pytest.mark.asyncio @@ -16,15 +16,20 @@ async def test_create_adapter_logs_warning_in_bypass_mode() -> None: mock_handler = MagicMock() service_container = MagicMock() service_container.get_mock_handler.return_value = mock_handler + service_container.get_adapter_dependencies.return_value = {"format_registry": None} context = { "settings": SimpleNamespace(server=SimpleNamespace(bypass_mode=True)), "service_container": service_container, + "config": SimpleNamespace(base_url="https://chatgpt.com/backend-codex"), + "http_pool_manager": MagicMock(), + "detection_service": MagicMock(), + "credentials_manager": MagicMock(), } with patch.object(plugin_factories.logger, "warning") as warning: adapter = await factory.create_adapter(context) # type: ignore[arg-type] - assert isinstance(adapter, MockAdapter) + assert isinstance(adapter, CodexAdapter) assert adapter.mock_handler is mock_handler warning.assert_called_once_with( "plugin_bypass_mode_enabled", diff --git a/tests/unit/plugins/test_codex_detection.py b/tests/unit/plugins/test_codex_detection.py index f6864c63..36f1b5cd 100644 --- a/tests/unit/plugins/test_codex_detection.py +++ b/tests/unit/plugins/test_codex_detection.py @@ -7,6 +7,7 @@ import pytest from ccproxy.config.settings import Settings +from ccproxy.models.detection import DetectedPrompts from ccproxy.plugins.codex.detection_service import CodexDetectionService @@ -52,3 +53,39 @@ async def test_codex_detection_falls_back_when_cli_missing(tmp_path: Path) -> No def test_codex_detection_ignores_content_encoding_header() -> None: assert "content-encoding" in CodexDetectionService.ignores_header + + +def test_codex_detection_merges_partial_prompt_cache_with_fallback() -> None: + settings = MagicMock(spec=Settings) + cli_service = MagicMock() + service = CodexDetectionService(settings=settings, cli_service=cli_service) + + cached_prompts = DetectedPrompts.from_body( + {"tools": [{"type": "function", "name": "exec_command"}]} + ) + fallback_prompts = DetectedPrompts.from_body( + { + "instructions": "Fallback instructions", + "include": ["reasoning.encrypted_content"], + "tool_choice": "auto", + } + ) + + with ( + patch.object( + service, + "get_cached_data", + return_value=SimpleNamespace(prompts=cached_prompts), + ), + patch.object( + service, + "_safe_fallback_data", + return_value=SimpleNamespace(prompts=fallback_prompts), + ), + ): + prompts = service.get_detected_prompts() + + assert prompts.instructions == "Fallback instructions" + assert prompts.raw["tools"] == [{"type": "function", "name": "exec_command"}] + assert prompts.raw["include"] == ["reasoning.encrypted_content"] + assert prompts.raw["tool_choice"] == "auto"