diff --git a/.branch-placeholder b/.branch-placeholder deleted file mode 100644 index b3a4252..0000000 --- a/.branch-placeholder +++ /dev/null @@ -1 +0,0 @@ -placeholder \ No newline at end of file diff --git a/dev-suite/src/api/events.py b/dev-suite/src/api/events.py index e57ad5a..5f4bc80 100644 --- a/dev-suite/src/api/events.py +++ b/dev-suite/src/api/events.py @@ -1,6 +1,7 @@ """Async event bus for real-time SSE streaming to the dashboard. Issue #35: SSE Event System -- Real-Time Task Streaming +Issue #80: Added TOOL_CALL event type for agent tool usage tracking The EventBus is a singleton that LangGraph nodes publish events to. Connected SSE clients each get their own asyncio.Queue for fan-out. @@ -49,6 +50,7 @@ class EventType(str, Enum): MEMORY_ADDED = "memory_added" LOG_LINE = "log_line" QA_ESCALATION = "qa_escalation" + TOOL_CALL = "tool_call" class SSEEvent(BaseModel): diff --git a/dev-suite/src/api/runner.py b/dev-suite/src/api/runner.py index 70567d8..8765f7e 100644 --- a/dev-suite/src/api/runner.py +++ b/dev-suite/src/api/runner.py @@ -1,6 +1,7 @@ """Async task runner bridging the FastAPI API to the LangGraph orchestrator. Issue #48: StateManager <-> Orchestrator bridge +Issue #80: Tool binding -- tools_config initialization, TOOL_CALL SSE events Uses LangGraph's astream() to iterate node completions and emit SSE events in real time. Runs entirely on the async event loop -- no threading needed. @@ -28,6 +29,7 @@ GraphState, WorkflowStatus, build_graph, + init_tools_config, MAX_RETRIES, TOKEN_BUDGET, ) @@ -92,6 +94,8 @@ class TaskRunner: def __init__(self): self._tasks: dict[str, asyncio.Task] = {} + # Fix 6: Track per-task dev tool call baselines for per-pass counting + self._dev_tool_baselines: dict[str, int] = {} def submit(self, task_id: str, description: str) -> None: """Submit a task for background execution.""" @@ -122,6 +126,7 @@ async def shutdown(self) -> None: if self._tasks: await asyncio.gather(*self._tasks.values(), return_exceptions=True) self._tasks.clear() + self._dev_tool_baselines.clear() logger.info("TaskRunner shutdown complete") @property @@ -133,6 +138,8 @@ async def _run_task(self, task_id: str, description: str) -> None: from .state import state_manager start_time = time.time() + # Fix 6: Initialize per-task dev tool baseline + self._dev_tool_baselines[task_id] = 0 try: await self._emit_progress(task_id, "task_started", None, f"Task started: {description[:100]}") @@ -142,6 +149,14 @@ async def _run_task(self, task_id: str, description: str) -> None: graph = build_graph() workflow = graph.compile() + # Initialize tools config (issue #80) + tools_config = init_tools_config() + n_tools = len(tools_config.get("configurable", {}).get("tools", [])) + if n_tools > 0: + await self._emit_log(f"[orchestrator] {n_tools} tools loaded for agents") + else: + await self._emit_log("[orchestrator] No tools configured (single-shot mode)") + initial_state: GraphState = { "task_description": description, "blueprint": None, @@ -154,13 +169,35 @@ async def _run_task(self, task_id: str, description: str) -> None: "memory_context": [], "trace": [], "parsed_files": [], + "tool_calls_log": [], + } + + stream_config = { + "recursion_limit": 25, + **tools_config, } prev_node = None - async for event in workflow.astream(initial_state, config={"recursion_limit": 25}): + prev_tool_count = 0 + async for event in workflow.astream(initial_state, config=stream_config): for node_name, node_output in event.items(): if node_name.startswith("__"): continue + + # Emit tool_call SSE events for any new tool calls (issue #80) + tool_calls_log = node_output.get("tool_calls_log", []) + if len(tool_calls_log) > prev_tool_count: + new_calls = tool_calls_log[prev_tool_count:] + for tc in new_calls: + await self._emit_tool_call( + task_id, + tc.get("agent", "unknown"), + tc.get("tool", "unknown"), + tc.get("success", True), + tc.get("result_preview", ""), + ) + prev_tool_count = len(tool_calls_log) + await self._handle_node_completion( task_id, node_name, node_output, state_manager, prev_node, ) @@ -181,9 +218,7 @@ async def _run_task(self, task_id: str, description: str) -> None: if task: task.status = TaskStatus.CANCELLED task.completed_at = datetime.now(timezone.utc) - for agent_id in ("arch", "dev", "qa"): - await state_manager.update_agent_status(agent_id, AgentStatus.IDLE) - await self._emit_complete(task_id, "cancelled", "Task cancelled by user") + await self._emit_complete(task_id, "cancelled", "Task was cancelled") except Exception as e: logger.error("Task %s failed with exception: %s", task_id, e, exc_info=True) @@ -197,6 +232,10 @@ async def _run_task(self, task_id: str, description: str) -> None: await self._emit_complete(task_id, "failed", f"Task failed: {e}") await self._emit_log(f"[orchestrator] ERROR: {e}") + finally: + # Fix 6: Clean up per-task baseline + self._dev_tool_baselines.pop(task_id, None) + async def _handle_node_completion(self, task_id, node_name, node_output, state_manager, prev_node): """Process a completed node and emit appropriate SSE events.""" if prev_node and prev_node in NODE_TO_AGENT: @@ -303,12 +342,23 @@ async def _handle_architect(self, task_id, output, task, state_manager): async def _handle_developer(self, task_id, output, task, state_manager): code = output.get("generated_code", "") task.generated_code = code + + # Fix 6: Use per-pass baseline to count only new tool calls from this dev pass + tool_calls_log = output.get("tool_calls_log", []) + baseline = self._dev_tool_baselines.get(task_id, 0) + dev_tool_calls = [tc for tc in tool_calls_log[baseline:] if tc.get("agent") == "developer"] + self._dev_tool_baselines[task_id] = len(tool_calls_log) + if code: action = f"Code generated ({len(code):,} chars)" + if dev_tool_calls: + action += f" using {len(dev_tool_calls)} tool call(s)" event_type = "code" retry = task.budget.retries_used if retry > 0: action = f"Retry {retry}/{task.budget.max_retries} -- code regenerated ({len(code):,} chars)" + if dev_tool_calls: + action += f" using {len(dev_tool_calls)} tool call(s)" event_type = "retry" await self._emit_log("[sandbox:locked] E2B micro-VM started (dev-sandbox)") else: @@ -362,6 +412,22 @@ async def _emit_log(self, message): except Exception: logger.debug("Failed to emit log_line", exc_info=True) + async def _emit_tool_call(self, task_id, agent, tool_name, success, result_preview): + """Emit a TOOL_CALL SSE event for dashboard tool usage tracking (issue #80).""" + try: + await event_bus.publish(SSEEvent( + type=EventType.TOOL_CALL, + data={ + "task_id": task_id, + "agent": agent, + "tool": tool_name, + "success": success, + "result_preview": result_preview[:100] if result_preview else "", + }, + )) + except Exception: + logger.debug("Failed to emit tool_call", exc_info=True) + # -- Singleton -- diff --git a/dev-suite/src/orchestrator.py b/dev-suite/src/orchestrator.py index 1ef969a..e1f0feb 100644 --- a/dev-suite/src/orchestrator.py +++ b/dev-suite/src/orchestrator.py @@ -3,19 +3,31 @@ This is the main entry point for the agent workflow. Implements the state machine with retry logic, token budgets, structured Blueprint passing, human escalation, code application, -and memory write-back (flush_memory node with mini-summarizer). +tool binding (issue #80), and memory write-back. + +Issue #80: Agent tool binding -- Dev and QA agents can now use +workspace tools (filesystem_read, filesystem_write, etc.) via +LangChain's bind_tools() + iterative tool execution loop. +Tools are passed via RunnableConfig["configurable"]["tools"]. """ +import asyncio import json import logging import os +import re from enum import Enum from pathlib import Path from typing import Any, Literal, TypedDict from dotenv import load_dotenv from langchain_anthropic import ChatAnthropic -from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_google_genai import ChatGoogleGenerativeAI from langgraph.graph import END, START, StateGraph from pydantic import BaseModel @@ -43,37 +55,24 @@ logger = logging.getLogger(__name__) -# -- Configuration -- - def _safe_int(env_key: str, default: int) -> int: - """Parse an integer from an env var, falling back to default on error.""" raw = os.getenv(env_key, str(default)) try: return int(raw) except (ValueError, TypeError): - logger.warning( - "%s=%r is not a valid integer, using default %d", - env_key, raw, default, - ) + logger.warning("%s=%r is not a valid integer, using default %d", env_key, raw, default) return default MAX_RETRIES = _safe_int("MAX_RETRIES", 3) TOKEN_BUDGET = _safe_int("TOKEN_BUDGET", 50000) +MAX_TOOL_TURNS = _safe_int("MAX_TOOL_TURNS", 10) -# -- Workspace -- - def _get_workspace_root() -> Path: - """Get the workspace root directory. - - Reads WORKSPACE_ROOT env var, falling back to current working directory. - """ raw = os.getenv("WORKSPACE_ROOT", ".") return Path(raw).resolve() -# -- Workflow State -- - class WorkflowStatus(str, Enum): PLANNING = "planning" BUILDING = "building" @@ -84,13 +83,6 @@ class WorkflowStatus(str, Enum): class GraphState(TypedDict, total=False): - """State that flows through the LangGraph state machine. - - D1 fix: Uses TypedDict (not Pydantic BaseModel) for reliable dict-merge - semantics in LangGraph. Fields present in a node's return dict replace - the existing value; fields absent are left unchanged. - """ - task_description: str blueprint: Blueprint | None generated_code: str @@ -104,13 +96,10 @@ class GraphState(TypedDict, total=False): trace: list[str] sandbox_result: SandboxResult | None parsed_files: list[dict] + tool_calls_log: list[dict] class AgentState(BaseModel): - """Pydantic model used at the boundary -- for constructing the initial - state and wrapping the final result with validation and attribute access. - """ - task_description: str = "" blueprint: Blueprint | None = None generated_code: str = "" @@ -124,43 +113,22 @@ class AgentState(BaseModel): trace: list[str] = [] sandbox_result: SandboxResult | None = None parsed_files: list[dict] = [] + tool_calls_log: list[dict] = [] -# -- LLM Initialization -- - def _get_architect_llm(): - """Gemini for the Architect agent (large context, planning only).""" - return ChatGoogleGenerativeAI( - model=os.getenv("ARCHITECT_MODEL", "gemini-3-flash-preview"), - google_api_key=os.getenv("GOOGLE_API_KEY"), - temperature=0.2, - ) + return ChatGoogleGenerativeAI(model=os.getenv("ARCHITECT_MODEL", "gemini-3-flash-preview"), google_api_key=os.getenv("GOOGLE_API_KEY"), temperature=0.2) def _get_developer_llm(): - """Claude for the Lead Dev agent (code execution).""" - return ChatAnthropic( - model=os.getenv("DEVELOPER_MODEL", "claude-sonnet-4-20250514"), - api_key=os.getenv("ANTHROPIC_API_KEY"), - temperature=0.1, - max_tokens=8192, - ) + return ChatAnthropic(model=os.getenv("DEVELOPER_MODEL", "claude-sonnet-4-20250514"), api_key=os.getenv("ANTHROPIC_API_KEY"), temperature=0.1, max_tokens=8192) def _get_qa_llm(): - """Claude for the QA agent (review and testing).""" - return ChatAnthropic( - model=os.getenv("QA_MODEL", "claude-sonnet-4-20250514"), - api_key=os.getenv("ANTHROPIC_API_KEY"), - temperature=0.0, - max_tokens=4096, - ) - + return ChatAnthropic(model=os.getenv("QA_MODEL", "claude-sonnet-4-20250514"), api_key=os.getenv("ANTHROPIC_API_KEY"), temperature=0.0, max_tokens=4096) -# -- Helpers -- def _extract_text_content(content: Any) -> str: - """Extract text from an LLM response's content field.""" if isinstance(content, str): return content if isinstance(content, list): @@ -177,7 +145,6 @@ def _extract_text_content(content: Any) -> str: def _extract_json(raw: str) -> dict: - """Extract a JSON object from LLM output text.""" text = raw.strip() try: return json.loads(text) @@ -200,14 +167,10 @@ def _extract_json(raw: str) -> dict: return json.loads(text[first_brace:last_brace + 1]) except json.JSONDecodeError: pass - raise json.JSONDecodeError( - f"No valid JSON found in response ({len(text)} chars): {text[:200]}...", - text, 0, - ) + raise json.JSONDecodeError(f"No valid JSON found in response ({len(text)} chars): {text[:200]}...", text, 0) def _extract_token_count(response: Any) -> int: - """Extract total token count from an LLM response.""" meta = getattr(response, "usage_metadata", None) if not meta: return 0 @@ -223,12 +186,10 @@ def _extract_token_count(response: Any) -> int: def _get_memory_store() -> MemoryStore: - """Get the memory store via factory (respects MEMORY_BACKEND env var).""" return create_memory_store() def _fetch_memory_context(task_description: str) -> list[str]: - """Query memory for relevant context across all tiers.""" try: store = _get_memory_store() results = store.query(task_description, n_results=10) @@ -238,7 +199,6 @@ def _fetch_memory_context(task_description: str) -> list[str]: def _infer_module(target_files: list[str]) -> str: - """Infer module name from target file paths.""" if not target_files: return "global" first_file = target_files[0] @@ -250,27 +210,132 @@ def _infer_module(target_files: list[str]) -> str: return "global" +# -- Tool Binding (Issue #80) -- + +# Fix 2: Removed github_create_pr -- non-idempotent write should not be in +# an iterative tool loop. PR creation belongs in a post-QA-pass step. +DEV_TOOL_NAMES = {"filesystem_read", "filesystem_write", "filesystem_list", "github_read_diff"} +QA_TOOL_NAMES = {"filesystem_read", "filesystem_list", "github_read_diff"} + +# Fix 4: Secret pattern regexes for sanitizing tool call previews +_SECRET_PATTERNS = [ + re.compile(r'(?:sk|pk|api|key|token|secret|password|bearer)[_-]?\w{10,}', re.IGNORECASE), + re.compile(r'(?:ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9_]{30,}'), + re.compile(r'(?:eyJ)[A-Za-z0-9_-]{20,}'), + re.compile(r'(?:AKIA|ASIA)[A-Z0-9]{16}'), +] + + +def _sanitize_preview(text: str, max_len: int = 200) -> str: + """Truncate and redact known secret patterns from tool call previews.""" + if not text: + return "" + sanitized = text + for pattern in _SECRET_PATTERNS: + sanitized = pattern.sub("[REDACTED]", sanitized) + if len(sanitized) > max_len: + sanitized = sanitized[:max_len] + "..." + return sanitized + + +def _get_agent_tools(config, allowed_names=None): + if not config: + return [] + configurable = config.get("configurable", {}) + tools = configurable.get("tools", []) + if not tools: + return [] + if allowed_names is None: + return list(tools) + return [t for t in tools if t.name in allowed_names] + + +async def _execute_tool_call(tool_call, tools): + """Execute a single tool call. Uses ainvoke/invoke public API (Fix 3).""" + tool_name = tool_call.get("name", "") + tool_args = tool_call.get("args", {}) + tool_id = tool_call.get("id", "unknown") + tool_map = {t.name: t for t in tools} + tool = tool_map.get(tool_name) + if not tool: + return ToolMessage(content=f"Error: Tool '{tool_name}' not found. Available: {list(tool_map.keys())}", tool_call_id=tool_id) + try: + # Fix 3: Use public ainvoke/invoke API instead of tool.coroutine. + # ainvoke handles input validation, callbacks, and config propagation. + if hasattr(tool, "ainvoke"): + result = await tool.ainvoke(tool_args) + else: + result = tool.invoke(tool_args) + return ToolMessage(content=str(result), tool_call_id=tool_id) + except Exception as e: + logger.warning("[TOOLS] Tool %s failed: %s", tool_name, e) + return ToolMessage(content=f"Error executing {tool_name}: {type(e).__name__}: {e}", tool_call_id=tool_id) + + +async def _run_tool_loop(llm_with_tools, messages, tools, max_turns=MAX_TOOL_TURNS, tokens_used=0, trace=None, agent_name="agent"): + if trace is None: + trace = [] + tool_calls_log = [] + current_messages = list(messages) + # Fix 5: Guard against max_turns <= 0 to prevent unbound response variable + if max_turns <= 0: + logger.warning("[%s] max_turns=%d, skipping tool loop", agent_name.upper(), max_turns) + trace.append(f"{agent_name}: tool loop skipped (max_turns={max_turns})") + last_msg = current_messages[-1] if current_messages else AIMessage(content="") + return last_msg, tokens_used, tool_calls_log + for turn in range(max_turns): + response = await llm_with_tools.ainvoke(current_messages) + tokens_used += _extract_token_count(response) + tool_calls = getattr(response, "tool_calls", None) or [] + if not tool_calls: + trace.append(f"{agent_name}: tool loop done after {turn} tool turn(s)") + return response, tokens_used, tool_calls_log + trace.append(f"{agent_name}: turn {turn + 1} -- {len(tool_calls)} tool call(s): {', '.join(tc.get('name', '?') for tc in tool_calls)}") + logger.info("[%s] Tool turn %d: %d calls", agent_name.upper(), turn + 1, len(tool_calls)) + current_messages.append(response) + for tc in tool_calls: + tool_msg = await _execute_tool_call(tc, tools) + current_messages.append(tool_msg) + # Fix 4: Sanitize previews before persisting to tool_calls_log + tool_calls_log.append({"agent": agent_name, "turn": turn + 1, "tool": tc.get("name", "unknown"), "args_preview": _sanitize_preview(str(tc.get("args", {}))), "result_preview": _sanitize_preview(str(tool_msg.content)), "success": not tool_msg.content.startswith("Error")}) + trace.append(f"{agent_name}: tool loop hit max turns ({max_turns})") + logger.warning("[%s] Hit max tool turns (%d)", agent_name.upper(), max_turns) + return response, tokens_used, tool_calls_log + + +def _run_async(coro): + """Run an async coroutine from sync context for the tool loop. + + Design note (re: CodeRabbit #8): This is intentional. developer_node and + qa_node are sync functions that use _run_async() to bridge into the async + tool loop. run_task() uses workflow.invoke() (sync). Converting everything + to async would cascade changes to CLI, tests, and callers with no benefit + since LangGraph handles sync nodes with internal async bridges fine. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + else: + return asyncio.run(coro) + + # -- Node Functions -- def architect_node(state: GraphState) -> dict: - """Architect: generates a structured Blueprint from the task description.""" trace = list(state.get("trace", [])) trace.append("architect: starting planning") - retry_count = state.get("retry_count", 0) tokens_used = state.get("tokens_used", 0) - logger.info("[ARCH] retry_count=%d, tokens_used=%d, status=%s", - retry_count, tokens_used, state.get("status", "unknown")) - + logger.info("[ARCH] retry_count=%d, tokens_used=%d, status=%s", retry_count, tokens_used, state.get("status", "unknown")) memory_context = _fetch_memory_context(state.get("task_description", "")) - memory_block = "" if memory_context: - memory_block = ( - "\n\nProject context from memory:\n" - + "\n".join(f"- {c}" for c in memory_context) - ) - + memory_block = "\n\nProject context from memory:\n" + "\n".join(f"- {c}" for c in memory_context) system_prompt = f"""You are the Architect agent. Your job is to create a structured Blueprint for a coding task. You NEVER write code yourself. @@ -284,7 +349,6 @@ def architect_node(state: GraphState) -> dict: }} Do not include any text before or after the JSON.{memory_block}""" - user_msg = state.get("task_description", "") failure_report = state.get("failure_report") if failure_report and failure_report.is_architectural: @@ -293,17 +357,9 @@ def architect_node(state: GraphState) -> dict: if failure_report.failed_files: user_msg += f"Failed files: {', '.join(failure_report.failed_files)}\n" user_msg += f"Recommendation: {failure_report.recommendation}\n" - user_msg += ( - "\nGenerate a COMPLETELY NEW Blueprint. Do not patch the old one. " - "The previous target_files or approach was wrong." - ) - + user_msg += "\nGenerate a COMPLETELY NEW Blueprint. Do not patch the old one. The previous target_files or approach was wrong." llm = _get_architect_llm() - response = llm.invoke([ - SystemMessage(content=system_prompt), - HumanMessage(content=user_msg), - ]) - + response = llm.invoke([SystemMessage(content=system_prompt), HumanMessage(content=user_msg)]) try: raw = _extract_text_content(response.content) blueprint_data = _extract_json(raw) @@ -311,47 +367,51 @@ def architect_node(state: GraphState) -> dict: except (json.JSONDecodeError, Exception) as e: trace.append(f"architect: failed to parse blueprint: {e}") logger.error("[ARCH] Blueprint parse failed: %s", e) - return { - "status": WorkflowStatus.FAILED, - "error_message": f"Architect failed to produce valid Blueprint: {e}", - "trace": trace, - "memory_context": memory_context, - } - + return {"status": WorkflowStatus.FAILED, "error_message": f"Architect failed to produce valid Blueprint: {e}", "trace": trace, "memory_context": memory_context} trace.append(f"architect: blueprint created for {len(blueprint.target_files)} files") tokens_used = tokens_used + _extract_token_count(response) logger.info("[ARCH] done. tokens_used now=%d", tokens_used) - - return { - "blueprint": blueprint, - "status": WorkflowStatus.BUILDING, - "tokens_used": tokens_used, - "trace": trace, - "memory_context": memory_context, - } + return {"blueprint": blueprint, "status": WorkflowStatus.BUILDING, "tokens_used": tokens_used, "trace": trace, "memory_context": memory_context} -def developer_node(state: GraphState) -> dict: - """Lead Dev: executes the Blueprint and generates code.""" +def developer_node(state: GraphState, config: dict | None = None) -> dict: + """Lead Dev: executes the Blueprint and generates code. Issue #80: tool binding support.""" trace = list(state.get("trace", [])) trace.append("developer: starting build") memory_writes = list(state.get("memory_writes", [])) - + tool_calls_log = list(state.get("tool_calls_log", [])) retry_count = state.get("retry_count", 0) tokens_used = state.get("tokens_used", 0) - logger.info("[DEV] retry_count=%d, tokens_used=%d, status=%s", - retry_count, tokens_used, state.get("status", "unknown")) - + logger.info("[DEV] retry_count=%d, tokens_used=%d, status=%s", retry_count, tokens_used, state.get("status", "unknown")) blueprint = state.get("blueprint") if not blueprint: trace.append("developer: no blueprint provided") - return { - "status": WorkflowStatus.FAILED, - "error_message": "Developer received no Blueprint", - "trace": trace, - } - - system_prompt = """You are the Lead Dev agent. You receive a structured Blueprint + return {"status": WorkflowStatus.FAILED, "error_message": "Developer received no Blueprint", "trace": trace} + tools = _get_agent_tools(config, DEV_TOOL_NAMES) + has_tools = len(tools) > 0 + if has_tools: + tool_names = [t.name for t in tools] + trace.append(f"developer: {len(tools)} tools available: {', '.join(tool_names)}") + system_prompt = """You are the Lead Dev agent. You receive a structured Blueprint and implement it using the tools available to you. + +WORKFLOW: +1. Use filesystem_read to examine the existing files listed in target_files. +2. Use filesystem_list to explore the project directory structure if needed. +3. Implement the changes described in the Blueprint. +4. Use filesystem_write to write each file. +5. After writing all files, provide a summary of what you implemented. + +IMPORTANT: +- Use filesystem_write for EACH file you need to create or modify. +- Follow the Blueprint exactly. Respect all constraints. +- Write clean, well-documented code. +- After completing all file writes, respond with a text summary. +- Also include the complete code in your final response using the format: + # --- FILE: path/to/file.py --- + (file contents)""" + else: + trace.append("developer: no tools available, using single-shot mode") + system_prompt = """You are the Lead Dev agent. You receive a structured Blueprint and write the code to implement it. Respond with the complete code implementation. Include file paths as comments @@ -360,9 +420,7 @@ def developer_node(state: GraphState) -> dict: Follow the Blueprint exactly. Respect all constraints. Write clean, well-documented code.""" - user_msg = f"Blueprint:\n{blueprint.model_dump_json(indent=2)}" - failure_report = state.get("failure_report") if failure_report and not failure_report.is_architectural: user_msg += "\n\nPREVIOUS ATTEMPT FAILED:\n" @@ -372,91 +430,47 @@ def developer_node(state: GraphState) -> dict: user_msg += f"Failed files: {', '.join(failure_report.failed_files)}\n" user_msg += f"Recommendation: {failure_report.recommendation}\n" user_msg += "\nFix the issues and regenerate the code." - + messages = [SystemMessage(content=system_prompt), HumanMessage(content=user_msg)] llm = _get_developer_llm() - response = llm.invoke([ - SystemMessage(content=system_prompt), - HumanMessage(content=user_msg), - ]) - + if has_tools: + llm_with_tools = llm.bind_tools(tools) + response, tokens_used, new_tool_log = _run_async(_run_tool_loop(llm_with_tools, messages, tools, max_turns=MAX_TOOL_TURNS, tokens_used=tokens_used, trace=trace, agent_name="developer")) + tool_calls_log.extend(new_tool_log) + else: + response = llm.invoke(messages) + tokens_used += _extract_token_count(response) content = _extract_text_content(response.content) trace.append(f"developer: code generated ({len(content)} chars)") - tokens_used = tokens_used + _extract_token_count(response) logger.info("[DEV] done. tokens_used now=%d", tokens_used) + memory_writes.append({"content": f"Implemented blueprint {blueprint.task_id}: {blueprint.instructions[:200]}", "tier": "l1", "module": _infer_module(blueprint.target_files), "source_agent": "developer", "confidence": 1.0, "sandbox_origin": "locked-down", "related_files": ",".join(blueprint.target_files), "task_id": blueprint.task_id}) + return {"generated_code": content, "status": WorkflowStatus.REVIEWING, "tokens_used": tokens_used, "trace": trace, "memory_writes": memory_writes, "tool_calls_log": tool_calls_log} - memory_writes.append({ - "content": f"Implemented blueprint {blueprint.task_id}: {blueprint.instructions[:200]}", - "tier": "l1", - "module": _infer_module(blueprint.target_files), - "source_agent": "developer", - "confidence": 1.0, - "sandbox_origin": "locked-down", - "related_files": ",".join(blueprint.target_files), - "task_id": blueprint.task_id, - }) - - return { - "generated_code": content, - "status": WorkflowStatus.REVIEWING, - "tokens_used": tokens_used, - "trace": trace, - "memory_writes": memory_writes, - } - - -# -- Code Application -- def apply_code_node(state: GraphState) -> dict: - """Parse generated code into files, write to workspace, prepare for sandbox. - - This node bridges the gap between the Dev agent's text output and the - filesystem. It: - 1. Parses generated_code using # --- FILE: path --- markers - 2. Validates paths for workspace containment (security) - 3. Writes each file to the workspace directory - 4. Stores the parsed file map in state for sandbox_validate to load - - The node never changes workflow status -- it's a pass-through that - enriches state with parsed_files. - """ trace = list(state.get("trace", [])) trace.append("apply_code: starting") - generated_code = state.get("generated_code", "") blueprint = state.get("blueprint") - if not generated_code: trace.append("apply_code: no generated_code -- skipping") return {"parsed_files": [], "trace": trace} - if not blueprint: trace.append("apply_code: no blueprint -- skipping") return {"parsed_files": [], "trace": trace} - - # Step 1: Parse generated code into individual files try: parsed = parse_generated_code(generated_code) except CodeParserError as e: logger.warning("[APPLY_CODE] Parse error: %s", e) trace.append(f"apply_code: parse error -- {e}") return {"parsed_files": [], "trace": trace} - if not parsed: trace.append("apply_code: parser returned no files") return {"parsed_files": [], "trace": trace} - - # Step 2: Validate paths for workspace containment workspace_root = _get_workspace_root() safe_files = validate_paths_for_workspace(parsed, workspace_root) - if len(safe_files) < len(parsed): skipped = len(parsed) - len(safe_files) - trace.append( - f"apply_code: WARNING -- {skipped} file(s) skipped " - f"due to path validation" - ) - - # Step 3: Write files to workspace + trace.append(f"apply_code: WARNING -- {skipped} file(s) skipped due to path validation") total_chars = 0 written_count = 0 for pf in safe_files: @@ -467,88 +481,39 @@ def apply_code_node(state: GraphState) -> dict: written_count += 1 total_chars += len(pf.content) except Exception as e: - logger.warning( - "[APPLY_CODE] Failed to write %s: %s", pf.path, e - ) + logger.warning("[APPLY_CODE] Failed to write %s: %s", pf.path, e) trace.append(f"apply_code: failed to write {pf.path} -- {e}") - - trace.append( - f"apply_code: wrote {written_count} files " - f"({total_chars:,} chars total) to workspace" - ) - logger.info( - "[APPLY_CODE] Wrote %d files (%d chars) to %s", - written_count, total_chars, workspace_root, - ) - - # Step 4: Store parsed files in state for sandbox loading - parsed_files_data = [ - {"path": pf.path, "content": pf.content} - for pf in safe_files - ] - + trace.append(f"apply_code: wrote {written_count} files ({total_chars:,} chars total) to workspace") + logger.info("[APPLY_CODE] Wrote %d files (%d chars) to %s", written_count, total_chars, workspace_root) + parsed_files_data = [{"path": pf.path, "content": pf.content} for pf in safe_files] return {"parsed_files": parsed_files_data, "trace": trace} -def qa_node(state: GraphState) -> dict: - """QA: reviews the generated code and produces a structured FailureReport.""" +def qa_node(state: GraphState, config: dict | None = None) -> dict: + """QA: reviews the generated code. Issue #80: read-only tool access.""" trace = list(state.get("trace", [])) trace.append("qa: starting review") memory_writes = list(state.get("memory_writes", [])) - + tool_calls_log = list(state.get("tool_calls_log", [])) retry_count = state.get("retry_count", 0) tokens_used = state.get("tokens_used", 0) - logger.info("[QA] retry_count=%d, tokens_used=%d, status=%s", - retry_count, tokens_used, state.get("status", "unknown")) - + logger.info("[QA] retry_count=%d, tokens_used=%d, status=%s", retry_count, tokens_used, state.get("status", "unknown")) generated_code = state.get("generated_code", "") blueprint = state.get("blueprint") if not generated_code or not blueprint: trace.append("qa: missing code or blueprint") - return { - "status": WorkflowStatus.FAILED, - "error_message": "QA received no code or blueprint to review", - "trace": trace, - } - - system_prompt = ( - "You are the QA agent. You review code against a Blueprint's " - "acceptance criteria.\n\n" - "Respond with ONLY a valid JSON object matching this schema:\n" - "{\n" - ' "task_id": "string (from the Blueprint)",\n' - ' "status": "pass" or "fail" or "escalate",\n' - ' "tests_passed": number,\n' - ' "tests_failed": number,\n' - ' "errors": ["list of specific error descriptions"],\n' - ' "failed_files": ["list of files with issues"],\n' - ' "is_architectural": true/false,\n' - ' "failure_type": "code" or "architectural" or null (if pass),\n' - ' "recommendation": "what to fix or why it should escalate"\n' - "}\n\n" - "FAILURE CLASSIFICATION (critical for correct routing):\n\n" - 'Set failure_type to "code" (status: "fail") when:\n' - "- Implementation has bugs, syntax errors, or type errors\n" - "- Tests fail due to logic errors in the code\n" - "- Code does not follow the Blueprint's constraints\n" - "- Missing error handling or edge cases\n" - "Action: Lead Dev will retry with the same Blueprint.\n\n" - 'Set failure_type to "architectural" (status: "escalate") when:\n' - "- Blueprint targets the WRONG files (code is in the wrong place)\n" - "- A required dependency or import is missing from the Blueprint\n" - "- The design approach is fundamentally flawed\n" - "- Acceptance criteria are impossible to meet with current targets\n" - "- The task requires files not listed in target_files\n" - "Action: Architect will generate a completely NEW Blueprint.\n\n" - "Be strict but fair. Only pass code that meets ALL acceptance " - "criteria.\n" - "Do not include any text before or after the JSON." - ) - + return {"status": WorkflowStatus.FAILED, "error_message": "QA received no code or blueprint to review", "trace": trace} + tools = _get_agent_tools(config, QA_TOOL_NAMES) + has_tools = len(tools) > 0 + if has_tools: + tool_names = [t.name for t in tools] + trace.append(f"qa: {len(tools)} tools available: {', '.join(tool_names)}") + system_prompt = "You are the QA agent. You review code against a Blueprint's acceptance criteria.\n\n" + if has_tools: + system_prompt += "You have tools to read files from the workspace. Use filesystem_read to inspect the actual files that were written, and filesystem_list to check the project structure.\n\n" + system_prompt += 'Respond with ONLY a valid JSON object matching this schema:\n{\n "task_id": "string (from the Blueprint)",\n "status": "pass" or "fail" or "escalate",\n "tests_passed": number,\n "tests_failed": number,\n "errors": ["list of specific error descriptions"],\n "failed_files": ["list of files with issues"],\n "is_architectural": true/false,\n "failure_type": "code" or "architectural" or null (if pass),\n "recommendation": "what to fix or why it should escalate"\n}\n\nFAILURE CLASSIFICATION (critical for correct routing):\n\nSet failure_type to "code" (status: "fail") when:\n- Implementation has bugs, syntax errors, or type errors\n- Tests fail due to logic errors in the code\n- Code does not follow the Blueprint\'s constraints\n- Missing error handling or edge cases\nAction: Lead Dev will retry with the same Blueprint.\n\nSet failure_type to "architectural" (status: "escalate") when:\n- Blueprint targets the WRONG files (code is in the wrong place)\n- A required dependency or import is missing from the Blueprint\n- The design approach is fundamentally flawed\n- Acceptance criteria are impossible to meet with current targets\n- The task requires files not listed in target_files\nAction: Architect will generate a completely NEW Blueprint.\n\nBe strict but fair. Only pass code that meets ALL acceptance criteria.\nDo not include any text before or after the JSON.' bp_json = blueprint.model_dump_json(indent=2) user_msg = f"Blueprint:\n{bp_json}\n\nGenerated Code:\n{generated_code}" - - # Include sandbox validation results if available sandbox_result = state.get("sandbox_result") if sandbox_result is not None: user_msg += "\n\nSandbox Validation Results:\n" @@ -561,194 +526,78 @@ def qa_node(state: GraphState) -> dict: user_msg += f" Errors: {', '.join(sandbox_result.errors)}\n" if sandbox_result.output_summary: user_msg += f" Output:\n{sandbox_result.output_summary}\n" - user_msg += ( - "\nUse these real test results to inform your review. " - "If sandbox tests passed, weigh that heavily in your verdict." - ) + user_msg += "\nUse these real test results to inform your review. If sandbox tests passed, weigh that heavily in your verdict." else: - user_msg += ( - "\n\nNote: Sandbox validation was not available for this review. " - "Evaluate the code based on the Blueprint criteria only." - ) - + user_msg += "\n\nNote: Sandbox validation was not available for this review. Evaluate the code based on the Blueprint criteria only." + messages = [SystemMessage(content=system_prompt), HumanMessage(content=user_msg)] llm = _get_qa_llm() - response = llm.invoke([ - SystemMessage(content=system_prompt), - HumanMessage(content=user_msg), - ]) - + if has_tools: + llm_with_tools = llm.bind_tools(tools) + response, tokens_used, new_tool_log = _run_async(_run_tool_loop(llm_with_tools, messages, tools, max_turns=5, tokens_used=tokens_used, trace=trace, agent_name="qa")) + tool_calls_log.extend(new_tool_log) + else: + response = llm.invoke(messages) + tokens_used += _extract_token_count(response) try: raw = _extract_text_content(response.content) report_data = _extract_json(raw) failure_report = FailureReport(**report_data) except (json.JSONDecodeError, Exception) as e: trace.append(f"qa: failed to parse report: {e}") - return { - "status": WorkflowStatus.FAILED, - "error_message": f"QA failed to produce valid report: {e}", - "trace": trace, - } - - trace.append( - f"qa: verdict={failure_report.status}, " - f"passed={failure_report.tests_passed}, " - f"failed={failure_report.tests_failed}" - ) - tokens_used = tokens_used + _extract_token_count(response) - + return {"status": WorkflowStatus.FAILED, "error_message": f"QA failed to produce valid report: {e}", "trace": trace} + trace.append(f"qa: verdict={failure_report.status}, passed={failure_report.tests_passed}, failed={failure_report.tests_failed}") if failure_report.status == "pass": status = WorkflowStatus.PASSED - memory_writes.append({ - "content": f"QA passed for {blueprint.task_id}: {failure_report.tests_passed} tests passed", - "tier": "l2", - "module": _infer_module(blueprint.target_files), - "source_agent": "qa", - "confidence": 1.0, - "sandbox_origin": "locked-down", - "related_files": ",".join(blueprint.target_files), - "task_id": blueprint.task_id, - }) + memory_writes.append({"content": f"QA passed for {blueprint.task_id}: {failure_report.tests_passed} tests passed", "tier": "l2", "module": _infer_module(blueprint.target_files), "source_agent": "qa", "confidence": 1.0, "sandbox_origin": "locked-down", "related_files": ",".join(blueprint.target_files), "task_id": blueprint.task_id}) elif failure_report.is_architectural: status = WorkflowStatus.ESCALATED - memory_writes.append({ - "content": f"Architectural issue in {blueprint.task_id}: {failure_report.recommendation}", - "tier": "l0-discovered", - "module": _infer_module(blueprint.target_files), - "source_agent": "qa", - "confidence": 0.85, - "sandbox_origin": "locked-down", - "related_files": ",".join(failure_report.failed_files), - "task_id": blueprint.task_id, - }) + memory_writes.append({"content": f"Architectural issue in {blueprint.task_id}: {failure_report.recommendation}", "tier": "l0-discovered", "module": _infer_module(blueprint.target_files), "source_agent": "qa", "confidence": 0.85, "sandbox_origin": "locked-down", "related_files": ",".join(failure_report.failed_files), "task_id": blueprint.task_id}) else: status = WorkflowStatus.REVIEWING - new_retry_count = retry_count + (1 if failure_report.status != "pass" else 0) - logger.info("[QA] verdict=%s, retry_count %d->%d, tokens_used=%d", - failure_report.status, retry_count, new_retry_count, tokens_used) - - return { - "failure_report": failure_report, - "status": status, - "tokens_used": tokens_used, - "retry_count": new_retry_count, - "trace": trace, - "memory_writes": memory_writes, - } - - -# -- Sandbox Validation -- - -def _run_sandbox_validation( - commands: list[str], - template: str | None, - generated_code: str, - parsed_files: list[dict] | None = None, - timeout: int = 120, -) -> SandboxResult | None: - """Execute validation commands in an E2B sandbox. - - Returns None if E2B_API_KEY is not configured (graceful skip). - Raises on unexpected errors so the caller can log them. - """ + logger.info("[QA] verdict=%s, retry_count %d->%d, tokens_used=%d", failure_report.status, retry_count, new_retry_count, tokens_used) + return {"failure_report": failure_report, "status": status, "tokens_used": tokens_used, "retry_count": new_retry_count, "trace": trace, "memory_writes": memory_writes, "tool_calls_log": tool_calls_log} + + +def _run_sandbox_validation(commands, template, generated_code, parsed_files=None, timeout=120): api_key = os.getenv("E2B_API_KEY") if not api_key: return None - runner = E2BRunner(api_key=api_key, default_timeout=timeout) - - # Convert parsed_files to the project_files dict format expected by run_tests project_files = None if parsed_files: project_files = {pf["path"]: pf["content"] for pf in parsed_files} - - # Build a compound command that runs all validations sequentially compound_cmd = " && ".join(commands) - - return runner.run_tests( - test_command=compound_cmd, - project_files=project_files, - timeout=timeout, - template=template, - ) + return runner.run_tests(test_command=compound_cmd, project_files=project_files, timeout=timeout, template=template) def sandbox_validate_node(state: GraphState) -> dict: - """Run sandbox validation on generated code before QA review. - - Selects the appropriate template and validation commands based on - the Blueprint's target_files, then executes them in an E2B sandbox. - - Now loads parsed_files into the sandbox before running commands, - so validation runs against real code instead of an empty project. - - Behavior: - - Optional: if E2B_API_KEY is not set, logs a warning and skips. - - Errors are caught and logged, never crash the workflow. - - SandboxResult is stored in state for QA to consume. - """ trace = list(state.get("trace", [])) trace.append("sandbox_validate: starting") - blueprint = state.get("blueprint") if not blueprint: trace.append("sandbox_validate: no blueprint -- skipping") return {"sandbox_result": None, "trace": trace} - - # Determine what to validate plan = get_validation_plan(blueprint.target_files) trace.append(f"sandbox_validate: {plan.description}") - if not plan.commands: trace.append("sandbox_validate: no code validation needed -- skipping") return {"sandbox_result": None, "trace": trace} - template_label = plan.template or "default" - trace.append( - f"sandbox_validate: template={template_label}, " - f"commands={len(plan.commands)}" - ) - + trace.append(f"sandbox_validate: template={template_label}, commands={len(plan.commands)}") generated_code = state.get("generated_code", "") parsed_files = state.get("parsed_files", []) - if parsed_files: - trace.append( - f"sandbox_validate: loading {len(parsed_files)} files into sandbox" - ) - + trace.append(f"sandbox_validate: loading {len(parsed_files)} files into sandbox") try: - result = _run_sandbox_validation( - commands=plan.commands, - template=plan.template, - generated_code=generated_code, - parsed_files=parsed_files if parsed_files else None, - ) - + result = _run_sandbox_validation(commands=plan.commands, template=plan.template, generated_code=generated_code, parsed_files=parsed_files if parsed_files else None) if result is None: - # No API key -- warn loudly - logger.warning( - "[SANDBOX] E2B_API_KEY not configured -- sandbox validation " - "SKIPPED. QA will review without real test results. " - "Set E2B_API_KEY in .env to enable sandbox validation." - ) - trace.append( - "sandbox_validate: WARNING -- E2B_API_KEY not configured, " - "sandbox validation skipped" - ) + logger.warning("[SANDBOX] E2B_API_KEY not configured -- sandbox validation SKIPPED.") + trace.append("sandbox_validate: WARNING -- E2B_API_KEY not configured, sandbox validation skipped") return {"sandbox_result": None, "trace": trace} - - trace.append( - f"sandbox_validate: exit_code={result.exit_code}, " - f"passed={result.tests_passed}, failed={result.tests_failed}" - ) - logger.info( - "[SANDBOX] Validation complete: exit=%d, passed=%s, failed=%s", - result.exit_code, result.tests_passed, result.tests_failed, - ) - + trace.append(f"sandbox_validate: exit_code={result.exit_code}, passed={result.tests_passed}, failed={result.tests_failed}") + logger.info("[SANDBOX] Validation complete: exit=%d, passed=%s, failed=%s", result.exit_code, result.tests_passed, result.tests_failed) return {"sandbox_result": result, "trace": trace} - except Exception as e: logger.warning("[SANDBOX] Validation failed with error: %s", e) trace.append(f"sandbox_validate: error -- {type(e).__name__}: {e}") @@ -756,24 +605,14 @@ def sandbox_validate_node(state: GraphState) -> dict: def flush_memory_node(state: GraphState) -> dict: - """Flush accumulated memory_writes to the memory store. - - Runs the mini-summarizer to deduplicate/compress writes, - then persists to Chroma. Gracefully degrades if store is unreachable. - """ trace = list(state.get("trace", [])) trace.append("flush_memory: starting") - memory_writes = state.get("memory_writes", []) if not memory_writes: trace.append("flush_memory: no writes to flush") return {"trace": trace} - consolidated = summarize_writes_sync(memory_writes) - trace.append( - f"flush_memory: summarizer {len(memory_writes)} -> {len(consolidated)} entries" - ) - + trace.append(f"flush_memory: summarizer {len(memory_writes)} -> {len(consolidated)} entries") try: store = _get_memory_store() written = 0 @@ -786,93 +625,47 @@ def flush_memory_node(state: GraphState) -> dict: sandbox_origin = entry.get("sandbox_origin", "none") related_files = entry.get("related_files", "") task_id = entry.get("task_id", "") - if tier == "l0-discovered": - store.add_l0_discovered( - content, module=module, source_agent=source_agent, - confidence=confidence, sandbox_origin=sandbox_origin, - related_files=related_files, task_id=task_id, - ) + store.add_l0_discovered(content, module=module, source_agent=source_agent, confidence=confidence, sandbox_origin=sandbox_origin, related_files=related_files, task_id=task_id) elif tier == "l2": - store.add_l2( - content, module=module, source_agent=source_agent, - related_files=related_files, task_id=task_id, - ) + store.add_l2(content, module=module, source_agent=source_agent, related_files=related_files, task_id=task_id) else: - store.add_l1( - content, module=module, source_agent=source_agent, - confidence=confidence, sandbox_origin=sandbox_origin, - related_files=related_files, task_id=task_id, - ) + store.add_l1(content, module=module, source_agent=source_agent, confidence=confidence, sandbox_origin=sandbox_origin, related_files=related_files, task_id=task_id) written += 1 - trace.append(f"flush_memory: wrote {written} entries to store") logger.info("[FLUSH] Wrote %d memory entries", written) except Exception as e: trace.append(f"flush_memory: store write failed: {e}") logger.warning("[FLUSH] Memory store write failed: %s", e) - return {"trace": trace} -# -- Routing Functions -- - def route_after_qa(state: GraphState) -> Literal["flush_memory", "developer", "architect", "__end__"]: - """Decide where to go after QA review.""" status = state.get("status", WorkflowStatus.FAILED) retry_count = state.get("retry_count", 0) tokens_used = state.get("tokens_used", 0) - - logger.info( - "[ROUTER] status=%s, retry_count=%d, tokens_used=%d, " - "max_retries=%d, token_budget=%d", - status, retry_count, tokens_used, MAX_RETRIES, TOKEN_BUDGET, - ) - + logger.info("[ROUTER] status=%s, retry_count=%d, tokens_used=%d, max_retries=%d, token_budget=%d", status, retry_count, tokens_used, MAX_RETRIES, TOKEN_BUDGET) if status == WorkflowStatus.PASSED: - logger.info("[ROUTER] -> flush_memory (passed)") return "flush_memory" - if status == WorkflowStatus.FAILED: - logger.info("[ROUTER] -> END (failed: %s)", state.get("error_message", "")) return END - if retry_count >= MAX_RETRIES: - logger.info("[ROUTER] -> flush_memory (max retries, saving what we have)") return "flush_memory" if tokens_used >= TOKEN_BUDGET: - logger.info("[ROUTER] -> flush_memory (token budget, saving what we have)") return "flush_memory" - if status == WorkflowStatus.ESCALATED: - logger.info("[ROUTER] -> architect (escalation)") return "architect" - else: - logger.info("[ROUTER] -> developer (retry)") - return "developer" - + return "developer" -# -- Graph Construction -- def build_graph() -> StateGraph: - """Build the LangGraph state machine. - - Flow: - START -> architect -> developer -> apply_code -> sandbox_validate -> qa -> (conditional) - -> pass: flush_memory -> END - -> fail: developer (retry) - -> escalate: architect (re-plan) - -> budget/retries exhausted: flush_memory -> END - """ graph = StateGraph(GraphState) - graph.add_node("architect", architect_node) graph.add_node("developer", developer_node) graph.add_node("apply_code", apply_code_node) graph.add_node("sandbox_validate", sandbox_validate_node) graph.add_node("qa", qa_node) graph.add_node("flush_memory", flush_memory_node) - graph.add_edge(START, "architect") graph.add_edge("architect", "developer") graph.add_edge("developer", "apply_code") @@ -880,86 +673,44 @@ def build_graph() -> StateGraph: graph.add_edge("sandbox_validate", "qa") graph.add_conditional_edges("qa", route_after_qa) graph.add_edge("flush_memory", END) - return graph def create_workflow(): - """Create and compile the workflow. Ready to invoke.""" - graph = build_graph() - return graph.compile() - - -# -- Entry Point -- - -def run_task( - task_description: str, - enable_tracing: bool = True, - session_id: str | None = None, - tags: list[str] | None = None, -) -> AgentState: - """Run a task through the full agent workflow. - - Args: - task_description: What to build. - enable_tracing: Whether to send traces to Langfuse. - session_id: Optional session ID for grouping related traces. - tags: Optional tags for filtering traces in Langfuse UI. - """ - trace_config = create_trace_config( - enabled=enable_tracing, - task_description=task_description, - session_id=session_id, - tags=tags or ["orchestrator"], - metadata={ - "max_retries": str(MAX_RETRIES), - "token_budget": str(TOKEN_BUDGET), - }, - ) + return build_graph().compile() + + +def init_tools_config(workspace_root=None): + if workspace_root is None: + workspace_root = _get_workspace_root() + try: + from .tools import create_provider, get_tools, load_mcp_config + config_path = Path(workspace_root) / "mcp-config.json" + if not config_path.is_file(): + logger.info("[TOOLS] No mcp-config.json found at %s, tools disabled", config_path) + return {"configurable": {"tools": []}} + mcp_config = load_mcp_config(str(config_path)) + provider = create_provider(mcp_config, workspace_root) + tools = get_tools(provider) + logger.info("[TOOLS] Loaded %d tools from provider", len(tools)) + return {"configurable": {"tools": tools}} + except Exception as e: + logger.warning("[TOOLS] Failed to initialize tools: %s", e) + return {"configurable": {"tools": []}} - workflow = create_workflow() - initial_state: GraphState = { - "task_description": task_description, - "blueprint": None, - "generated_code": "", - "failure_report": None, - "status": WorkflowStatus.PLANNING, - "retry_count": 0, - "tokens_used": 0, - "error_message": "", - "memory_context": [], - "memory_writes": [], - "trace": [], - "sandbox_result": None, - "parsed_files": [], - } - - invoke_config = { - "recursion_limit": 25, - } +def run_task(task_description, enable_tracing=True, session_id=None, tags=None): + trace_config = create_trace_config(enabled=enable_tracing, task_description=task_description, session_id=session_id, tags=tags or ["orchestrator"], metadata={"max_retries": str(MAX_RETRIES), "token_budget": str(TOKEN_BUDGET)}) + workflow = create_workflow() + tools_config = init_tools_config() + initial_state: GraphState = {"task_description": task_description, "blueprint": None, "generated_code": "", "failure_report": None, "status": WorkflowStatus.PLANNING, "retry_count": 0, "tokens_used": 0, "error_message": "", "memory_context": [], "memory_writes": [], "trace": [], "sandbox_result": None, "parsed_files": [], "tool_calls_log": []} + invoke_config = {"recursion_limit": 25, **tools_config} if trace_config.callbacks: invoke_config["callbacks"] = trace_config.callbacks - - # Wrap the entire workflow execution in propagation context so that - # session_id, tags, and metadata flow through to all Langfuse spans. with trace_config.propagation_context(): - add_trace_event(trace_config, "orchestrator_start", metadata={ - "task_preview": task_description[:200], - "max_retries": MAX_RETRIES, - "token_budget": TOKEN_BUDGET, - }) - + add_trace_event(trace_config, "orchestrator_start", metadata={"task_preview": task_description[:200], "max_retries": MAX_RETRIES, "token_budget": TOKEN_BUDGET, "tools_available": len(tools_config.get("configurable", {}).get("tools", []))}) result = workflow.invoke(initial_state, config=invoke_config) final_state = AgentState(**result) - - add_trace_event(trace_config, "orchestrator_complete", metadata={ - "status": final_state.status.value, - "tokens_used": final_state.tokens_used, - "retry_count": final_state.retry_count, - "memory_writes_count": len(final_state.memory_writes), - }) - + add_trace_event(trace_config, "orchestrator_complete", metadata={"status": final_state.status.value, "tokens_used": final_state.tokens_used, "retry_count": final_state.retry_count, "memory_writes_count": len(final_state.memory_writes), "tool_calls_count": len(final_state.tool_calls_log)}) trace_config.flush() - return final_state diff --git a/dev-suite/tests/test_e2e.py b/dev-suite/tests/test_e2e.py index be95f09..7b9aaa8 100644 --- a/dev-suite/tests/test_e2e.py +++ b/dev-suite/tests/test_e2e.py @@ -1,17 +1,17 @@ """End-to-end pipeline validation tests (Step 7). -These tests verify the full Architect → Lead Dev → QA pipeline +These tests verify the full Architect -> Lead Dev -> QA pipeline by mocking LLM responses with realistic payloads. They validate: -1. Happy path: Blueprint → code → QA pass → PASSED -2. Retry path: QA fail → retry developer → QA pass -3. Escalation path: QA escalate → re-plan architect → developer → QA pass -4. Budget exhaustion: tokens_used >= TOKEN_BUDGET → END -5. Max retries: retry_count >= MAX_RETRIES → END +1. Happy path: Blueprint -> code -> QA pass -> PASSED +2. Retry path: QA fail -> retry developer -> QA pass +3. Escalation path: QA escalate -> re-plan architect -> developer -> QA pass +4. Budget exhaustion: tokens_used >= TOKEN_BUDGET -> END +5. Max retries: retry_count >= MAX_RETRIES -> END 6. Memory integration: Chroma context is fetched and included 7. Tracing integration: TracingConfig wired through correctly -No API keys needed — all LLM calls are mocked. +No API keys needed -- all LLM calls are mocked. """ import json @@ -37,7 +37,7 @@ ) -# ── Fixtures ── +# -- Fixtures -- SAMPLE_BLUEPRINT = Blueprint( @@ -114,28 +114,21 @@ def _make_llm_response(content: str, total_tokens: int = 500) -> MagicMock: return resp -# ── Happy Path Tests ── +# -- Happy Path Tests -- class TestE2EHappyPath: - """Full pipeline: Architect → Lead Dev → QA → PASSED.""" @patch("src.orchestrator._fetch_memory_context") @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") def test_full_pipeline_pass(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): - """A task goes through the full loop and passes QA on the first try.""" mock_memory.return_value = ["Project uses Python 3.13", "Use type hints everywhere"] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=800) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE, total_tokens=1200) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_response = _make_llm_response(SAMPLE_QA_PASS.model_dump_json(), total_tokens=400) - mock_qa_llm.return_value.invoke.return_value = qa_response - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=800) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE, total_tokens=1200) + mock_qa_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_PASS.model_dump_json(), total_tokens=400) result = run_task("Create a Python function that validates email addresses", enable_tracing=False) - assert result.status == WorkflowStatus.PASSED assert result.retry_count == 0 assert result.tokens_used == 800 + 1200 + 400 @@ -153,49 +146,33 @@ def test_full_pipeline_pass(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") - def test_memory_context_included_in_architect_prompt( - self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory - ): - """Memory context should be included in the Architect's system prompt.""" + def test_memory_context_included_in_architect_prompt(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): mock_memory.return_value = ["Framework: SvelteKit", "Database: CosmosDB"] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_response = _make_llm_response(SAMPLE_QA_PASS.model_dump_json()) - mock_qa_llm.return_value.invoke.return_value = qa_response - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE) + mock_qa_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_PASS.model_dump_json()) result = run_task("Build a new endpoint", enable_tracing=False) - arch_call_args = mock_arch_llm.return_value.invoke.call_args[0][0] system_msg = arch_call_args[0].content assert "SvelteKit" in system_msg assert "CosmosDB" in system_msg -# ── Retry Path Tests ── +# -- Retry Path Tests -- class TestE2ERetryPath: - """QA fails → retry developer → eventually passes.""" @patch("src.orchestrator._fetch_memory_context") @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") def test_retry_then_pass(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): - """QA fails on first attempt, passes on retry.""" mock_memory.return_value = [] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=800) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE, total_tokens=1200) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_fail = _make_llm_response(SAMPLE_QA_FAIL.model_dump_json(), total_tokens=400) - qa_pass = _make_llm_response(SAMPLE_QA_PASS.model_dump_json(), total_tokens=400) - mock_qa_llm.return_value.invoke.side_effect = [qa_fail, qa_pass] - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=800) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE, total_tokens=1200) + mock_qa_llm.return_value.invoke.side_effect = [_make_llm_response(SAMPLE_QA_FAIL.model_dump_json(), total_tokens=400), _make_llm_response(SAMPLE_QA_PASS.model_dump_json(), total_tokens=400)] result = run_task("Create email validator", enable_tracing=False) - assert result.status == WorkflowStatus.PASSED assert result.retry_count == 1 assert mock_dev_llm.return_value.invoke.call_count == 2 @@ -205,21 +182,12 @@ def test_retry_then_pass(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_me @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") - def test_failure_report_included_in_retry( - self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory - ): - """Developer's retry prompt should include the QA failure report.""" + def test_failure_report_included_in_retry(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): mock_memory.return_value = [] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_fail = _make_llm_response(SAMPLE_QA_FAIL.model_dump_json()) - qa_pass = _make_llm_response(SAMPLE_QA_PASS.model_dump_json()) - mock_qa_llm.return_value.invoke.side_effect = [qa_fail, qa_pass] - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE) + mock_qa_llm.return_value.invoke.side_effect = [_make_llm_response(SAMPLE_QA_FAIL.model_dump_json()), _make_llm_response(SAMPLE_QA_PASS.model_dump_json())] result = run_task("Create email validator", enable_tracing=False) - dev_calls = mock_dev_llm.return_value.invoke.call_args_list assert len(dev_calls) == 2 retry_msg = dev_calls[1][0][0][1].content @@ -227,31 +195,21 @@ def test_failure_report_included_in_retry( assert "Missing edge case for empty string" in retry_msg -# ── Escalation Path Tests ── +# -- Escalation Path Tests -- class TestE2EEscalation: - """QA escalates architectural failure → re-plans with Architect.""" @patch("src.orchestrator._fetch_memory_context") @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") - def test_escalation_to_architect( - self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory - ): - """QA escalates → Architect re-plans → Developer retries → QA passes.""" + def test_escalation_to_architect(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): mock_memory.return_value = [] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=800) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE, total_tokens=1200) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_escalate = _make_llm_response(SAMPLE_QA_ESCALATE.model_dump_json(), total_tokens=400) - qa_pass = _make_llm_response(SAMPLE_QA_PASS.model_dump_json(), total_tokens=400) - mock_qa_llm.return_value.invoke.side_effect = [qa_escalate, qa_pass] - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=800) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE, total_tokens=1200) + mock_qa_llm.return_value.invoke.side_effect = [_make_llm_response(SAMPLE_QA_ESCALATE.model_dump_json(), total_tokens=400), _make_llm_response(SAMPLE_QA_PASS.model_dump_json(), total_tokens=400)] result = run_task("Create email validator", enable_tracing=False) - assert result.status == WorkflowStatus.PASSED assert mock_arch_llm.return_value.invoke.call_count == 2 assert mock_dev_llm.return_value.invoke.call_count == 2 @@ -261,21 +219,12 @@ def test_escalation_to_architect( @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") - def test_escalation_includes_failure_in_architect_prompt( - self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory - ): - """Re-plan prompt should include the QA escalation reason.""" + def test_escalation_includes_failure_in_architect_prompt(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): mock_memory.return_value = [] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_escalate = _make_llm_response(SAMPLE_QA_ESCALATE.model_dump_json()) - qa_pass = _make_llm_response(SAMPLE_QA_PASS.model_dump_json()) - mock_qa_llm.return_value.invoke.side_effect = [qa_escalate, qa_pass] - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE) + mock_qa_llm.return_value.invoke.side_effect = [_make_llm_response(SAMPLE_QA_ESCALATE.model_dump_json()), _make_llm_response(SAMPLE_QA_PASS.model_dump_json())] result = run_task("Create email validator", enable_tracing=False) - arch_calls = mock_arch_llm.return_value.invoke.call_args_list assert len(arch_calls) == 2 replan_msg = arch_calls[1][0][0][1].content @@ -283,30 +232,21 @@ def test_escalation_includes_failure_in_architect_prompt( assert "architectural" in replan_msg.lower() or "parsing library" in replan_msg -# ── Budget & Limit Tests ── +# -- Budget & Limit Tests -- class TestE2EBudgetLimits: - """Pipeline respects token budget and retry limits.""" @patch("src.orchestrator._fetch_memory_context") @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") - def test_max_retries_stops_pipeline( - self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory - ): - """Pipeline stops after MAX_RETRIES failures.""" + def test_max_retries_stops_pipeline(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): mock_memory.return_value = [] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=100) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE, total_tokens=100) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_fail = _make_llm_response(SAMPLE_QA_FAIL.model_dump_json(), total_tokens=100) - mock_qa_llm.return_value.invoke.return_value = qa_fail - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=100) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE, total_tokens=100) + mock_qa_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_FAIL.model_dump_json(), total_tokens=100) result = run_task("Create email validator", enable_tracing=False) - assert result.retry_count >= MAX_RETRIES assert result.status != WorkflowStatus.PASSED assert mock_dev_llm.return_value.invoke.call_count == MAX_RETRIES @@ -316,41 +256,27 @@ def test_max_retries_stops_pipeline( @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") - def test_token_budget_stops_pipeline( - self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory - ): - """Pipeline stops when token budget is exhausted.""" + def test_token_budget_stops_pipeline(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): mock_memory.return_value = [] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=400) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE, total_tokens=400) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_fail = _make_llm_response(SAMPLE_QA_FAIL.model_dump_json(), total_tokens=400) - mock_qa_llm.return_value.invoke.return_value = qa_fail - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=400) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE, total_tokens=400) + mock_qa_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_FAIL.model_dump_json(), total_tokens=400) result = run_task("Create email validator", enable_tracing=False) - assert result.tokens_used >= 1000 -# ── Node-Level Tests ── +# -- Node-Level Tests -- class TestE2ENodeFunctions: - """Test individual node functions with realistic mock data.""" @patch("src.orchestrator._fetch_memory_context") @patch("src.orchestrator._get_architect_llm") def test_architect_node_produces_blueprint(self, mock_llm, mock_memory): - """architect_node should parse LLM output into a Blueprint.""" mock_memory.return_value = ["Use Python 3.13"] - mock_llm.return_value.invoke.return_value = _make_llm_response( - SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=500 - ) - + mock_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json(), total_tokens=500) state = {"task_description": "Create email validator", "trace": [], "tokens_used": 0, "retry_count": 0} result = architect_node(state) - assert result["status"] == WorkflowStatus.BUILDING assert result["blueprint"].task_id == "e2e-test-001" assert len(result["blueprint"].target_files) == 1 @@ -359,127 +285,68 @@ def test_architect_node_produces_blueprint(self, mock_llm, mock_memory): @patch("src.orchestrator._get_developer_llm") def test_developer_node_generates_code(self, mock_llm): - """developer_node should generate code from the Blueprint.""" - mock_llm.return_value.invoke.return_value = _make_llm_response( - SAMPLE_CODE, total_tokens=1200 - ) - - state = { - "task_description": "Create email validator", - "blueprint": SAMPLE_BLUEPRINT, - "status": WorkflowStatus.BUILDING, - "trace": [], - "tokens_used": 0, - "retry_count": 0, - } + mock_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE, total_tokens=1200) + state = {"task_description": "Create email validator", "blueprint": SAMPLE_BLUEPRINT, "status": WorkflowStatus.BUILDING, "trace": [], "tokens_used": 0, "retry_count": 0} result = developer_node(state) - assert result["status"] == WorkflowStatus.REVIEWING assert "validate_email" in result["generated_code"] assert result["tokens_used"] == 1200 @patch("src.orchestrator._get_qa_llm") def test_qa_node_returns_pass(self, mock_llm): - """qa_node should parse a 'pass' verdict correctly.""" - mock_llm.return_value.invoke.return_value = _make_llm_response( - SAMPLE_QA_PASS.model_dump_json(), total_tokens=400 - ) - - state = { - "task_description": "Create email validator", - "blueprint": SAMPLE_BLUEPRINT, - "generated_code": SAMPLE_CODE, - "status": WorkflowStatus.REVIEWING, - "trace": [], - "tokens_used": 0, - "retry_count": 0, - } + mock_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_PASS.model_dump_json(), total_tokens=400) + state = {"task_description": "Create email validator", "blueprint": SAMPLE_BLUEPRINT, "generated_code": SAMPLE_CODE, "status": WorkflowStatus.REVIEWING, "trace": [], "tokens_used": 0, "retry_count": 0} result = qa_node(state) - assert result["status"] == WorkflowStatus.PASSED assert result["failure_report"].status == "pass" assert result["failure_report"].tests_passed == 4 @patch("src.orchestrator._get_qa_llm") def test_qa_node_returns_fail(self, mock_llm): - """qa_node should parse a 'fail' verdict and increment retry count.""" - mock_llm.return_value.invoke.return_value = _make_llm_response( - SAMPLE_QA_FAIL.model_dump_json(), total_tokens=400 - ) - - state = { - "task_description": "Create email validator", - "blueprint": SAMPLE_BLUEPRINT, - "generated_code": SAMPLE_CODE, - "status": WorkflowStatus.REVIEWING, - "retry_count": 0, - "trace": [], - "tokens_used": 0, - } + mock_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_FAIL.model_dump_json(), total_tokens=400) + state = {"task_description": "Create email validator", "blueprint": SAMPLE_BLUEPRINT, "generated_code": SAMPLE_CODE, "status": WorkflowStatus.REVIEWING, "retry_count": 0, "trace": [], "tokens_used": 0} result = qa_node(state) - assert result["status"] == WorkflowStatus.REVIEWING assert result["failure_report"].status == "fail" assert result["retry_count"] == 1 @patch("src.orchestrator._get_qa_llm") def test_qa_node_returns_escalate(self, mock_llm): - """qa_node should detect architectural failures and set ESCALATED status.""" - mock_llm.return_value.invoke.return_value = _make_llm_response( - SAMPLE_QA_ESCALATE.model_dump_json(), total_tokens=400 - ) - - state = { - "task_description": "Create email validator", - "blueprint": SAMPLE_BLUEPRINT, - "generated_code": SAMPLE_CODE, - "status": WorkflowStatus.REVIEWING, - "retry_count": 0, - "trace": [], - "tokens_used": 0, - } + mock_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_ESCALATE.model_dump_json(), total_tokens=400) + state = {"task_description": "Create email validator", "blueprint": SAMPLE_BLUEPRINT, "generated_code": SAMPLE_CODE, "status": WorkflowStatus.REVIEWING, "retry_count": 0, "trace": [], "tokens_used": 0} result = qa_node(state) - assert result["status"] == WorkflowStatus.ESCALATED assert result["failure_report"].is_architectural is True @patch("src.orchestrator._get_architect_llm") @patch("src.orchestrator._fetch_memory_context") def test_architect_handles_json_in_code_fence(self, mock_memory, mock_llm): - """Architect should handle LLM wrapping JSON in ```json code fences.""" mock_memory.return_value = [] fenced_json = f"```json\n{SAMPLE_BLUEPRINT.model_dump_json()}\n```" mock_llm.return_value.invoke.return_value = _make_llm_response(fenced_json) - state = {"task_description": "Create email validator", "trace": [], "tokens_used": 0, "retry_count": 0} result = architect_node(state) - assert result["status"] == WorkflowStatus.BUILDING assert result["blueprint"].task_id == "e2e-test-001" -# ── Memory Integration Tests ── +# -- Memory Integration Tests -- class TestE2EMemoryIntegration: - """Test that the pipeline correctly queries and uses Chroma memory.""" def test_memory_query_with_real_chroma(self, tmp_path): - """Verify _fetch_memory_context queries Chroma and returns results.""" store = ChromaMemoryStore(persist_dir=str(tmp_path / "chroma"), collection_name="e2e_test") store.add_l0_core("Project language: Python 3.13") store.add_l0_core("Always use type hints") store.add_l1("Auth module uses JWT tokens", module="auth") - results = store.query("What programming language does this project use?") assert len(results) > 0 assert any("Python" in r.content for r in results) def test_memory_context_survives_pipeline(self, tmp_path): - """Memory context retrieved early in pipeline should be in final state.""" store = ChromaMemoryStore(persist_dir=str(tmp_path / "chroma"), collection_name="e2e_mem") store.add_l0_core("Test memory entry") - with patch("src.memory.chroma_store.ChromaMemoryStore") as MockStore: MockStore.return_value = store context = [] @@ -488,32 +355,23 @@ def test_memory_context_survives_pipeline(self, tmp_path): context = [r.content for r in results] except Exception: pass - assert len(context) >= 1 -# ── Tracing Integration Tests ── +# -- Tracing Integration Tests -- class TestE2ETracingIntegration: - """Test that tracing is properly wired through the pipeline.""" @patch("src.orchestrator._fetch_memory_context") @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") - def test_tracing_disabled_runs_cleanly( - self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory - ): - """Pipeline should work with tracing explicitly disabled.""" + def test_tracing_disabled_runs_cleanly(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory): mock_memory.return_value = [] - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_response = _make_llm_response(SAMPLE_QA_PASS.model_dump_json()) - mock_qa_llm.return_value.invoke.return_value = qa_response - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE) + mock_qa_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_PASS.model_dump_json()) result = run_task("Create email validator", enable_tracing=False) assert result.status == WorkflowStatus.PASSED @@ -522,18 +380,11 @@ def test_tracing_disabled_runs_cleanly( @patch("src.orchestrator._get_qa_llm") @patch("src.orchestrator._get_developer_llm") @patch("src.orchestrator._get_architect_llm") - def test_tracing_enabled_calls_trace_config( - self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory, mock_trace - ): - """When tracing is enabled, create_trace_config should be called.""" + def test_tracing_enabled_calls_trace_config(self, mock_arch_llm, mock_dev_llm, mock_qa_llm, mock_memory, mock_trace): mock_memory.return_value = [] mock_trace.return_value = MagicMock(callbacks=[], enabled=False, trace_id=None, flush=MagicMock()) - arch_response = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) - mock_arch_llm.return_value.invoke.return_value = arch_response - dev_response = _make_llm_response(SAMPLE_CODE) - mock_dev_llm.return_value.invoke.return_value = dev_response - qa_response = _make_llm_response(SAMPLE_QA_PASS.model_dump_json()) - mock_qa_llm.return_value.invoke.return_value = qa_response - + mock_arch_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_BLUEPRINT.model_dump_json()) + mock_dev_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_CODE) + mock_qa_llm.return_value.invoke.return_value = _make_llm_response(SAMPLE_QA_PASS.model_dump_json()) result = run_task("Create email validator", enable_tracing=True) mock_trace.assert_called_once() diff --git a/dev-suite/tests/test_tool_binding.py b/dev-suite/tests/test_tool_binding.py new file mode 100644 index 0000000..b4c6718 --- /dev/null +++ b/dev-suite/tests/test_tool_binding.py @@ -0,0 +1,420 @@ +"""Tests for issue #80: Agent tool binding. + +Tests cover: +- Tool extraction from RunnableConfig +- Tool filtering by agent role +- Tool execution (success and failure) +- Tool loop iteration and termination +- Developer node with tools (agentic mode) +- Developer node without tools (single-shot fallback) +- QA node with tools (read-only subset) +- QA node without tools (single-shot fallback) +- init_tools_config() factory +- TOOL_CALL event type in events.py +- GraphState tool_calls_log field +- _sanitize_preview helper (Fix 4) +- max_turns <= 0 guard (Fix 5) +""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class FakeTool: + """Mock tool that supports both ainvoke (async) and invoke (sync) APIs.""" + + def __init__(self, name, result="ok", should_raise=False): + self.name = name + self._result = result + self._should_raise = should_raise + self.invoke_count = 0 + + async def ainvoke(self, args): + """Async public API (preferred by _execute_tool_call).""" + self.invoke_count += 1 + if self._should_raise: + raise RuntimeError(f"Tool {self.name} failed") + return self._result + + def invoke(self, args): + """Sync fallback API.""" + self.invoke_count += 1 + if self._should_raise: + raise RuntimeError(f"Tool {self.name} failed") + return self._result + + +class SyncOnlyFakeTool: + """Mock tool that only has invoke (no ainvoke) for testing sync fallback.""" + + def __init__(self, name, result="ok", should_raise=False): + self.name = name + self._result = result + self._should_raise = should_raise + + def invoke(self, args): + if self._should_raise: + raise RuntimeError(f"Tool {self.name} failed") + return self._result + + +class FakeResponse: + def __init__(self, content="test response", tool_calls=None, usage=None): + self.content = content + self.tool_calls = tool_calls or [] + self.usage_metadata = usage or {"input_tokens": 100, "output_tokens": 50} + + +def make_config(tools=None): + return {"configurable": {"tools": tools or []}} + + +class TestGetAgentTools: + def test_returns_empty_when_no_config(self): + from src.orchestrator import _get_agent_tools + assert _get_agent_tools(None) == [] + + def test_returns_empty_when_no_configurable(self): + from src.orchestrator import _get_agent_tools + assert _get_agent_tools({}) == [] + + def test_returns_empty_when_no_tools_key(self): + from src.orchestrator import _get_agent_tools + assert _get_agent_tools({"configurable": {}}) == [] + + def test_returns_all_tools_when_no_filter(self): + from src.orchestrator import _get_agent_tools + tools = [FakeTool("a"), FakeTool("b"), FakeTool("c")] + assert len(_get_agent_tools(make_config(tools))) == 3 + + def test_filters_by_allowed_names(self): + from src.orchestrator import _get_agent_tools + tools = [FakeTool("filesystem_read"), FakeTool("filesystem_write"), FakeTool("github_read_diff")] + result = _get_agent_tools(make_config(tools), {"filesystem_read", "github_read_diff"}) + assert {t.name for t in result} == {"filesystem_read", "github_read_diff"} + + def test_filters_returns_empty_when_no_match(self): + from src.orchestrator import _get_agent_tools + assert _get_agent_tools(make_config([FakeTool("filesystem_read")]), {"nonexistent"}) == [] + + def test_dev_tool_names_subset(self): + from src.orchestrator import DEV_TOOL_NAMES, QA_TOOL_NAMES + assert QA_TOOL_NAMES.issubset(DEV_TOOL_NAMES) + + def test_qa_has_no_write_tools(self): + from src.orchestrator import QA_TOOL_NAMES + for name in QA_TOOL_NAMES: + assert "write" not in name.lower() + assert "create" not in name.lower() + + def test_dev_has_no_pr_creation(self): + """Fix 2: github_create_pr removed from DEV_TOOL_NAMES.""" + from src.orchestrator import DEV_TOOL_NAMES + assert "github_create_pr" not in DEV_TOOL_NAMES + + +class TestExecuteToolCall: + @pytest.mark.asyncio + async def test_executes_tool_successfully(self): + from src.orchestrator import _execute_tool_call + tool = FakeTool("filesystem_read", result="file contents here") + result = await _execute_tool_call({"name": "filesystem_read", "args": {"path": "test.py"}, "id": "call_1"}, [tool]) + assert result.content == "file contents here" + assert result.tool_call_id == "call_1" + + @pytest.mark.asyncio + async def test_returns_error_for_unknown_tool(self): + from src.orchestrator import _execute_tool_call + result = await _execute_tool_call({"name": "nonexistent", "args": {}, "id": "call_2"}, [FakeTool("other")]) + assert "Error" in result.content + + @pytest.mark.asyncio + async def test_handles_tool_exception(self): + from src.orchestrator import _execute_tool_call + result = await _execute_tool_call({"name": "failing_tool", "args": {}, "id": "call_3"}, [FakeTool("failing_tool", should_raise=True)]) + assert "Error executing failing_tool" in result.content + + @pytest.mark.asyncio + async def test_prefers_async_ainvoke(self): + """Fix 3: _execute_tool_call uses ainvoke (public API) not tool.coroutine.""" + from src.orchestrator import _execute_tool_call + tool = FakeTool("async_tool", result="async result") + result = await _execute_tool_call({"name": "async_tool", "args": {}, "id": "call_4"}, [tool]) + assert result.content == "async result" + assert tool.invoke_count == 1 + + @pytest.mark.asyncio + async def test_falls_back_to_sync_invoke(self): + """Fix 3: Falls back to invoke() when ainvoke is not available.""" + from src.orchestrator import _execute_tool_call + tool = SyncOnlyFakeTool("sync_tool", result="sync result") + result = await _execute_tool_call({"name": "sync_tool", "args": {}, "id": "call_5"}, [tool]) + assert result.content == "sync result" + + +class TestRunToolLoop: + @pytest.mark.asyncio + async def test_no_tool_calls_returns_immediately(self): + from src.orchestrator import _run_tool_loop + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = FakeResponse(content="final answer") + response, tokens, log = await _run_tool_loop(mock_llm, [], []) + assert response.content == "final answer" + assert mock_llm.ainvoke.call_count == 1 + assert len(log) == 0 + + @pytest.mark.asyncio + async def test_single_tool_turn(self): + from src.orchestrator import _run_tool_loop + tool = FakeTool("filesystem_read", result="contents") + first_resp = FakeResponse(content="") + first_resp.tool_calls = [{"name": "filesystem_read", "args": {"path": "f.py"}, "id": "tc1"}] + mock_llm = AsyncMock() + mock_llm.ainvoke.side_effect = [first_resp, FakeResponse(content="done")] + response, tokens, log = await _run_tool_loop(mock_llm, [], [tool]) + assert response.content == "done" + assert len(log) == 1 + assert log[0]["success"] is True + + @pytest.mark.asyncio + async def test_max_turns_limit(self): + from src.orchestrator import _run_tool_loop + resp = FakeResponse(content="") + resp.tool_calls = [{"name": "filesystem_read", "args": {}, "id": "tc"}] + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = resp + _, _, log = await _run_tool_loop(mock_llm, [], [FakeTool("filesystem_read")], max_turns=3) + assert mock_llm.ainvoke.call_count == 3 + assert len(log) == 3 + + @pytest.mark.asyncio + async def test_accumulates_tokens(self): + from src.orchestrator import _run_tool_loop + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = FakeResponse(content="answer", usage={"input_tokens": 200, "output_tokens": 100}) + _, tokens, _ = await _run_tool_loop(mock_llm, [], [], tokens_used=500) + assert tokens == 800 + + @pytest.mark.asyncio + async def test_tool_failure_logged(self): + from src.orchestrator import _run_tool_loop + first_resp = FakeResponse(content="") + first_resp.tool_calls = [{"name": "bad_tool", "args": {}, "id": "tc1"}] + mock_llm = AsyncMock() + mock_llm.ainvoke.side_effect = [first_resp, FakeResponse(content="recovered")] + _, _, log = await _run_tool_loop(mock_llm, [], [FakeTool("bad_tool", should_raise=True)]) + assert len(log) == 1 + assert log[0]["success"] is False + + @pytest.mark.asyncio + async def test_max_turns_zero_returns_immediately(self): + """Fix 5: max_turns <= 0 should not crash with unbound response.""" + from src.orchestrator import _run_tool_loop + from langchain_core.messages import HumanMessage + mock_llm = AsyncMock() + msg = HumanMessage(content="test") + response, tokens, log = await _run_tool_loop(mock_llm, [msg], [], max_turns=0, tokens_used=100) + assert mock_llm.ainvoke.call_count == 0 + assert tokens == 100 + assert len(log) == 0 + + +class TestDeveloperNodeTools: + @pytest.fixture + def blueprint(self): + from src.agents.architect import Blueprint + return Blueprint(task_id="test-task", target_files=["src/main.py"], instructions="Write a main function", constraints=["Use type hints"], acceptance_criteria=["Function exists"]) + + @pytest.fixture + def base_state(self, blueprint): + return {"task_description": "Test task", "blueprint": blueprint, "generated_code": "", "failure_report": None, "status": "building", "retry_count": 0, "tokens_used": 0, "error_message": "", "memory_context": [], "memory_writes": [], "trace": [], "sandbox_result": None, "parsed_files": [], "tool_calls_log": []} + + def test_dev_with_tools_uses_bind_tools(self, base_state): + """When tools are in config, developer should bind them to the LLM.""" + from src.orchestrator import developer_node + tools = [FakeTool("filesystem_read"), FakeTool("filesystem_write")] + mock_llm = MagicMock() + mock_llm_bound = AsyncMock() + mock_llm.bind_tools.return_value = mock_llm_bound + mock_llm_bound.ainvoke.return_value = FakeResponse(content="# --- FILE: src/main.py ---\ndef main(): pass") + with patch("src.orchestrator._get_developer_llm", return_value=mock_llm): + result = developer_node(base_state, make_config(tools)) + mock_llm.bind_tools.assert_called_once() + assert result["generated_code"] != "" + assert result["status"].value == "reviewing" + + def test_dev_without_tools_single_shot(self, base_state): + from src.orchestrator import developer_node + mock_llm = MagicMock() + mock_llm.invoke.return_value = FakeResponse(content="# --- FILE: src/main.py ---\ndef main(): pass") + with patch("src.orchestrator._get_developer_llm", return_value=mock_llm): + result = developer_node(base_state, make_config([])) + mock_llm.invoke.assert_called_once() + mock_llm.bind_tools.assert_not_called() + + def test_dev_no_config_single_shot(self, base_state): + from src.orchestrator import developer_node + mock_llm = MagicMock() + mock_llm.invoke.return_value = FakeResponse(content="code here") + with patch("src.orchestrator._get_developer_llm", return_value=mock_llm): + result = developer_node(base_state, None) + mock_llm.invoke.assert_called_once() + + def test_dev_tool_calls_logged_in_state(self, base_state): + from src.orchestrator import developer_node + tool = FakeTool("filesystem_read", result="existing code") + first_resp = FakeResponse(content="") + first_resp.tool_calls = [{"name": "filesystem_read", "args": {"path": "x"}, "id": "tc1"}] + mock_llm = MagicMock() + mock_llm_bound = AsyncMock() + mock_llm.bind_tools.return_value = mock_llm_bound + mock_llm_bound.ainvoke.side_effect = [first_resp, FakeResponse(content="# --- FILE: src/main.py ---\ncode")] + with patch("src.orchestrator._get_developer_llm", return_value=mock_llm): + result = developer_node(base_state, make_config([tool])) + assert len(result["tool_calls_log"]) == 1 + assert result["tool_calls_log"][0]["tool"] == "filesystem_read" + assert result["tool_calls_log"][0]["agent"] == "developer" + + def test_dev_tools_filtered_to_dev_set(self, base_state): + """Fix 2: DEV_TOOL_NAMES no longer includes github_create_pr.""" + from src.orchestrator import DEV_TOOL_NAMES, developer_node + all_tools = [FakeTool(n) for n in ["filesystem_read", "filesystem_write", "filesystem_list", "github_read_diff", "github_create_pr", "unexpected_tool"]] + mock_llm = MagicMock() + mock_llm_bound = AsyncMock() + mock_llm.bind_tools.return_value = mock_llm_bound + mock_llm_bound.ainvoke.return_value = FakeResponse(content="code") + with patch("src.orchestrator._get_developer_llm", return_value=mock_llm): + developer_node(base_state, make_config(all_tools)) + bound_names = {t.name for t in mock_llm.bind_tools.call_args[0][0]} + assert bound_names == DEV_TOOL_NAMES + assert "github_create_pr" not in bound_names + + +class TestQANodeTools: + @pytest.fixture + def qa_state(self): + from src.agents.architect import Blueprint + bp = Blueprint(task_id="test-qa", target_files=["src/main.py"], instructions="Implement feature", constraints=[], acceptance_criteria=["Tests pass"]) + return {"task_description": "Test task", "blueprint": bp, "generated_code": "# --- FILE: src/main.py ---\ndef main(): pass", "failure_report": None, "status": "reviewing", "retry_count": 0, "tokens_used": 0, "error_message": "", "memory_context": [], "memory_writes": [], "trace": [], "sandbox_result": None, "parsed_files": [], "tool_calls_log": []} + + def test_qa_with_tools_gets_read_only(self, qa_state): + from src.orchestrator import QA_TOOL_NAMES, qa_node + all_tools = [FakeTool(n) for n in ["filesystem_read", "filesystem_write", "filesystem_list", "github_read_diff", "github_create_pr"]] + mock_llm = MagicMock() + mock_llm_bound = AsyncMock() + mock_llm.bind_tools.return_value = mock_llm_bound + qa_json = json.dumps({"task_id": "test-qa", "status": "pass", "tests_passed": 1, "tests_failed": 0, "errors": [], "failed_files": [], "is_architectural": False, "failure_type": None, "recommendation": "All good"}) + mock_llm_bound.ainvoke.return_value = FakeResponse(content=qa_json) + with patch("src.orchestrator._get_qa_llm", return_value=mock_llm): + qa_node(qa_state, make_config(all_tools)) + bound_names = {t.name for t in mock_llm.bind_tools.call_args[0][0]} + assert bound_names == QA_TOOL_NAMES + assert "filesystem_write" not in bound_names + + def test_qa_without_tools(self, qa_state): + from src.orchestrator import qa_node + mock_llm = MagicMock() + qa_json = json.dumps({"task_id": "test-qa", "status": "pass", "tests_passed": 1, "tests_failed": 0, "errors": [], "failed_files": [], "is_architectural": False, "failure_type": None, "recommendation": "All good"}) + mock_llm.invoke.return_value = FakeResponse(content=qa_json) + with patch("src.orchestrator._get_qa_llm", return_value=mock_llm): + result = qa_node(qa_state, make_config([])) + mock_llm.invoke.assert_called_once() + assert result["failure_report"].status == "pass" + + +class TestInitToolsConfig: + def test_returns_empty_when_no_config_file(self, tmp_path): + from src.orchestrator import init_tools_config + assert init_tools_config(workspace_root=tmp_path) == {"configurable": {"tools": []}} + + def test_returns_empty_on_exception(self, tmp_path): + from src.orchestrator import init_tools_config + (tmp_path / "mcp-config.json").write_text("not valid json") + assert init_tools_config(workspace_root=tmp_path) == {"configurable": {"tools": []}} + + def test_loads_tools_from_valid_config(self, tmp_path): + from src.orchestrator import init_tools_config + (tmp_path / "mcp-config.json").write_text(json.dumps({"servers": {"filesystem": {"version": "1.0"}}, "last_reviewed": "2026-03-01"})) + mock_tools = [FakeTool("filesystem_read")] + with patch("src.tools.load_mcp_config") as mock_load, patch("src.tools.create_provider") as mock_create, patch("src.tools.get_tools", return_value=mock_tools): + result = init_tools_config(workspace_root=tmp_path) + assert len(result["configurable"]["tools"]) == 1 + + +class TestSanitizePreview: + """Fix 4: _sanitize_preview helper tests.""" + + def test_empty_string(self): + from src.orchestrator import _sanitize_preview + assert _sanitize_preview("") == "" + + def test_normal_text_unchanged(self): + from src.orchestrator import _sanitize_preview + assert _sanitize_preview("normal tool output") == "normal tool output" + + def test_truncates_long_text(self): + from src.orchestrator import _sanitize_preview + result = _sanitize_preview("x" * 500) + assert len(result) <= 204 + + def test_redacts_api_key_pattern(self): + from src.orchestrator import _sanitize_preview + assert "[REDACTED]" in _sanitize_preview("key is sk_test_abc123def456ghi789") + + def test_redacts_github_token(self): + from src.orchestrator import _sanitize_preview + assert "[REDACTED]" in _sanitize_preview("token: ghp_abcdefghijklmnopqrstuvwxyz12345678") + + def test_redacts_jwt_pattern(self): + from src.orchestrator import _sanitize_preview + assert "[REDACTED]" in _sanitize_preview("bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9") + + +class TestToolCallEvent: + def test_tool_call_event_type_exists(self): + from src.api.events import EventType + assert hasattr(EventType, "TOOL_CALL") + assert EventType.TOOL_CALL.value == "tool_call" + + def test_can_create_tool_call_sse_event(self): + from src.api.events import EventType, SSEEvent + event = SSEEvent(type=EventType.TOOL_CALL, data={"task_id": "test", "agent": "developer", "tool": "filesystem_read", "success": True}) + assert event.type == EventType.TOOL_CALL + + @pytest.mark.asyncio + async def test_tool_call_event_published(self): + from src.api.events import EventBus, EventType, SSEEvent + bus = EventBus() + queue = await bus.subscribe() + await bus.publish(SSEEvent(type=EventType.TOOL_CALL, data={"agent": "dev", "tool": "filesystem_read"})) + received = queue.get_nowait() + assert received.type == EventType.TOOL_CALL + + +class TestGraphStateToolCallsLog: + def test_graph_state_has_tool_calls_log(self): + from src.orchestrator import GraphState + assert "tool_calls_log" in GraphState.__annotations__ + + def test_agent_state_has_tool_calls_log(self): + from src.orchestrator import AgentState + assert AgentState().tool_calls_log == [] + + def test_agent_state_with_tool_calls(self): + from src.orchestrator import AgentState + state = AgentState(tool_calls_log=[{"agent": "dev", "tool": "filesystem_read", "success": True}]) + assert len(state.tool_calls_log) == 1 + + +class TestMaxToolTurns: + def test_max_tool_turns_default(self): + from src.orchestrator import MAX_TOOL_TURNS + assert MAX_TOOL_TURNS == 10 + + def test_max_tool_turns_from_env(self, monkeypatch): + monkeypatch.setenv("MAX_TOOL_TURNS", "5") + from src.orchestrator import _safe_int + assert _safe_int("MAX_TOOL_TURNS", 10) == 5