diff --git a/ccproxy/plugins/codex/adapter.py b/ccproxy/plugins/codex/adapter.py index 6faf46ed..27ca606e 100644 --- a/ccproxy/plugins/codex/adapter.py +++ b/ccproxy/plugins/codex/adapter.py @@ -1,4 +1,5 @@ import contextlib +import copy import json import uuid from typing import Any, cast @@ -63,12 +64,7 @@ async def handle_request( headers = extract_request_headers(request) # Determine client streaming intent from body flag (fallback to False) - wants_stream = False - try: - data = json.loads(body.decode()) if body else {} - wants_stream = bool(data.get("stream", False)) - except Exception: # Malformed/missing JSON -> assume non-streaming - wants_stream = False + wants_stream = self._detect_streaming_intent(body, headers) logger.trace( "codex_adapter_request_intent", wants_stream=wants_stream, @@ -256,19 +252,18 @@ async def get_target_url(self, endpoint: str) -> str: async def prepare_provider_request( self, body: bytes, headers: dict[str, str], endpoint: str ) -> tuple[bytes, dict[str, str]]: - token_value = await self._resolve_access_token() + filtered_headers = await self.prepare_provider_headers(headers) - # Get profile to extract chatgpt_account_id - profile = await self.token_manager.get_profile_quick() - chatgpt_account_id = ( - getattr(profile, "chatgpt_account_id", None) if profile else None - ) + if self._request_body_is_encoded(headers): + return body, filtered_headers + + # Body will be re-serialized as plain JSON; drop stale encoding header + filtered_headers.pop("content-encoding", None) # Parse body (format conversion is now handled by format chain) body_data = json.loads(body.decode()) if body else {} + body_data = self._apply_request_template(body_data) - # Inject instructions mandatory for being allow to - # to used the Codex API endpoint # Fetch detected instructions from detection service instructions = self._get_instructions() @@ -299,8 +294,20 @@ async def prepare_provider_request( # Remove any prefixed metadata fields that shouldn't be sent to the API body_data = self._remove_metadata_fields(body_data) - # Filter and add headers + return json.dumps(body_data).encode(), filtered_headers + + async def prepare_provider_headers(self, headers: dict[str, str]) -> dict[str, str]: + token_value = await self._resolve_access_token() + + profile = await self.token_manager.get_profile_quick() + chatgpt_account_id = ( + getattr(profile, "chatgpt_account_id", None) if profile else None + ) + filtered_headers = filter_request_headers(headers, preserve_auth=False) + content_encoding = headers.get("content-encoding") + if content_encoding: + filtered_headers["content-encoding"] = content_encoding session_id = filtered_headers.get("session_id") or str(uuid.uuid4()) conversation_id = filtered_headers.get("conversation_id") or str(uuid.uuid4()) @@ -318,10 +325,10 @@ async def prepare_provider_request( filtered_headers.update(base_headers) cli_headers = self._collect_cli_headers() - if cli_headers: - filtered_headers.update(cli_headers) + for key, value in cli_headers.items(): + filtered_headers.setdefault(key, value) - return json.dumps(body_data).encode(), filtered_headers + return filtered_headers async def process_provider_response( self, response: httpx.Response, endpoint: str @@ -581,6 +588,70 @@ def _remove_metadata_fields(self, data: dict[str, Any]) -> dict[str, Any]: return cleaned_data + def _apply_request_template(self, data: dict[str, Any]) -> dict[str, Any]: + if not isinstance(data, dict): + return data + + template = self._get_request_template() + if not template: + return self._normalize_input_messages(data) + + merged = copy.deepcopy(data) + + for key in ("include", "parallel_tool_calls", "reasoning", "tool_choice"): + if key not in merged and key in template: + merged[key] = copy.deepcopy(template[key]) + + if not merged.get("tools") and isinstance(template.get("tools"), list): + merged["tools"] = copy.deepcopy(template["tools"]) + + if "prompt_cache_key" not in merged: + prompt_cache_key = template.get("prompt_cache_key") + if isinstance(prompt_cache_key, str) and prompt_cache_key: + merged["prompt_cache_key"] = str(uuid.uuid4()) + + return self._normalize_input_messages(merged) + + def _normalize_input_messages(self, data: dict[str, Any]) -> dict[str, Any]: + input_items = data.get("input") + if not isinstance(input_items, list): + return data + + normalized_items: list[Any] = [] + for item in input_items: + if ( + isinstance(item, dict) + and "type" not in item + and "role" in item + and "content" in item + ): + normalized_item = dict(item) + normalized_item["type"] = "message" + normalized_items.append(normalized_item) + continue + + normalized_items.append(item) + + result = dict(data) + result["input"] = normalized_items + return result + + def _request_body_is_encoded(self, headers: dict[str, str]) -> bool: + encoding = headers.get("content-encoding", "").strip().lower() + return bool(encoding and encoding != "identity") + + def _detect_streaming_intent(self, body: bytes, headers: dict[str, str]) -> bool: + if self._request_body_is_encoded(headers): + accept = headers.get("accept", "").lower() + return "text/event-stream" in accept + + try: + data = json.loads(body.decode()) if body else {} + return bool(data.get("stream", False)) + except Exception: + accept = headers.get("accept", "").lower() + return "text/event-stream" in accept + def _get_instructions(self) -> str: if not self.detection_service: return "" @@ -601,6 +672,16 @@ def _get_instructions(self) -> str: return "" + def _get_request_template(self) -> dict[str, Any]: + if not self.detection_service: + return {} + + prompts = self.detection_service.get_detected_prompts() + if isinstance(prompts.raw, dict) and prompts.raw: + return prompts.raw + + return {} + def adapt_error(self, error_body: dict[str, Any]) -> dict[str, Any]: """Convert Codex error format to appropriate API error format. diff --git a/ccproxy/plugins/codex/detection_service.py b/ccproxy/plugins/codex/detection_service.py index 9cb396b0..88885560 100644 --- a/ccproxy/plugins/codex/detection_service.py +++ b/ccproxy/plugins/codex/detection_service.py @@ -12,6 +12,7 @@ from typing import Any, cast from fastapi import FastAPI, Request, Response +from pydantic import ValidationError from ccproxy.config.settings import Settings from ccproxy.config.utils import get_ccproxy_cache_dir @@ -43,6 +44,7 @@ class CodexDetectionService: ignores_header: list[str] = [ "host", "content-length", + "content-encoding", "authorization", "x-api-key", "session_id", @@ -133,17 +135,46 @@ def get_detected_headers(self) -> DetectedHeaders: """Return cached headers as structured data.""" data = self.get_cached_data() - if not data: - return DetectedHeaders() - return data.headers + headers = data.headers if data else DetectedHeaders() + + required_headers = { + "accept", + "content-type", + "openai-beta", + "originator", + "version", + } + missing_required = [key for key in required_headers if not headers.get(key)] + if not missing_required: + return headers + + fallback = self._safe_fallback_data() + if fallback is None: + return headers + + merged_headers = fallback.headers.as_dict() + merged_headers.update( + {key: value for key, value in headers.as_dict().items() if value} + ) + return DetectedHeaders(merged_headers) def get_detected_prompts(self) -> DetectedPrompts: """Return cached prompt metadata as structured data.""" data = self.get_cached_data() - if not data: - return DetectedPrompts() - return data.prompts + prompts = data.prompts if data else DetectedPrompts() + if prompts.has_instructions() or prompts.has_system(): + return prompts + + fallback = self._safe_fallback_data() + if fallback is None: + return prompts + + fallback_prompts = fallback.prompts + if fallback_prompts.has_instructions() or fallback_prompts.has_system(): + return fallback_prompts + + return prompts def get_ignored_headers(self) -> list[str]: """Headers that should be ignored when forwarding CLI values.""" @@ -496,6 +527,16 @@ def _get_fallback_data(self) -> CodexCacheData: fallback_data_dict = json.load(f) return CodexCacheData.model_validate(fallback_data_dict) + def _safe_fallback_data(self) -> CodexCacheData | None: + """Best-effort fallback data loader for partial detection caches.""" + try: + return self._get_fallback_data() + except (OSError, json.JSONDecodeError, ValidationError): + logger.debug( + "safe_fallback_data_load_failed", exc_info=True, category="plugin" + ) + return None + def invalidate_cache(self) -> None: """Clear all cached detection data.""" # Clear the async cache for _get_codex_version diff --git a/ccproxy/plugins/codex/routes.py b/ccproxy/plugins/codex/routes.py index 28f59c93..5a8310f3 100644 --- a/ccproxy/plugins/codex/routes.py +++ b/ccproxy/plugins/codex/routes.py @@ -1,9 +1,18 @@ """Codex plugin routes.""" +import contextlib +import json +from collections import deque +from pathlib import Path +from time import time from typing import TYPE_CHECKING, Annotated, Any, cast +from urllib.parse import urlparse +from uuid import uuid4 -from fastapi import APIRouter, Depends, Request +import anyio +from fastapi import APIRouter, Depends, Request, WebSocket, WebSocketDisconnect from starlette.responses import Response, StreamingResponse +from starlette.websockets import WebSocketState from ccproxy.api.decorators import with_format_chain from ccproxy.api.dependencies import ( @@ -11,6 +20,7 @@ get_provider_config_dependency, ) from ccproxy.auth.dependencies import ConditionalAuthDep +from ccproxy.config.settings import Settings from ccproxy.core.constants import ( FORMAT_ANTHROPIC_MESSAGES, FORMAT_OPENAI_CHAT, @@ -19,13 +29,20 @@ UPSTREAM_ENDPOINT_OPENAI_CHAT_COMPLETIONS, UPSTREAM_ENDPOINT_OPENAI_RESPONSES, ) +from ccproxy.core.logging import get_plugin_logger +from ccproxy.core.plugins import PluginRegistry, ProviderPluginRuntime from ccproxy.streaming import DeferredStreaming +from ccproxy.streaming.sse_parser import SSEStreamParser from .config import CodexSettings if TYPE_CHECKING: - pass + from .adapter import CodexAdapter + +logger = get_plugin_logger() + +_MAX_LOCAL_RESPONSE_IDS = 256 CodexAdapterDep = Annotated[Any, Depends(get_plugin_adapter("codex"))] CodexConfigDep = Annotated[ @@ -54,6 +71,243 @@ async def _codex_responses_handler( return await handle_codex_request(request, adapter) +def _get_codex_websocket_adapter(websocket: WebSocket) -> "CodexAdapter": + if not hasattr(websocket.app.state, "plugin_registry"): + raise RuntimeError("Plugin registry not initialized") + + registry: PluginRegistry = websocket.app.state.plugin_registry + runtime = registry.get_runtime("codex") + + if not runtime or not isinstance(runtime, ProviderPluginRuntime): + raise RuntimeError("Codex plugin not initialized") + + if not runtime.adapter: + raise RuntimeError("Codex adapter not available") + + return cast("CodexAdapter", runtime.adapter) + + +def _prepare_websocket_headers(websocket: WebSocket) -> dict[str, str]: + headers = { + key.lower(): value + for key, value in websocket.headers.items() + if not key.lower().startswith("sec-websocket-") + } + headers["accept"] = "text/event-stream" + return headers + + +def _parse_websocket_request(raw_message: str) -> dict[str, Any]: + payload = json.loads(raw_message) + if not isinstance(payload, dict): + raise ValueError("Expected JSON object payload") + + if payload.get("type") != "response.create": + raise ValueError("Unsupported websocket message type") + + provider_payload = dict(payload) + provider_payload.pop("type", None) + return provider_payload + + +def _make_websocket_terminal_event( + provider_payload: dict[str, Any], + *, + error: dict[str, Any] | None = None, +) -> dict[str, Any]: + response_payload: dict[str, Any] = { + "id": f"resp_ws_{uuid4().hex}", + "object": "response", + "created_at": int(time()), + "status": "failed" if error else "completed", + "model": provider_payload.get("model"), + "output": [], + "parallel_tool_calls": False, + "error": error, + "incomplete_details": None, + } + return {"type": "response.completed", "response": response_payload} + + +def _is_websocket_warmup_request(provider_payload: dict[str, Any]) -> bool: + input_items = provider_payload.get("input") + return isinstance(input_items, list) and len(input_items) == 0 + + +async def _authenticate_websocket(websocket: WebSocket) -> None: + """Enforce bearer auth on WebSocket connections when auth is configured. + + Mirrors the ConditionalAuthDep logic: if security.auth_token is set, + the client must provide a matching Authorization header. Closes the + connection with 1008 (Policy Violation) on failure. + """ + container = getattr(websocket.app.state, "service_container", None) + settings: Settings | None = None + if container is not None: + with contextlib.suppress(ValueError): + settings = container.get_service(Settings) + if settings is None: + with contextlib.suppress(Exception): + settings = Settings() + + if settings is None or not settings.security.auth_token: + return + + expected = settings.security.auth_token.get_secret_value() + auth_header = websocket.headers.get("authorization", "") + scheme, _, credentials = auth_header.partition(" ") + if scheme.lower() == "bearer": + credentials = credentials.strip() + token = credentials.split()[0] if credentials else "" + else: + token = "" + + if token != expected: + await websocket.close(code=1008, reason="Authentication required") + raise WebSocketDisconnect(code=1008) + + +async def _sanitize_websocket_payload( + adapter: "CodexAdapter", provider_payload: dict[str, Any], headers: dict[str, str] +) -> tuple[dict[str, Any], dict[str, str]]: + """Run the same request normalization used by HTTP routes on a WS payload.""" + body_bytes = json.dumps(provider_payload).encode("utf-8") + prepared_body, prepared_headers = await adapter.prepare_provider_request( + body_bytes, headers, UPSTREAM_ENDPOINT_OPENAI_RESPONSES + ) + sanitized_payload = json.loads(prepared_body.decode("utf-8")) + return sanitized_payload, prepared_headers + + +def _serialize_codex_models(config: CodexSettings) -> list[dict[str, Any]]: + models: list[dict[str, Any]] = [] + for card in config.models_endpoint: + model_data = card.model_dump(mode="json") + slug = model_data.get("slug") or model_data.get("id") or model_data.get("root") + if isinstance(slug, str) and slug: + model_data.setdefault("slug", slug) + model_data.setdefault("display_name", slug) + models.append(model_data) + return models + + +async def _load_codex_cli_models_cache() -> list[dict[str, Any]]: + cache_path = anyio.Path(Path.home() / ".codex" / "models_cache.json") + if not await cache_path.exists(): + return [] + + try: + content = await cache_path.read_text() + payload = json.loads(content) + except Exception: + return [] + + models = payload.get("models") + if not isinstance(models, list): + return [] + + return [model for model in models if isinstance(model, dict)] + + +async def _serialize_codex_cli_models(config: CodexSettings) -> list[dict[str, Any]]: + configured_ids = { + card.id + for card in config.models_endpoint + if isinstance(getattr(card, "id", None), str) + } + configured_ids.update( + { + card.root + for card in config.models_endpoint + if isinstance(getattr(card, "root", None), str) and card.root + } + ) + + cached_models = await _load_codex_cli_models_cache() + if cached_models and configured_ids: + matched = [ + model + for model in cached_models + if model.get("slug") in configured_ids + or model.get("display_name") in configured_ids + ] + if matched: + return matched + + return _serialize_codex_models(config) + + +async def _stream_websocket_response( + websocket: WebSocket, + adapter: "CodexAdapter", + provider_payload: dict[str, Any], +) -> None: + request_headers = _prepare_websocket_headers(websocket) + provider_payload, provider_headers = await _sanitize_websocket_payload( + adapter, provider_payload, request_headers + ) + target_url = await adapter.get_target_url(UPSTREAM_ENDPOINT_OPENAI_RESPONSES) + parsed_url = urlparse(target_url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + client = await adapter.http_pool_manager.get_streaming_client(base_url=base_url) + parser = SSEStreamParser() + saw_terminal_event = False + + async with client.stream( + "POST", + target_url, + headers=provider_headers, + content=json.dumps(provider_payload).encode("utf-8"), + ) as upstream_response: + if upstream_response.status_code >= 400: + error_body = await upstream_response.aread() + try: + error_payload = json.loads(error_body.decode("utf-8")) + except Exception: + error_payload = { + "error": { + "type": "server_error", + "message": error_body.decode("utf-8", errors="replace") + or "Upstream Codex request failed", + } + } + await websocket.send_text( + json.dumps( + _make_websocket_terminal_event( + provider_payload, + error=error_payload.get("error", error_payload), + ), + separators=(",", ":"), + ) + ) + return + + async for chunk in upstream_response.aiter_bytes(): + for event in parser.feed(chunk): + if event.get("type") in {"response.completed", "response.failed"}: + saw_terminal_event = True + await websocket.send_text(json.dumps(event, separators=(",", ":"))) + + for event in parser.flush(): + if event.get("type") in {"response.completed", "response.failed"}: + saw_terminal_event = True + await websocket.send_text(json.dumps(event, separators=(",", ":"))) + + if not saw_terminal_event: + await websocket.send_text( + json.dumps( + _make_websocket_terminal_event( + provider_payload, + error={ + "type": "server_error", + "message": "WebSocket stream ended before response.completed", + }, + ), + separators=(",", ":"), + ) + ) + + @router.post("/v1/responses", response_model=None) @with_format_chain( [FORMAT_OPENAI_RESPONSES], endpoint=UPSTREAM_ENDPOINT_OPENAI_RESPONSES @@ -66,6 +320,51 @@ async def codex_responses( return await _codex_responses_handler(request, adapter) +@router.websocket("/v1/responses") +async def codex_responses_websocket(websocket: WebSocket) -> None: + await websocket.accept() + await _authenticate_websocket(websocket) + + try: + adapter = _get_codex_websocket_adapter(websocket) + local_response_ids: deque[str] = deque(maxlen=_MAX_LOCAL_RESPONSE_IDS) + logger.debug("websocket_connected", client=str(websocket.client)) + while True: + raw_message = await websocket.receive_text() + provider_payload = _parse_websocket_request(raw_message) + if _is_websocket_warmup_request(provider_payload): + warmup_event = _make_websocket_terminal_event(provider_payload) + response_id = warmup_event.get("response", {}).get("id") + if isinstance(response_id, str) and response_id: + local_response_ids.append(response_id) + await websocket.send_text( + json.dumps(warmup_event, separators=(",", ":")) + ) + logger.debug("websocket_warmup_handled", response_id=response_id) + continue + previous_response_id = provider_payload.get("previous_response_id") + if ( + isinstance(previous_response_id, str) + and previous_response_id in local_response_ids + ): + provider_payload.pop("previous_response_id", None) + logger.debug( + "websocket_streaming_request", model=provider_payload.get("model") + ) + await _stream_websocket_response(websocket, adapter, provider_payload) + except WebSocketDisconnect: + logger.debug("websocket_disconnected", client=str(websocket.client)) + return + except ValueError as exc: + logger.warning("websocket_value_error", error=str(exc)) + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close(code=1008, reason=str(exc)) + except Exception: + logger.warning("websocket_unexpected_error", exc_info=True) + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close(code=1011, reason="Internal server error") + + @router.post("/responses", response_model=None, include_in_schema=False) @with_format_chain( [FORMAT_OPENAI_RESPONSES], endpoint=UPSTREAM_ENDPOINT_OPENAI_RESPONSES @@ -78,6 +377,11 @@ async def codex_responses_legacy( return await _codex_responses_handler(request, adapter) +@router.websocket("/responses") +async def codex_responses_legacy_websocket(websocket: WebSocket) -> None: + await codex_responses_websocket(websocket) + + @router.post("/v1/chat/completions", response_model=None) @with_format_chain( [FORMAT_OPENAI_CHAT, FORMAT_OPENAI_RESPONSES], @@ -98,8 +402,9 @@ async def list_models( config: CodexConfigDep, ) -> dict[str, Any]: """List available Codex models.""" - models = [card.model_dump(mode="json") for card in config.models_endpoint] - return {"object": "list", "data": models} + openai_models = _serialize_codex_models(config) + codex_models = await _serialize_codex_cli_models(config) + return {"object": "list", "data": openai_models, "models": codex_models} @router.post("/v1/messages", response_model=None) diff --git a/tests/plugins/codex/integration/test_codex_basic.py b/tests/plugins/codex/integration/test_codex_basic.py index 876f668b..06da0b73 100644 --- a/tests/plugins/codex/integration/test_codex_basic.py +++ b/tests/plugins/codex/integration/test_codex_basic.py @@ -30,9 +30,15 @@ async def test_models_endpoint_available_when_enabled( data: dict[str, Any] = resp.json() assert data.get("object") == "list" models = data.get("data") + cli_models = data.get("models") assert isinstance(models, list) assert len(models) > 0 + assert isinstance(cli_models, list) + assert len(cli_models) > 0 assert {"id", "object", "created", "owned_by"}.issubset(models[0].keys()) + assert models[0].get("slug") == models[0]["id"] + assert models[0].get("display_name") == models[0]["id"] + assert cli_models[0].get("slug") @pytest.mark.asyncio diff --git a/tests/plugins/codex/unit/test_adapter.py b/tests/plugins/codex/unit/test_adapter.py index 40a4c18e..89705135 100644 --- a/tests/plugins/codex/unit/test_adapter.py +++ b/tests/plugins/codex/unit/test_adapter.py @@ -217,6 +217,122 @@ async def test_prepare_provider_request_sets_stream_true( result_data = json.loads(result_body.decode()) assert result_data["stream"] is True + @pytest.mark.asyncio + async def test_prepare_provider_request_removes_max_completion_tokens( + self, adapter: CodexAdapter + ) -> None: + body_dict = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "gpt-4", + "max_completion_tokens": 321, + } + body = json.dumps(body_dict).encode() + + result_body, _ = await adapter.prepare_provider_request( + body, {"content-type": "application/json"}, "/responses" + ) + + result_data = json.loads(result_body.decode()) + assert "max_output_tokens" not in result_data + assert "max_completion_tokens" not in result_data + + @pytest.mark.asyncio + async def test_prepare_provider_request_preserves_encoded_body( + self, adapter: CodexAdapter + ) -> None: + """Encoded request bodies should pass through unchanged.""" + body = b"\x28\xb5\x2f\xfdcompressed-request" + headers = { + "content-type": "application/json", + "content-encoding": "zstd", + "accept": "application/json, text/event-stream", + "authorization": "Bearer old-token", + "session_id": "existing-session", + } + + result_body, result_headers = await adapter.prepare_provider_request( + body, headers, "/responses" + ) + + assert result_body == body + assert result_headers["content-encoding"] == "zstd" + assert result_headers["authorization"] == "Bearer test-token" + assert result_headers["session_id"] == "existing-session" + assert "conversation_id" in result_headers + + @pytest.mark.asyncio + async def test_prepare_provider_request_strips_content_encoding_for_plain_body( + self, adapter: CodexAdapter + ) -> None: + """When body is not encoded, content-encoding must not be forwarded.""" + body_dict = { + "input": [{"type": "message", "role": "user", "content": "Hello"}], + "model": "gpt-4", + } + body = json.dumps(body_dict).encode() + headers = { + "content-type": "application/json", + "content-encoding": "identity", + } + + result_body, result_headers = await adapter.prepare_provider_request( + body, headers, "/responses" + ) + + result_data = json.loads(result_body.decode()) + assert result_data["stream"] is True + assert "content-encoding" not in result_headers + + @pytest.mark.asyncio + async def test_prepare_provider_request_applies_codex_template_defaults( + self, + mock_detection_service: Mock, + mock_auth_manager: Mock, + mock_http_pool_manager: Mock, + ) -> None: + template = { + "instructions": "You are a Python expert.", + "include": ["reasoning.encrypted_content"], + "parallel_tool_calls": True, + "reasoning": {"effort": "medium"}, + "tool_choice": "auto", + "tools": [{"type": "function", "name": "exec_command"}], + "prompt_cache_key": "template-cache-key", + } + prompts = DetectedPrompts.from_body(template) + mock_detection_service.get_detected_prompts = Mock(return_value=prompts) + mock_detection_service.get_system_prompt = Mock( + return_value=prompts.instructions_payload() + ) + + mock_config = Mock() + mock_config.base_url = "https://chat.openai.com/backend-anon" + + adapter = CodexAdapter( + detection_service=mock_detection_service, + config=mock_config, + auth_manager=mock_auth_manager, + http_pool_manager=mock_http_pool_manager, + ) + + body = json.dumps( + { + "model": "gpt-5", + "input": [{"role": "user", "content": [{"type": "input_text"}]}], + } + ).encode() + + result_body, _ = await adapter.prepare_provider_request(body, {}, "/responses") + result_data = json.loads(result_body.decode()) + + assert result_data["include"] == ["reasoning.encrypted_content"] + assert result_data["parallel_tool_calls"] is True + assert result_data["reasoning"] == {"effort": "medium"} + assert result_data["tool_choice"] == "auto" + assert result_data["tools"] == [{"type": "function", "name": "exec_command"}] + assert result_data["prompt_cache_key"] != "template-cache-key" + assert result_data["input"][0]["type"] == "message" + @pytest.mark.asyncio async def test_process_provider_response(self, adapter: CodexAdapter) -> None: """Test response processing and format conversion.""" diff --git a/tests/unit/plugins/test_codex_detection.py b/tests/unit/plugins/test_codex_detection.py index 9c8ec326..f6864c63 100644 --- a/tests/unit/plugins/test_codex_detection.py +++ b/tests/unit/plugins/test_codex_detection.py @@ -48,3 +48,7 @@ async def test_codex_detection_falls_back_when_cli_missing(tmp_path: Path) -> No mock_save.assert_not_called() assert result is expected_fallback assert service.get_cached_data() is expected_fallback + + +def test_codex_detection_ignores_content_encoding_header() -> None: + assert "content-encoding" in CodexDetectionService.ignores_header