Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ccproxy/core/plugins/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class BaseProviderPluginFactory(ProviderPluginFactory):
cli_commands: list[CliCommandSpec] = []
cli_arguments: list[CliArgumentSpec] = []
tool_accumulator_class: type | None = None
use_mock_adapter_in_bypass_mode: bool = True

def __init__(self) -> None:
"""Initialize factory with manifest built from class attributes."""
Expand Down Expand Up @@ -231,7 +232,8 @@ async def create_adapter(self, context: PluginContext) -> BaseAdapter:
adapter=self.adapter_class.__name__,
category="lifecycle",
)
return MockAdapter(service_container.get_mock_handler())
if self.use_mock_adapter_in_bypass_mode:
return MockAdapter(service_container.get_mock_handler())

# Extract services from context (one-time extraction)
http_pool_manager: HTTPPoolManager | None = cast(
Expand Down Expand Up @@ -285,6 +287,8 @@ async def create_adapter(self, context: PluginContext) -> BaseAdapter:
if hasattr(context, "get")
else None,
}
if settings and getattr(settings.server, "bypass_mode", False):
adapter_kwargs["mock_handler"] = service_container.get_mock_handler()
if self.tool_accumulator_class:
adapter_kwargs["tool_accumulator_class"] = self.tool_accumulator_class

Expand Down Expand Up @@ -320,6 +324,9 @@ async def create_adapter(self, context: PluginContext) -> BaseAdapter:
"model_mapper": context.get("model_mapper")
if hasattr(context, "get")
else None,
"mock_handler": service_container.get_mock_handler()
if settings and getattr(settings.server, "bypass_mode", False)
else None,
}
if self.tool_accumulator_class:
non_http_adapter_kwargs["tool_accumulator_class"] = (
Expand Down
24 changes: 19 additions & 5 deletions ccproxy/plugins/codex/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from ccproxy.services.adapters.chain_composer import compose_from_chain
from ccproxy.services.adapters.http_adapter import BaseHTTPAdapter
from ccproxy.services.adapters.mock_adapter import MockAdapter
from ccproxy.services.handler_config import HandlerConfig
from ccproxy.streaming import DeferredStreaming, StreamingBufferService
from ccproxy.utils.headers import (
Expand Down Expand Up @@ -58,6 +59,9 @@ async def handle_request(
# Context + request info
ctx = request.state.context
self._ensure_tool_accumulator(ctx)
if self.mock_handler:
return await MockAdapter(self.mock_handler).handle_request(request)

endpoint = ctx.metadata.get("endpoint", "")
body = await request.body()
body = await self._map_request_model(ctx, body)
Expand Down Expand Up @@ -286,6 +290,13 @@ async def prepare_provider_request(
else:
body_data.pop("instructions", None)

body_data = self._sanitize_provider_body(body_data)

return json.dumps(body_data).encode(), filtered_headers

def _sanitize_provider_body(self, body_data: dict[str, Any]) -> dict[str, Any]:
"""Apply Codex-specific payload sanitization shared by all request paths."""

# Codex backend requires stream=true, always override
body_data["stream"] = True
body_data["store"] = False
Expand All @@ -305,11 +316,10 @@ async def prepare_provider_request(
input for input in list_input if input.get("type") != "item_reference"
]

#
# Remove any prefixed metadata fields that shouldn't be sent to the API
body_data = self._remove_metadata_fields(body_data)

return json.dumps(body_data).encode(), filtered_headers
return body_data

async def prepare_provider_headers(self, headers: dict[str, str]) -> dict[str, str]:
token_value = await self._resolve_access_token()
Expand Down Expand Up @@ -475,6 +485,10 @@ async def handle_streaming(
if not self.streaming_handler:
# Fallback to base behavior
return await super().handle_streaming(request, endpoint, **kwargs)
if self.mock_handler:
return await MockAdapter(self.mock_handler).handle_streaming(
request, endpoint, **kwargs
)

# Get context
ctx = request.state.context
Expand Down Expand Up @@ -655,9 +669,6 @@ def _request_body_is_encoded(self, headers: dict[str, str]) -> bool:
encoding = headers.get("content-encoding", "").strip().lower()
return bool(encoding and encoding != "identity")

def _should_apply_detection_payload(self) -> bool:
return bool(getattr(self.config, "inject_detection_payload", True))

def _detect_streaming_intent(self, body: bytes, headers: dict[str, str]) -> bool:
if self._request_body_is_encoded(headers):
accept = headers.get("accept", "").lower()
Expand All @@ -670,6 +681,9 @@ def _detect_streaming_intent(self, body: bytes, headers: dict[str, str]) -> bool
accept = headers.get("accept", "").lower()
return "text/event-stream" in accept

def _should_apply_detection_payload(self) -> bool:
return bool(getattr(self.config, "inject_detection_payload", True))

def _get_instructions(self) -> str:
if not self.detection_service:
return ""
Expand Down
28 changes: 21 additions & 7 deletions ccproxy/plugins/codex/detection_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,12 @@ def get_detected_prompts(self) -> DetectedPrompts:

data = self.get_cached_data()
prompts = data.prompts if data else DetectedPrompts()
if prompts.has_instructions() or prompts.has_system():
return prompts

fallback = self._safe_fallback_data()
if fallback is None:
return prompts

fallback_prompts = fallback.prompts
if fallback_prompts.has_instructions() or fallback_prompts.has_system():
return fallback_prompts

return prompts
return self._merge_detected_prompts(prompts, fallback.prompts)

def get_ignored_headers(self) -> list[str]:
"""Headers that should be ignored when forwarding CLI values."""
Expand Down Expand Up @@ -537,6 +531,26 @@ def _safe_fallback_data(self) -> CodexCacheData | None:
)
return None

@staticmethod
def _merge_detected_prompts(
prompts: DetectedPrompts, fallback: DetectedPrompts
) -> DetectedPrompts:
"""Merge partial prompt caches with fallback defaults."""

prompt_raw = prompts.raw if isinstance(prompts.raw, dict) else {}
fallback_raw = fallback.raw if isinstance(fallback.raw, dict) else {}
merged_raw = dict(fallback_raw)
merged_raw.update(prompt_raw)

instructions = prompts.instructions or fallback.instructions
system = prompts.system if prompts.system is not None else fallback.system

return DetectedPrompts(
instructions=instructions,
system=system,
raw=merged_raw,
)

def invalidate_cache(self) -> None:
"""Clear all cached detection data."""
# Clear the async cache for _get_codex_version
Expand Down
1 change: 1 addition & 0 deletions ccproxy/plugins/codex/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class CodexFactory(BaseProviderPluginFactory):
"""Factory for Codex provider plugin."""

cli_safe = False # Heavy provider plugin - not safe for CLI
use_mock_adapter_in_bypass_mode = False

# Plugin configuration via class attributes
plugin_name = "codex"
Expand Down
Loading
Loading