Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 99 additions & 18 deletions ccproxy/plugins/codex/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import copy
import json
import uuid
from typing import Any, cast
Expand Down Expand Up @@ -63,12 +64,7 @@ async def handle_request(
headers = extract_request_headers(request)

# Determine client streaming intent from body flag (fallback to False)
wants_stream = False
try:
data = json.loads(body.decode()) if body else {}
wants_stream = bool(data.get("stream", False))
except Exception: # Malformed/missing JSON -> assume non-streaming
wants_stream = False
wants_stream = self._detect_streaming_intent(body, headers)
logger.trace(
"codex_adapter_request_intent",
wants_stream=wants_stream,
Expand Down Expand Up @@ -256,19 +252,18 @@ async def get_target_url(self, endpoint: str) -> str:
async def prepare_provider_request(
self, body: bytes, headers: dict[str, str], endpoint: str
) -> tuple[bytes, dict[str, str]]:
token_value = await self._resolve_access_token()
filtered_headers = await self.prepare_provider_headers(headers)

# Get profile to extract chatgpt_account_id
profile = await self.token_manager.get_profile_quick()
chatgpt_account_id = (
getattr(profile, "chatgpt_account_id", None) if profile else None
)
if self._request_body_is_encoded(headers):
return body, filtered_headers

# Body will be re-serialized as plain JSON; drop stale encoding header
filtered_headers.pop("content-encoding", None)

# Parse body (format conversion is now handled by format chain)
body_data = json.loads(body.decode()) if body else {}
body_data = self._apply_request_template(body_data)

# Inject instructions mandatory for being allow to
# to used the Codex API endpoint
# Fetch detected instructions from detection service
instructions = self._get_instructions()

Expand Down Expand Up @@ -299,8 +294,20 @@ async def prepare_provider_request(
# Remove any prefixed metadata fields that shouldn't be sent to the API
body_data = self._remove_metadata_fields(body_data)

# Filter and add headers
return json.dumps(body_data).encode(), filtered_headers

async def prepare_provider_headers(self, headers: dict[str, str]) -> dict[str, str]:
token_value = await self._resolve_access_token()

profile = await self.token_manager.get_profile_quick()
chatgpt_account_id = (
getattr(profile, "chatgpt_account_id", None) if profile else None
)

filtered_headers = filter_request_headers(headers, preserve_auth=False)
content_encoding = headers.get("content-encoding")
if content_encoding:
filtered_headers["content-encoding"] = content_encoding

session_id = filtered_headers.get("session_id") or str(uuid.uuid4())
conversation_id = filtered_headers.get("conversation_id") or str(uuid.uuid4())
Expand All @@ -318,10 +325,10 @@ async def prepare_provider_request(
filtered_headers.update(base_headers)

cli_headers = self._collect_cli_headers()
if cli_headers:
filtered_headers.update(cli_headers)
for key, value in cli_headers.items():
filtered_headers.setdefault(key, value)

return json.dumps(body_data).encode(), filtered_headers
return filtered_headers

async def process_provider_response(
self, response: httpx.Response, endpoint: str
Expand Down Expand Up @@ -581,6 +588,70 @@ def _remove_metadata_fields(self, data: dict[str, Any]) -> dict[str, Any]:

return cleaned_data

def _apply_request_template(self, data: dict[str, Any]) -> dict[str, Any]:
if not isinstance(data, dict):
return data

template = self._get_request_template()
if not template:
return self._normalize_input_messages(data)

merged = copy.deepcopy(data)

for key in ("include", "parallel_tool_calls", "reasoning", "tool_choice"):
if key not in merged and key in template:
merged[key] = copy.deepcopy(template[key])

if not merged.get("tools") and isinstance(template.get("tools"), list):
merged["tools"] = copy.deepcopy(template["tools"])

if "prompt_cache_key" not in merged:
prompt_cache_key = template.get("prompt_cache_key")
if isinstance(prompt_cache_key, str) and prompt_cache_key:
merged["prompt_cache_key"] = str(uuid.uuid4())

return self._normalize_input_messages(merged)

def _normalize_input_messages(self, data: dict[str, Any]) -> dict[str, Any]:
input_items = data.get("input")
if not isinstance(input_items, list):
return data

normalized_items: list[Any] = []
for item in input_items:
if (
isinstance(item, dict)
and "type" not in item
and "role" in item
and "content" in item
):
normalized_item = dict(item)
normalized_item["type"] = "message"
normalized_items.append(normalized_item)
continue

normalized_items.append(item)

result = dict(data)
result["input"] = normalized_items
return result

def _request_body_is_encoded(self, headers: dict[str, str]) -> bool:
encoding = headers.get("content-encoding", "").strip().lower()
return bool(encoding and encoding != "identity")

def _detect_streaming_intent(self, body: bytes, headers: dict[str, str]) -> bool:
if self._request_body_is_encoded(headers):
accept = headers.get("accept", "").lower()
return "text/event-stream" in accept

try:
data = json.loads(body.decode()) if body else {}
return bool(data.get("stream", False))
except Exception:
accept = headers.get("accept", "").lower()
return "text/event-stream" in accept

def _get_instructions(self) -> str:
if not self.detection_service:
return ""
Expand All @@ -601,6 +672,16 @@ def _get_instructions(self) -> str:

return ""

def _get_request_template(self) -> dict[str, Any]:
if not self.detection_service:
return {}

prompts = self.detection_service.get_detected_prompts()
if isinstance(prompts.raw, dict) and prompts.raw:
return prompts.raw

return {}

def adapt_error(self, error_body: dict[str, Any]) -> dict[str, Any]:
"""Convert Codex error format to appropriate API error format.

Expand Down
53 changes: 47 additions & 6 deletions ccproxy/plugins/codex/detection_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, cast

from fastapi import FastAPI, Request, Response
from pydantic import ValidationError

from ccproxy.config.settings import Settings
from ccproxy.config.utils import get_ccproxy_cache_dir
Expand Down Expand Up @@ -43,6 +44,7 @@ class CodexDetectionService:
ignores_header: list[str] = [
"host",
"content-length",
"content-encoding",
"authorization",
"x-api-key",
"session_id",
Expand Down Expand Up @@ -133,17 +135,46 @@ def get_detected_headers(self) -> DetectedHeaders:
"""Return cached headers as structured data."""

data = self.get_cached_data()
if not data:
return DetectedHeaders()
return data.headers
headers = data.headers if data else DetectedHeaders()

required_headers = {
"accept",
"content-type",
"openai-beta",
"originator",
"version",
}
missing_required = [key for key in required_headers if not headers.get(key)]
if not missing_required:
return headers

fallback = self._safe_fallback_data()
if fallback is None:
return headers

merged_headers = fallback.headers.as_dict()
merged_headers.update(
{key: value for key, value in headers.as_dict().items() if value}
)
return DetectedHeaders(merged_headers)

def get_detected_prompts(self) -> DetectedPrompts:
"""Return cached prompt metadata as structured data."""

data = self.get_cached_data()
if not data:
return DetectedPrompts()
return data.prompts
prompts = data.prompts if data else DetectedPrompts()
if prompts.has_instructions() or prompts.has_system():
return prompts

fallback = self._safe_fallback_data()
if fallback is None:
return prompts

fallback_prompts = fallback.prompts
if fallback_prompts.has_instructions() or fallback_prompts.has_system():
return fallback_prompts

return prompts

def get_ignored_headers(self) -> list[str]:
"""Headers that should be ignored when forwarding CLI values."""
Expand Down Expand Up @@ -496,6 +527,16 @@ def _get_fallback_data(self) -> CodexCacheData:
fallback_data_dict = json.load(f)
return CodexCacheData.model_validate(fallback_data_dict)

def _safe_fallback_data(self) -> CodexCacheData | None:
"""Best-effort fallback data loader for partial detection caches."""
try:
return self._get_fallback_data()
except (OSError, json.JSONDecodeError, ValidationError):
logger.debug(
"safe_fallback_data_load_failed", exc_info=True, category="plugin"
)
return None

def invalidate_cache(self) -> None:
"""Clear all cached detection data."""
# Clear the async cache for _get_codex_version
Expand Down
Loading
Loading