diff --git a/README.md b/README.md index 46f4679..08a8765 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,47 @@ We then extend `KernelBenchEnv` to support: - **Batching**: `KernelBenchEnvGroupBuilder` groups multiple rollouts for the same problem, enabling **GRPO-style** training where rewards are normalized within groups. - **Dataset Construction**: `KernelBenchDatasetBuilder` handles the iteration over KernelBench levels and problems, partitioning them into training and evaluation sets. You are welcome to extend it to support more problems beyond what is currently in KernelBench. +### Multi-Turn RL + +We extend the single-turn pipeline with multi-turn iterative refinement, following the approach in [Kevin](https://arxiv.org/abs/2507.11948). Instead of generating one kernel per problem, the model generates a kernel, receives evaluation feedback (compilation errors, correctness failures, or speedup results), and refines its solution over multiple turns. + +`MultiTurnKernelBenchEnv` manages the multi-turn loop: +- **History management**: Prior turns (prompt, response, feedback) are kept in context with token-based truncation to stay within the context window. +- **Evaluation feedback**: Structured feedback tells the model what went wrong (compilation error, incorrect output, or correct but slow) so it can fix specific issues. +- **Early stopping**: Optionally stop the episode when the kernel passes all correctness tests. + +Training uses GRPO with discounted returns across turns: +- Per-turn scores are computed as `S = 0.3 * correct + speedup` (only for correct kernels). +- Discounted returns: `R_t = S_t + γ * R_{t+1}` (backward recursion, γ=0.4 by default). +- Advantages are normalized across all `group_size × max_turns` turn-level samples: `(R - mean) / (std + ε)`. +- PPO with asymmetric clipping (Clip-Higher, ε_low=0.2, ε_high=0.28) and constant length normalization. + +Enable multi-turn via config: +```yaml +multiturn: + enabled: true + max_turns: 4 # Refinement turns per trajectory + gamma: 0.4 # Discount factor + aggregation: "sum" # "sum" or "max" +``` + +Or via CLI: +```bash +uv run python -m kernelbench_tinker.scripts.train_kernel_rl \ + --config src/kernelbench_tinker/config/rl_kernelbench.yaml \ + multiturn.enabled=true \ + log_path=./runs/my_multiturn_experiment +``` + +Multi-turn inference is also supported via the eval script: +```bash +uv run python -m kernelbench_tinker.scripts.eval_kernel_rl \ + checkpoint_path= \ + multiturn_enabled=true \ + multiturn_max_turns=8 \ + level=1 +``` + ### Directory Structure ```text @@ -54,6 +95,7 @@ src/kernelbench_tinker/ envs/ kernelbench_client.py # KernelBench Python API wrapper kernelbench_env.py # Single-turn RL environment + multiturn_kernelbench_env.py # Multi-turn RL environment training/ models.py # Model/renderer configuration reward.py # Reward shaping @@ -282,7 +324,6 @@ Note the scope of this repo is an open-source implementation of KernelBench-Tink * More reward examples leveraging more fine-grained metrics * More reward hack checking -* Multi-turn RL to have denser reward signal like [Kevin](https://arxiv.org/abs/2507.11948) * Improve Step time and training efficiency diff --git a/src/kernelbench_tinker/config/configs.py b/src/kernelbench_tinker/config/configs.py index 0144250..40ad97f 100644 --- a/src/kernelbench_tinker/config/configs.py +++ b/src/kernelbench_tinker/config/configs.py @@ -81,3 +81,57 @@ class DatasetConfig: # Train/test split test_fraction: float = 0.1 + + +@dataclass +class MultiTurnConfig: + """ + Configuration for multi-turn RL training. + + Controls the iterative refinement loop where the model receives + evaluation feedback and can fix errors across multiple turns. + """ + + # Enable multi-turn mode (False = single-turn) + enabled: bool = False + + # Maximum refinement turns per trajectory + max_turns: int = 4 + + # Discount factor for multi-turn returns: R_t = S_t + gamma * R_{t+1} + gamma: float = 0.4 + + # Return aggregation mode: "sum" or "max" + # sum: R_t = Σ γ^(i-t) × S_i (reward turns leading to many good kernels) + # max: R_t = max{ γ^(i-t) × S_i } (reward turns leading to one great kernel) + aggregation: str = "sum" + + # Stop the episode early when the kernel is correct. + # Default False for training: model needs post-correctness turns to + # learn speedup optimization. Set True at eval time if desired. + early_stop_on_correct: bool = False + + # Optional: require this speedup before early stopping + speedup_threshold: float | None = None + + # Prompt + prompt_max_tokens: int | None = None # Token budget for history truncation (None = char fallback) + inject_think_token: bool = False # Append \n to generation prompts + + # Generation + temperature: float = 0.9 + top_p: float = 1.0 + seed: int | None = None + + # Response length extension mid-training (0 = disabled) + max_tokens_extended: int = 22000 + max_tokens_extend_after_step: int = 30 + + # Training + loss_fn: str = "ppo" + max_grad_norm: float = 0.05 + warmup_ratio: float = 0.03 + clip_epsilon_low: float = 0.2 + clip_epsilon_high: float = 0.28 + constant_length_norm: int = 16384 + num_substeps: int = 2 diff --git a/src/kernelbench_tinker/config/rl_kernelbench.yaml b/src/kernelbench_tinker/config/rl_kernelbench.yaml index bda2995..38bc2f0 100644 --- a/src/kernelbench_tinker/config/rl_kernelbench.yaml +++ b/src/kernelbench_tinker/config/rl_kernelbench.yaml @@ -26,6 +26,33 @@ learning_rate: 0.000002 # 2e-6 as explicit float max_tokens: 16384 temperature: 1.0 +# ============================================================================= +# Multi-turn Configuration (disabled by default) +# ============================================================================= +multiturn: + enabled: false # true to enable iterative refinement + max_turns: 4 # Maximum refinement turns per trajectory + gamma: 0.4 # Discount factor for multi-turn returns + aggregation: "sum" # "sum" (reward many good kernels) or "max" (reward one great kernel) + early_stop_on_correct: false # Stop episode when kernel passes all tests + speedup_threshold: null # Required speedup before early stopping (null = any correct) + # Prompt + prompt_max_tokens: null # Token budget for history truncation (null = char fallback) + inject_think_token: false # Append \n to generation prompts + # Generation + temperature: 0.9 # Generation temperature + top_p: 1.0 # Nucleus sampling (1.0 = disabled) + seed: null # Random seed for generation (null = random) + max_tokens_extended: 22000 # Extend max_tokens mid-training (0 = disabled) + max_tokens_extend_after_step: 30 # Step at which to switch + # Training + loss_fn: "ppo" # Loss function (single-turn uses top-level loss_fn) + max_grad_norm: 0.05 # Gradient clipping (0.0 = disabled) + warmup_ratio: 0.03 # Linear LR warmup fraction + clip_epsilon_low: 0.2 # PPO clip lower bound + clip_epsilon_high: 0.28 # PPO clip upper bound (Clip-High) + constant_length_norm: 16384 # GRPO constant length normalization (0 = disabled) + # ============================================================================= # Training Configuration # ============================================================================= @@ -57,6 +84,7 @@ dataset_builder: # Problem Selection # --------------------------------------------------------------------------- level: 1 # KernelBench level (1, 2, 3, or 4) + levels: null # Train on multiple levels (e.g. [1, 2]); overrides level when set start_problem: null # First problem ID (null = start from 1) end_problem: null # Last problem ID (null = all problems) dataset_src: "huggingface" # "huggingface" or "local" @@ -107,6 +135,9 @@ dataset_builder: reward_correctness_weight: 0.3 reward_speed_weight: 1.0 reward_length_weight: 0.0 + reward_speed_max_reward: 10.0 # Cap on speed reward component (set high to uncap) + reward_clip_min: null # Lower bound on total reward (null = no clipping) + reward_clip_max: null # Upper bound on total reward (null = no clipping) # --------------------------------------------------------------------------- # Reward Hacking Detection (Static Checker) diff --git a/src/kernelbench_tinker/envs/env_utils.py b/src/kernelbench_tinker/envs/env_utils.py new file mode 100644 index 0000000..562e20b --- /dev/null +++ b/src/kernelbench_tinker/envs/env_utils.py @@ -0,0 +1,150 @@ +""" +Shared utilities for KernelBench environments. + +Contains helpers used by both the single-turn and multi-turn environments: +- System prompt construction +- Step evaluation (parse → evaluate → reward → metrics) +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from tinker_cookbook import renderers +from tinker_cookbook.rl.types import Action, Metrics + +from kernelbench_tinker.config.configs import EvalConfig +from kernelbench_tinker.envs.kernelbench_client import ( + KernelBenchProblem, + KernelEvalResult, + ParsedResponse, + evaluate_kernel_async, + parse_structured_response, +) +from kernelbench_tinker.training.reward import ( + RewardConfig, + compute_reward, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class EvalStepResult: + """Result from evaluate_step(), shared by single-turn and multi-turn envs.""" + + parsed: ParsedResponse + eval_result: KernelEvalResult + format_ok: bool + kernel_code: str + reward: float + metrics: Metrics + response_text: str # Raw response content from renderer (before structured parsing) + + +def build_system_prompt(backend: str) -> str: + """Build a backend-specific system prompt for kernel generation. + + Used by both single-turn and multi-turn environments. + """ + return ( + f"You are an expert GPU kernel developer. Your task is to optimize PyTorch " + f"operations by writing efficient custom {backend.upper()} kernels.\n" + f"\n" + f"When given a PyTorch model, write an optimized kernel implementation.\n" + f"\n" + f"Your solution must:\n" + f"- Be a drop-in replacement as a class named `ModelNew`\n" + f"- Use custom {backend.upper()} kernels, not just PyTorch operations\n" + f"- Be correct and produce the same results as the reference\n" + f"\n" + f"You MUST respond in exactly this format:\n" + f"\n" + f"\n" + f"```python\n" + f"# Your complete optimized implementation here\n" + f"class ModelNew(nn.Module):\n" + f" ...\n" + f"```\n" + f"" + ) + + +async def evaluate_step( + problem: KernelBenchProblem, + renderer: renderers.Renderer, + action: Action, + eval_config: EvalConfig, + reward_config: RewardConfig, + step_start: float, +) -> EvalStepResult: + """Parse, evaluate, and compute reward for a single action. + + Shared by KernelBenchEnv.step() and MultiTurnKernelBenchEnv.step(). + """ + message, _ = renderer.parse_response(action) + response_text = message.get("content", "") + + parsed = parse_structured_response(response_text) + kernel_code = parsed.kernel + format_ok = parsed.format_ok + + eval_start = time.perf_counter() + cfg = eval_config + eval_result = await evaluate_kernel_async( + level=problem.level, + problem_id=problem.problem_id, + backend=problem.backend, + kernel_code=kernel_code, + dataset_src=problem.dataset_src, + num_correct_trials=cfg.num_correct_trials, + measure_performance=cfg.measure_performance, + num_perf_trials=cfg.num_perf_trials, + timing_method=cfg.timing_method, + precision=cfg.precision, + check_for_excessive_speedup=cfg.check_for_excessive_speedup, + excessive_speedup_threshold=cfg.excessive_speedup_threshold, + timeout=cfg.modal_timeout, + ) + eval_time = time.perf_counter() - eval_start + + reward = compute_reward( + eval_result, + reward_config, + kernel_code=kernel_code, + backend=problem.backend, + ) + + metrics: Metrics = { + "level": problem.level, + "problem_id": problem.problem_id, + "format_ok": float(format_ok), + "compiled": float(eval_result["compiled"]), + "correctness": float(eval_result["correctness"]), + "tests_passed": eval_result["tests_passed"], + "tests_total": eval_result["tests_total"], + } + if eval_result.get("speedup") is not None: + metrics["speedup"] = eval_result["speedup"] + if eval_result.get("runtime_ms") is not None: + metrics["runtime_ms"] = eval_result["runtime_ms"] + metrics["time/eval"] = eval_time + timing_metadata = (eval_result.get("metadata") or {}).get("timings", {}) + if "reference_load_s" in timing_metadata: + metrics["time/ref_load"] = timing_metadata["reference_load_s"] + if "modal_eval_s" in timing_metadata: + metrics["time/modal_eval"] = timing_metadata["modal_eval_s"] + metrics["time/step_total"] = time.perf_counter() - step_start + + return EvalStepResult( + parsed=parsed, + eval_result=eval_result, + format_ok=format_ok, + kernel_code=kernel_code, + reward=reward, + metrics=metrics, + response_text=response_text, + ) diff --git a/src/kernelbench_tinker/envs/kernelbench_client.py b/src/kernelbench_tinker/envs/kernelbench_client.py index 495d44f..cc7d304 100644 --- a/src/kernelbench_tinker/envs/kernelbench_client.py +++ b/src/kernelbench_tinker/envs/kernelbench_client.py @@ -9,13 +9,13 @@ import functools import hashlib +import logging import os import re import sys import time from collections import OrderedDict from dataclasses import dataclass, field -import logging from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) @@ -33,11 +33,18 @@ re.DOTALL | re.IGNORECASE ) +# Summary block pattern - reasoning summary inside ... +SUMMARY_BLOCK_PATTERN = re.compile( + r"(.*?)", + re.DOTALL | re.IGNORECASE +) + @dataclass class ParsedResponse: """Parsed model response with kernel blocks.""" kernel: str # Kernel code (from block or extracted code block) + cot_summary: str # Reasoning summary (from block) raw: str # Original raw response format_ok: bool # Whether we successfully extracted kernel code @@ -94,8 +101,15 @@ def parse_structured_response(text: str) -> ParsedResponse: # Check if we got valid kernel code format_ok = bool(kernel) and ("class ModelNew" in kernel or "def forward" in kernel) + # Extract CoT summary from block + cot_summary = "" + summary_match = SUMMARY_BLOCK_PATTERN.search(text) + if summary_match: + cot_summary = summary_match.group(1).strip() + return ParsedResponse( kernel=kernel, + cot_summary=cot_summary, raw=raw, format_ok=format_ok, ) @@ -487,6 +501,7 @@ class KernelBenchProblem: prompt_gpu_name: str | None = None _prompt: str | None = field(default=None, repr=False) + _base_prompt: str | None = field(default=None, repr=False) @property def prompt(self) -> str: @@ -504,3 +519,23 @@ def prompt(self) -> str: ) return self._prompt + @property + def base_prompt(self) -> str: + """Get the zero-shot prompt (no examples) for refinement turns. + + In multi-turn training, the one-shot example is included only on the + first turn. Subsequent turns use this stripped-down prompt to save + context tokens. + """ + if self._base_prompt is None: + self._base_prompt = get_prompt_for_problem( + self.level, + self.problem_id, + self.backend, + option="zero_shot", + dataset_src=self.dataset_src, + precision=self.prompt_precision, + include_hardware=self.prompt_include_hardware, + gpu_name=self.prompt_gpu_name, + ) + return self._base_prompt diff --git a/src/kernelbench_tinker/envs/kernelbench_env.py b/src/kernelbench_tinker/envs/kernelbench_env.py index 209ef5b..f755f80 100644 --- a/src/kernelbench_tinker/envs/kernelbench_env.py +++ b/src/kernelbench_tinker/envs/kernelbench_env.py @@ -30,50 +30,27 @@ ) from tinker_cookbook.utils import logtree +from kernelbench_tinker.config.configs import EvalConfig +from kernelbench_tinker.envs.env_utils import ( + EvalStepResult, + build_system_prompt, + evaluate_step, +) from kernelbench_tinker.envs.kernelbench_client import ( KernelBenchProblem, KernelEvalResult, ParsedResponse, - evaluate_kernel_async, get_problem_ids, - parse_structured_response, ) -from kernelbench_tinker.config.configs import EvalConfig from kernelbench_tinker.training.reward import ( - compute_reward, - compute_reward_breakdown, RewardConfig, + compute_reward_breakdown, ) from kernelbench_tinker.training.trace_logger import get_trace_logger logger = logging.getLogger(__name__) -# Default system prompt for kernel generation (structured format) -DEFAULT_SYSTEM_PROMPT = """You are an expert GPU kernel developer. Your task is to optimize PyTorch operations by writing efficient custom GPU kernels. - -When given a PyTorch model, you should: -1. Analyze the operations being performed -2. Write an optimized kernel implementation -3. Return your solution as a Python class named `ModelNew` that implements the same interface - -Your kernel should: -- Be functionally correct (produce the same outputs as the reference) -- Be efficient (aim for speedup over the PyTorch baseline) -- Handle edge cases properly -- Use the specified backend (Triton, CUDA, etc.) - -You MUST respond in exactly this format: - - -```python -# Your complete optimized implementation here -class ModelNew(nn.Module): - ... -``` -""" - - class KernelBenchEnv(Env): """ A single-turn RL environment for a KernelBench problem. @@ -119,8 +96,7 @@ def _build_initial_messages(self) -> list[renderers.Message]: """Build the initial conversation for the problem.""" messages: list[renderers.Message] = [] - # Add system prompt if supported - messages.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT}) + messages.append({"role": "system", "content": build_system_prompt(self.problem.backend)}) # Add the problem prompt as user message messages.append({"role": "user", "content": self.problem.prompt}) @@ -151,97 +127,40 @@ async def step(self, action: Action) -> StepResult: StepResult with reward and episode done status """ step_start = time.perf_counter() - # Parse the response to get text - message, _ = self.renderer.parse_response(action) - response_text = message.get("content", "") - - # Parse structured response (extracts block) - parsed = parse_structured_response(response_text) - kernel_code = parsed.kernel - # Check format validity - format_ok = parsed.format_ok - - # Evaluate the kernel (Modal for isolated GPU execution) - eval_start = time.perf_counter() - cfg = self.eval_config - eval_result = await evaluate_kernel_async( - level=self.problem.level, - problem_id=self.problem.problem_id, - backend=self.problem.backend, - kernel_code=kernel_code, - dataset_src=self.problem.dataset_src, - num_correct_trials=cfg.num_correct_trials, - measure_performance=cfg.measure_performance, - num_perf_trials=cfg.num_perf_trials, - timing_method=cfg.timing_method, - precision=cfg.precision, - check_for_excessive_speedup=cfg.check_for_excessive_speedup, - excessive_speedup_threshold=cfg.excessive_speedup_threshold, - timeout=cfg.modal_timeout, - ) - eval_time = time.perf_counter() - eval_start - - # Compute reward (pass kernel_code for static checking) - reward = compute_reward( - eval_result, - self.reward_config, - kernel_code=kernel_code, - backend=self.problem.backend, + r = await evaluate_step( + self.problem, self.renderer, action, + self.eval_config, self.reward_config, step_start, ) # Log the attempt logtree.log_text(f"Problem: Level {self.problem.level}, ID {self.problem.problem_id}") - logtree.log_text(f"Format OK: {'Yes' if format_ok else 'No'}") - logtree.log_text(f"Compiled: {'Yes' if eval_result['compiled'] else 'No'}") + logtree.log_text(f"Format OK: {'Yes' if r.format_ok else 'No'}") + logtree.log_text(f"Compiled: {'Yes' if r.eval_result['compiled'] else 'No'}") logtree.log_text( - f"Correctness: {eval_result['tests_passed']}/{eval_result['tests_total']}" + f"Correctness: {r.eval_result['tests_passed']}/{r.eval_result['tests_total']}" ) - if eval_result.get("speedup"): - logtree.log_text(f"Speedup: {eval_result['speedup']:.2f}x") - logtree.log_text(f"Reward: {reward:.3f}") - error_message = eval_result.get("error_message") + if r.eval_result.get("speedup") is not None: + logtree.log_text(f"Speedup: {r.eval_result['speedup']:.2f}x") + logtree.log_text(f"Reward: {r.reward:.3f}") + error_message = r.eval_result.get("error_message") if error_message: logtree.log_text(f"Error: {error_message[:200]}") - # Build metrics - metrics: Metrics = { - "level": self.problem.level, - "problem_id": self.problem.problem_id, - "format_ok": float(format_ok), - "compiled": float(eval_result["compiled"]), - "correctness": float(eval_result["correctness"]), - "tests_passed": eval_result["tests_passed"], - "tests_total": eval_result["tests_total"], - } - if eval_result.get("speedup"): - metrics["speedup"] = eval_result["speedup"] - if eval_result.get("runtime_ms"): - metrics["runtime_ms"] = eval_result["runtime_ms"] - metrics["time/eval"] = eval_time - timing_metadata = (eval_result.get("metadata") or {}).get("timings", {}) - if "reference_load_s" in timing_metadata: - metrics["time/ref_load"] = timing_metadata["reference_load_s"] - if "modal_eval_s" in timing_metadata: - metrics["time/modal_eval"] = timing_metadata["modal_eval_s"] - metrics["time/step_total"] = time.perf_counter() - step_start - # Trace logging (prompt + response + eval) await self._log_trace( - parsed=parsed, - eval_result=eval_result, - format_ok=format_ok, - reward=reward, - metrics=metrics, + parsed=r.parsed, + eval_result=r.eval_result, + format_ok=r.format_ok, + reward=r.reward, + metrics=r.metrics, ) - episode_done = True - return StepResult( - reward=reward, - episode_done=episode_done, + reward=r.reward, + episode_done=True, next_observation=tinker.ModelInput.empty(), next_stop_condition=self.stop_condition, - metrics=metrics, + metrics=r.metrics, ) async def _log_trace( @@ -349,19 +268,6 @@ def __init__( shuffle: bool = True, num_epochs: int = 1, ): - """ - Initialize the RL dataset. - - Args: - problems: List of KernelBench problems - renderer: Tinker renderer for formatting - batch_size: Number of problems per batch - group_size: Number of rollouts per problem - eval_config: Configuration for kernel evaluation - reward_config: Reward configuration - shuffle: Whether to shuffle problems each epoch - num_epochs: Number of training epochs - """ self.problems = problems self.renderer = renderer self.batch_size = batch_size @@ -401,15 +307,13 @@ def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]: for i in range(start_idx, end_idx): problem_idx = self._problem_indices[i] problem = self.problems[problem_idx] - - builder = KernelBenchEnvGroupBuilder( + builders.append(KernelBenchEnvGroupBuilder( problem=problem, renderer=self.renderer, group_size=self.group_size, eval_config=self.eval_config, reward_config=self.reward_config, - ) - builders.append(builder) + )) return builders @@ -425,6 +329,7 @@ class KernelBenchDatasetBuilder(RLDatasetBuilder): # Problem selection level: int = 1 + levels: list[int] | None = None # Train on multiple levels (overrides level when set) start_problem: int | None = None end_problem: int | None = None backend: str = "triton" @@ -452,6 +357,11 @@ class KernelBenchDatasetBuilder(RLDatasetBuilder): reward_speed_weight: float = 1.0 reward_length_weight: float = 0.0 + # Reward clipping and speed cap + reward_clip_min: float | None = None + reward_clip_max: float | None = None + reward_speed_max_reward: float = 10.0 # Cap on speed reward component + # Reward hacking detection (static checker) reward_enable_static_checker: bool = True reward_static_checker_backend: str = "triton" @@ -464,6 +374,9 @@ class KernelBenchDatasetBuilder(RLDatasetBuilder): # Test split test_fraction: float = 0.1 + # Explicit holdout indices per level (overrides test_fraction when set) + # Format: {level: [problem_ids]} e.g. {1: [3,10,25], 2: [10,20,30]} + holdout_indices: dict[int, list[int]] | None = None # Prompt configuration prompt_option: str = "one_shot" # "zero_shot", "one_shot", "few_shot" @@ -481,33 +394,54 @@ async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: Args: tokenizer: The tokenizer to use for the renderer. Required for most renderers. """ - # Get problem IDs - problem_ids = get_problem_ids( - self.level, - start=self.start_problem, - end=self.end_problem, - dataset_src=self.dataset_src, - ) - - # Create problems - all_problems = [ - KernelBenchProblem( - level=self.level, - problem_id=pid, - backend=self.backend, + # Determine which levels to use + active_levels = self.levels if self.levels else [self.level] + + # Collect problems across all levels + all_problems: list[KernelBenchProblem] = [] + for lvl in active_levels: + # Get problem IDs + problem_ids = get_problem_ids( + lvl, + start=self.start_problem, + end=self.end_problem, dataset_src=self.dataset_src, - prompt_option=self.prompt_option, - prompt_precision=self.prompt_precision or self.precision, - prompt_include_hardware=self.prompt_include_hardware, - prompt_gpu_name=self.prompt_gpu_name or ( - self.modal_gpu_type if self.prompt_include_hardware else None - ), ) - for pid in problem_ids - ] + + # Create problems + all_problems.extend( + KernelBenchProblem( + level=lvl, + problem_id=pid, + backend=self.backend, + dataset_src=self.dataset_src, + prompt_option=self.prompt_option, + prompt_precision=self.prompt_precision or self.precision, + prompt_include_hardware=self.prompt_include_hardware, + prompt_gpu_name=self.prompt_gpu_name or ( + self.modal_gpu_type if self.prompt_include_hardware else None + ), + ) + for pid in problem_ids + ) # Split into train/test - if self.test_fraction > 0 and len(all_problems) > 1: + if self.holdout_indices: + # Explicit holdout: separate by (level, problem_id) membership + holdout_set = { + (lvl, pid) + for lvl, pids in self.holdout_indices.items() + for pid in pids + } + train_problems = [ + p for p in all_problems + if (p.level, p.problem_id) not in holdout_set + ] + test_problems = [ + p for p in all_problems + if (p.level, p.problem_id) in holdout_set + ] or None + elif self.test_fraction > 0 and len(all_problems) > 1: n_test = max(1, int(len(all_problems) * self.test_fraction)) # Use last N problems as test set for reproducibility train_problems = all_problems[:-n_test] @@ -539,6 +473,9 @@ async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: correctness_weight=self.reward_correctness_weight, speed_weight=self.reward_speed_weight, length_weight=self.reward_length_weight, + speed_max_reward=self.reward_speed_max_reward, + reward_clip_min=self.reward_clip_min, + reward_clip_max=self.reward_clip_max, enable_static_checker=self.reward_enable_static_checker, static_checker_backend=self.reward_static_checker_backend or self.backend, static_checker_precision=self.reward_static_checker_precision or self.precision, @@ -547,7 +484,11 @@ async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: ) # Configure Modal evaluator with the same config - from kernelbench_tinker.modal.evaluator import ModalEvaluatorConfig, set_modal_evaluator, ModalKernelEvaluator + from kernelbench_tinker.modal.evaluator import ( + ModalEvaluatorConfig, + ModalKernelEvaluator, + set_modal_evaluator, + ) modal_config = ModalEvaluatorConfig( enabled=True, gpu_type=eval_config.modal_gpu_type, @@ -583,4 +524,3 @@ async def __call__(self, tokenizer=None) -> tuple[RLDataset, RLDataset | None]: ) return train_dataset, test_dataset - diff --git a/src/kernelbench_tinker/envs/multiturn_kernelbench_env.py b/src/kernelbench_tinker/envs/multiturn_kernelbench_env.py new file mode 100644 index 0000000..6e084c5 --- /dev/null +++ b/src/kernelbench_tinker/envs/multiturn_kernelbench_env.py @@ -0,0 +1,532 @@ +""" +Multi-turn KernelBench RL environment. + +Extends the single-turn KernelBenchEnv to support iterative kernel refinement. +Each episode consists of up to T turns where the model receives evaluation +feedback and can fix errors or improve performance. +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Sequence + +import tinker +from tinker.types.model_input_chunk import EncodedTextChunk +from tinker_cookbook import renderers +from tinker_cookbook.completers import StopCondition +from tinker_cookbook.rl.types import ( + Action, + Env, + EnvGroupBuilder, + Metrics, + Observation, + StepResult, + Trajectory, +) +from tinker_cookbook.utils import logtree + +from kernelbench_tinker.config.configs import EvalConfig, MultiTurnConfig +from kernelbench_tinker.envs.env_utils import build_system_prompt, evaluate_step +from kernelbench_tinker.envs.kernelbench_client import ( + KernelBenchProblem, + KernelEvalResult, + ParsedResponse, +) +from kernelbench_tinker.training.reward import ( + RewardConfig, + compute_reward_breakdown, +) +from kernelbench_tinker.training.trace_logger import get_trace_logger + +logger = logging.getLogger(__name__) + +# Limit for feedback content included in refinement prompts (char-based fallback) +MAX_HISTORY_CONTEXT_LEN = 8000 + +def extract_raw_content(response_text: str, eos_token: str | None = None) -> str: + """Extract assistant content from a response, stripping the thinking block. + + Keep text after ```` when present. Strip an unclosed ```` + prefix if the model started thinking but didn't close the tag. If the + response has no ```` and the EOS token is missing, return the EOS + token as a null-response marker. + """ + if "" in response_text: + return response_text.split("")[-1].lstrip('\n') + if "" in response_text: + # Unclosed thinking block — strip it + return response_text.split("")[0].strip() + if eos_token is not None and eos_token not in response_text: + return eos_token + return response_text + + +def build_eval_feedback(eval_result: KernelEvalResult) -> str: + """Build feedback string from an evaluation result for the next refinement turn.""" + error_msg = eval_result.get("error_message") or "" + metadata = eval_result.get("metadata") or {} + error_type = metadata.get("error_type", "") + + if not eval_result["format_ok"] or error_type == "parsing_error": + resp = ( + "Your previous answer failed to be parsed due to not adhering " + f"to the desired formatting. Here's the error message: {error_msg}.\n" + ) + elif not eval_result["compiled"]: + resp = ( + "Your previous answer failed to compile. " + f"Here's the error message: {error_msg}.\n" + ) + elif error_type == "runtime_error" or (not eval_result["correctness"] and error_msg): + # Runtime error: compiled successfully but had runtime errors + resp = ( + "Your previous answer compiled successfully but had runtime " + f"errors. Here's the error message: {error_msg}.\n" + ) + elif not eval_result["correctness"]: + # Incorrect output + resp = ( + "Your previous answer was incorrect. " + f"Here's the error message: {error_msg}.\n" + ) + else: + speedup = eval_result.get("speedup") or 0.0 + resp = ( + "Your previous answer was correct but can be made faster. " + "Here's the speedup you achieved relative to the baseline: " + f"{speedup:.2f}.\n" + ) + + resp += "\nRestart your reasoning process and generate new, complete code." + return resp + + +# --------------------------------------------------------------------------- +# Multi-turn state +# --------------------------------------------------------------------------- + + +@dataclass +class MultiTurnState: + """Mutable state for a multi-turn kernel refinement episode.""" + + level: int + problem_id: int + backend: str + turn_idx: int + max_turns: int + history: list[dict] # Per-turn: {raw_content, kernel, feedback, score} + step_scores: list[float] + done: bool + success: bool + + +# --------------------------------------------------------------------------- +# Multi-turn environment +# --------------------------------------------------------------------------- + + +class MultiTurnKernelBenchEnv(Env): + """ + Multi-turn RL environment for KernelBench. + + Each episode consists of up to T refinement steps: + 1. Turn 0: problem prompt (same as single-turn) + 2. Turn 1+: problem prompt + previous attempt feedback + + The episode ends when the kernel is correct (early stopping) or + max_turns is reached. + """ + + def __init__( + self, + problem: KernelBenchProblem, + renderer: renderers.Renderer, + max_turns: int = 4, + eval_config: EvalConfig | None = None, + reward_config: RewardConfig | None = None, + system_prompt: str | None = None, + early_stop_on_correct: bool = False, + speedup_threshold: float | None = None, + tokenizer: Any | None = None, + prompt_max_tokens: int | None = None, + inject_think_token: bool = False, + ): + self.problem = problem + self.renderer = renderer + self.max_turns = max_turns + self.eval_config = eval_config or EvalConfig() + self.reward_config = reward_config or RewardConfig() + self.early_stop_on_correct = early_stop_on_correct + self.speedup_threshold = speedup_threshold + self.tokenizer = tokenizer + self.prompt_max_tokens = prompt_max_tokens + self.inject_think_token = inject_think_token + + self._system_prompt = system_prompt or build_system_prompt( + problem.backend, + ) + + self._current_prompt_messages: list[renderers.Message] | None = None + self._state: MultiTurnState | None = None + + @property + def stop_condition(self) -> StopCondition: + return self.renderer.get_stop_sequences() + + def _append_think_token(self, observation: tinker.ModelInput) -> tinker.ModelInput: + """Append ``\\n`` tokens after the chat template to force thinking mode.""" + if not self.inject_think_token or self.tokenizer is None: + return observation + think_ids = self.tokenizer.encode("\n", add_special_tokens=False) + return observation.append(EncodedTextChunk(tokens=think_ids)) + + @property + def state(self) -> MultiTurnState: + if self._state is None: + raise RuntimeError( + "Environment not initialized. Call initial_observation first." + ) + return self._state + + def _build_initial_messages(self) -> list[renderers.Message]: + messages: list[renderers.Message] = [] + if self._system_prompt: + messages.append({"role": "system", "content": self._system_prompt}) + messages.append({"role": "user", "content": self.problem.prompt}) + return messages + + def _count_message_tokens(self, messages: list[renderers.Message]) -> int: + """Count total tokens across all messages using the tokenizer.""" + if self.tokenizer is None: + return 0 + total = 0 + for msg in messages: + content = msg.get("content", "") + total += len(self.tokenizer.encode(content)) + return total + + def _build_refinement_messages(self) -> list[renderers.Message]: + """Build refinement prompt with history as alternating assistant/user turns. + + Uses token-based truncation (oldest-first) when tokenizer and + prompt_max_tokens are set, otherwise falls back to char-based. + """ + messages: list[renderers.Message] = [] + if self._system_prompt: + messages.append({"role": "system", "content": self._system_prompt}) + + # Initial user message: problem without one-shot example + base = self.problem.base_prompt + messages.append({ + "role": "user", + "content": base + "Here are your previous attempts:\n", + }) + + if self.state.history: + history = list(self.state.history) + + if self.tokenizer is not None and self.prompt_max_tokens is not None: + # Token-based truncation: count base tokens, fit history + # into remaining budget, keeping most recent entries. + base_tokens = self._count_message_tokens(messages) + # 32-token safety buffer for tokenizer boundary effects + budget = self.prompt_max_tokens - base_tokens - 32 + + # Walk history backwards, accumulating tokens + kept: list[dict] = [] + used = 0 + for entry in reversed(history): + entry_text = entry["raw_content"] + entry["feedback"] + entry_tokens = len(self.tokenizer.encode(entry_text)) + if used + entry_tokens > budget and kept: + break + kept.append(entry) + used += entry_tokens + history = list(reversed(kept)) + else: + # Char-based fallback + total_len = sum( + len(e["raw_content"]) + len(e["feedback"]) for e in history + ) + while total_len > MAX_HISTORY_CONTEXT_LEN and len(history) > 1: + removed = history.pop(0) + total_len -= len(removed["raw_content"]) + len(removed["feedback"]) + + # Add history as assistant/user turn pairs + for entry in history: + messages.append({"role": "assistant", "content": entry["raw_content"]}) + messages.append({"role": "user", "content": entry["feedback"]}) + + return messages + + async def initial_observation(self) -> tuple[Observation, StopCondition]: + self._state = MultiTurnState( + level=self.problem.level, + problem_id=self.problem.problem_id, + backend=self.problem.backend, + turn_idx=0, + max_turns=self.max_turns, + history=[], + step_scores=[], + done=False, + success=False, + ) + messages = self._build_initial_messages() + observation = self.renderer.build_generation_prompt(messages) + observation = self._append_think_token(observation) + self._current_prompt_messages = messages + return observation, self.stop_condition + + async def step(self, action: Action) -> StepResult: + step_start = time.perf_counter() + state = self.state + + r = await evaluate_step( + self.problem, self.renderer, action, + self.eval_config, self.reward_config, step_start, + ) + state.step_scores.append(r.reward) + + # Extract content for history. + # Use CoT summary when available (strips full reasoning, keeps concise + # summary + kernel code), otherwise fall back to raw content with + # thinking block stripped. + eos_token = getattr(self.tokenizer, "eos_token", None) if self.tokenizer else None + raw_content = extract_raw_content(r.response_text, eos_token) + history_content = r.parsed.cot_summary if r.parsed.cot_summary else raw_content + + # Build feedback and store in history + feedback = build_eval_feedback(r.eval_result) + state.history.append({ + "raw_content": history_content, + "kernel": r.kernel_code, + "feedback": feedback, + "score": r.reward, + }) + + # Log + logtree.log_text( + f"Multi-turn: Level {state.level}, ID {state.problem_id}, " + f"Turn {state.turn_idx}" + ) + logtree.log_text(f"Format OK: {'Yes' if r.format_ok else 'No'}") + logtree.log_text( + f"Compiled: {'Yes' if r.eval_result['compiled'] else 'No'}" + ) + logtree.log_text( + f"Correctness: {r.eval_result['tests_passed']}/{r.eval_result['tests_total']}" + ) + if r.eval_result.get("speedup") is not None: + logtree.log_text(f"Speedup: {r.eval_result['speedup']:.2f}x") + logtree.log_text(f"Step score: {r.reward:.3f}") + + # Early stopping + is_correct = r.eval_result["correctness"] + meets_speedup = ( + self.speedup_threshold is None + or (r.eval_result.get("speedup") or 0.0) >= self.speedup_threshold + ) + if self.early_stop_on_correct and is_correct and meets_speedup: + state.done = True + state.success = True + + state.turn_idx += 1 + if state.turn_idx >= state.max_turns: + state.done = True + + # Add multi-turn fields to shared metrics + metrics = r.metrics + metrics["turn"] = state.turn_idx - 1 + metrics["step_score"] = r.reward + metrics["episode_done"] = float(state.done) + metrics["episode_success"] = float(state.success) + + # Trace logging + await self._log_trace( + parsed=r.parsed, + eval_result=r.eval_result, + format_ok=r.format_ok, + reward=r.reward, + metrics=metrics, + ) + + # Next observation or done + if state.done: + next_observation = tinker.ModelInput.empty() + else: + messages = self._build_refinement_messages() + next_observation = self.renderer.build_generation_prompt(messages) + next_observation = self._append_think_token(next_observation) + self._current_prompt_messages = messages + + return StepResult( + reward=r.reward, + episode_done=state.done, + next_observation=next_observation, + next_stop_condition=self.stop_condition, + metrics=metrics, + ) + + async def _log_trace( + self, + parsed: ParsedResponse, + eval_result: KernelEvalResult, + format_ok: bool, + reward: float, + metrics: Metrics, + ) -> None: + trace_logger = get_trace_logger() + if trace_logger is None: + return + + trace_record = { + "mode": "multi_turn", + "level": self.problem.level, + "problem_id": self.problem.problem_id, + "backend": self.problem.backend, + "dataset_src": self.problem.dataset_src, + "prompt_option": self.problem.prompt_option, + "turn": self.state.turn_idx - 1, + "max_turns": self.state.max_turns, + "prompt_messages": self._current_prompt_messages, + "renderer": getattr( + self.renderer, "name", type(self.renderer).__name__ + ), + "response": { + "raw": parsed.raw, + "kernel": parsed.kernel, + "cot_summary": parsed.cot_summary, + "format_ok": format_ok, + }, + "eval_result": eval_result, + "reward": reward, + "reward_breakdown": compute_reward_breakdown( + eval_result, + self.reward_config, + kernel_code=parsed.kernel, + backend=self.problem.backend, + ), + "metrics": metrics, + "history": [ + { + "raw_content": entry["raw_content"], + "kernel": entry["kernel"], + "feedback": entry["feedback"], + "score": entry["score"], + } + for entry in self.state.history + ], + "state": { + "turn_idx": self.state.turn_idx, + "done": self.state.done, + "success": self.state.success, + "step_scores": list(self.state.step_scores), + }, + "timestamp": time.time(), + "stop_condition": str(self.stop_condition), + } + + await trace_logger.log(trace_record) + + def get_step_scores(self) -> list[float]: + """Return per-step scores for discounted return computation.""" + return list(self.state.step_scores) + + +# --------------------------------------------------------------------------- +# Group builder, dataset, dataset builder (mirrors single-turn structure) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class MultiTurnKernelBenchEnvGroupBuilder(EnvGroupBuilder): + """Builder for groups of multi-turn KernelBench environments.""" + + problem: KernelBenchProblem + renderer: renderers.Renderer + group_size: int + max_turns: int = 4 + eval_config: EvalConfig = field(default_factory=EvalConfig) + reward_config: RewardConfig = field(default_factory=RewardConfig) + system_prompt: str | None = None + early_stop_on_correct: bool = False + speedup_threshold: float | None = None + tokenizer: Any | None = field(default=None, hash=False, compare=False) + prompt_max_tokens: int | None = None + inject_think_token: bool = False + + async def make_envs(self) -> Sequence[Env]: + return [ + MultiTurnKernelBenchEnv( + problem=self.problem, + renderer=self.renderer, + max_turns=self.max_turns, + eval_config=self.eval_config, + reward_config=self.reward_config, + system_prompt=self.system_prompt, + early_stop_on_correct=self.early_stop_on_correct, + speedup_threshold=self.speedup_threshold, + tokenizer=self.tokenizer, + prompt_max_tokens=self.prompt_max_tokens, + inject_think_token=self.inject_think_token, + ) + for _ in range(self.group_size) + ] + + async def compute_group_rewards( + self, + trajectory_group: list[Trajectory], + env_group: Sequence[Env], + ) -> list[tuple[float, Metrics]]: + # No-op: real rewards are computed per-step inside env.step() and + # overwritten by apply_discounted_returns before advantage estimation. + return [(0.0, {}) for _ in trajectory_group] + + def logging_tags(self) -> list[str]: + return [ + f"level_{self.problem.level}", + f"problem_{self.problem.problem_id}", + "kernelbench", + "multiturn", + ] + + +def wrap_builders_as_multiturn( + builders: Sequence[EnvGroupBuilder], + multiturn_cfg: MultiTurnConfig, + tokenizer: Any | None = None, +) -> list[MultiTurnKernelBenchEnvGroupBuilder]: + """Wrap single-turn KernelBenchEnvGroupBuilders as multi-turn builders. + + Called by the training loop when multiturn.enabled is True. Reads + problem/renderer/group_size/eval_config/reward_config from each + single-turn builder and creates MultiTurnKernelBenchEnvGroupBuilder + instances with the multi-turn config. + """ + from kernelbench_tinker.envs.kernelbench_env import KernelBenchEnvGroupBuilder + + wrapped = [] + for b in builders: + if not isinstance(b, KernelBenchEnvGroupBuilder): + raise TypeError( + f"Expected KernelBenchEnvGroupBuilder, got {type(b).__name__}" + ) + wrapped.append(MultiTurnKernelBenchEnvGroupBuilder( + problem=b.problem, + renderer=b.renderer, + group_size=b.group_size, + max_turns=multiturn_cfg.max_turns, + eval_config=b.eval_config, + reward_config=b.reward_config, + system_prompt=build_system_prompt(b.problem.backend), + early_stop_on_correct=multiturn_cfg.early_stop_on_correct, + speedup_threshold=multiturn_cfg.speedup_threshold, + tokenizer=tokenizer, + prompt_max_tokens=multiturn_cfg.prompt_max_tokens, + inject_think_token=multiturn_cfg.inject_think_token, + )) + return wrapped diff --git a/src/kernelbench_tinker/evaluation/eval_kernelbench.py b/src/kernelbench_tinker/evaluation/eval_kernelbench.py index cf00836..f4adbd7 100644 --- a/src/kernelbench_tinker/evaluation/eval_kernelbench.py +++ b/src/kernelbench_tinker/evaluation/eval_kernelbench.py @@ -10,7 +10,7 @@ import json import logging import os -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from typing import Any import numpy as np diff --git a/src/kernelbench_tinker/modal/app.py b/src/kernelbench_tinker/modal/app.py index ce87066..825b987 100644 --- a/src/kernelbench_tinker/modal/app.py +++ b/src/kernelbench_tinker/modal/app.py @@ -28,7 +28,6 @@ import modal - # ============================================================================= # GPU Architecture Mapping # ============================================================================= @@ -189,9 +188,10 @@ def evaluate( """ import tempfile import time - + import modal.experimental import torch + from kernelbench.eval import eval_kernel_against_ref, get_torch_dtype_from_string from kernelbench.utils import set_gpu_arch diff --git a/src/kernelbench_tinker/scripts/eval_kernel_rl.py b/src/kernelbench_tinker/scripts/eval_kernel_rl.py index 07ba2a6..9816770 100644 --- a/src/kernelbench_tinker/scripts/eval_kernel_rl.py +++ b/src/kernelbench_tinker/scripts/eval_kernel_rl.py @@ -18,28 +18,57 @@ import json import logging import os -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from typing import Any import chz import tinker -from tqdm import tqdm - from tinker_cookbook import renderers, tokenizer_utils from tinker_cookbook.completers import TinkerTokenCompleter +from tqdm import tqdm from kernelbench_tinker.env import setup_environment +from kernelbench_tinker.envs.env_utils import build_system_prompt from kernelbench_tinker.envs.kernelbench_client import ( KernelBenchProblem, evaluate_kernel_async, get_problem_ids, parse_structured_response, ) -from kernelbench_tinker.training.models import get_renderer_name_for_model +from kernelbench_tinker.training.models import ( + KernelBenchTokenCompleter, + get_renderer_name_for_model, +) logger = logging.getLogger(__name__) +def pick_best_sample( + samples: list[dict[str, Any]], +) -> tuple[bool, bool, float | None]: + """Pick the best sample from a list of evaluation results. + + Returns (best_correct, best_compiled, best_speedup). + """ + def speedup_value(sample: dict[str, Any]) -> float: + speedup = sample.get("speedup") + return float(speedup) if isinstance(speedup, (int, float)) else 0.0 + + correct_samples = [s for s in samples if s.get("correctness")] + if correct_samples: + best = max(correct_samples, key=speedup_value) + else: + compiled = [s for s in samples if s.get("compiled")] + best = compiled[0] if compiled else samples[0] + + best_speedup: float | None = None + speedup_obj = best.get("speedup") + if isinstance(speedup_obj, (int, float)): + best_speedup = float(speedup_obj) + + return bool(best.get("correctness")), bool(best.get("compiled")), best_speedup + + @chz.chz class EvalConfig: """Configuration for model evaluation.""" @@ -58,6 +87,8 @@ class EvalConfig: # Generation configuration max_tokens: int = 4096 temperature: float = 0.0 # Greedy for eval + top_p: float = 1.0 # Nucleus sampling (1.0 = disabled) + seed: int | None = None # Random seed for generation (null = random) num_samples: int = 1 # Samples per problem # Evaluation settings @@ -86,6 +117,12 @@ class EvalConfig: tensorboard_log_dir: str | None = None # If provided, log eval metrics to TensorBoard tensorboard_step: int = 0 # Step to log eval metrics at + # Multi-turn inference + multiturn_enabled: bool = False + multiturn_max_turns: int = 8 + inject_think_token: bool = False # Append \n to generation prompts + prompt_max_tokens: int | None = None # Token budget for history truncation + # Tinker API base_url: str | None = None @@ -109,9 +146,9 @@ async def generate_kernel( temperature: float, ) -> str: """Generate a kernel for a problem.""" - # Build prompt + # Build prompt (same system prompt as training env) messages = [ - {"role": "system", "content": "You are an expert GPU kernel developer."}, + {"role": "system", "content": build_system_prompt(problem.backend)}, {"role": "user", "content": problem.prompt}, ] observation = renderer.build_generation_prompt(messages) @@ -185,31 +222,101 @@ async def evaluate_problem( **eval_result, }) - def speedup_value(sample: dict[str, Any]) -> float: - speedup = sample.get("speedup") - return float(speedup) if isinstance(speedup, (int, float)) else 0.0 + best_correct, best_compiled, best_speedup = pick_best_sample(samples) - # Find best result - correct_samples = [s for s in samples if s.get("correctness")] - if correct_samples: - # Best by speedup - best = max(correct_samples, key=speedup_value) - else: - # Best by compilation - compiled = [s for s in samples if s.get("compiled")] - best = compiled[0] if compiled else samples[0] + return EvalResult( + level=problem.level, + problem_id=problem.problem_id, + samples=samples, + best_correct=best_correct, + best_compiled=best_compiled, + best_speedup=best_speedup, + ) - best_speedup: float | None = None - speedup_obj = best.get("speedup") - if isinstance(speedup_obj, (int, float)): - best_speedup = float(speedup_obj) + +async def evaluate_problem_multiturn( + sampling_client: tinker.SamplingClient, + problem: KernelBenchProblem, + renderer: renderers.Renderer, + cfg: EvalConfig, + tokenizer: object | None = None, +) -> EvalResult: + """Evaluate a single problem using multi-turn refinement. + + Reuses MultiTurnKernelBenchEnv so history truncation, feedback + construction, and think-token injection stay in one place. + """ + from kernelbench_tinker.config.configs import EvalConfig as KernelEvalConfig + from kernelbench_tinker.envs.multiturn_kernelbench_env import MultiTurnKernelBenchEnv + + eval_config = KernelEvalConfig( + num_correct_trials=cfg.num_correct_trials, + measure_performance=cfg.measure_performance, + num_perf_trials=cfg.num_perf_trials, + timing_method=cfg.timing_method, + precision=cfg.precision, + check_for_excessive_speedup=cfg.check_for_excessive_speedup, + excessive_speedup_threshold=cfg.excessive_speedup_threshold, + modal_timeout=cfg.modal_timeout, + ) + + env = MultiTurnKernelBenchEnv( + problem=problem, + renderer=renderer, + max_turns=cfg.multiturn_max_turns, + eval_config=eval_config, + system_prompt=build_system_prompt(problem.backend), + tokenizer=tokenizer, + prompt_max_tokens=cfg.prompt_max_tokens, + inject_think_token=cfg.inject_think_token, + ) + + completer = KernelBenchTokenCompleter( + sampling_client, + max_tokens=cfg.max_tokens, + temperature=cfg.temperature if cfg.num_samples == 1 else 1.0, + top_p=cfg.top_p, + seed=cfg.seed, + ) + + observation, stop = await env.initial_observation() + samples: list[dict[str, Any]] = [] + + for turn in range(cfg.multiturn_max_turns): + result = await completer(observation, stop) + step_result = await env.step(result.tokens) + m = step_result.metrics + + kernel_code = env.state.history[-1]["kernel"] + if cfg.max_kernel_code_chars is not None and len(kernel_code) > cfg.max_kernel_code_chars: + kernel_code = kernel_code[: cfg.max_kernel_code_chars] + "..." + + samples.append({ + "sample_id": turn, + "turn": turn, + "kernel_code": kernel_code, + "format_ok": bool(m.get("format_ok")), + "compiled": bool(m.get("compiled")), + "correctness": bool(m.get("correctness")), + "tests_passed": m.get("tests_passed", 0), + "tests_total": m.get("tests_total", 0), + "speedup": m.get("speedup"), + "runtime_ms": m.get("runtime_ms"), + }) + + if step_result.episode_done: + break + observation = step_result.next_observation + stop = step_result.next_stop_condition + + best_correct, best_compiled, best_speedup = pick_best_sample(samples) return EvalResult( level=problem.level, problem_id=problem.problem_id, samples=samples, - best_correct=bool(best.get("correctness")), - best_compiled=bool(best.get("compiled")), + best_correct=best_correct, + best_compiled=best_compiled, best_speedup=best_speedup, ) @@ -274,6 +381,12 @@ async def run_evaluation(cfg: EvalConfig) -> dict[str, Any]: results = [] for problem in tqdm(problems, desc="Evaluating"): try: + if cfg.multiturn_enabled: + result = await evaluate_problem_multiturn( + sampling_client, problem, renderer, cfg, tokenizer=tokenizer + ) + results.append(result) + continue result = await evaluate_problem( sampling_client, problem, renderer, cfg ) @@ -353,8 +466,8 @@ def main(): # Log to TensorBoard if specified if cfg.tensorboard_log_dir: # Lazy import to avoid circular imports - from kernelbench_tinker.training.tensorboard_logger import create_tensorboard_logger from kernelbench_tinker.evaluation.eval_kernelbench import EvalResults, ProblemResult + from kernelbench_tinker.training.tensorboard_logger import create_tensorboard_logger tb_logger = create_tensorboard_logger(cfg.tensorboard_log_dir) diff --git a/src/kernelbench_tinker/scripts/train_kernel_rl.py b/src/kernelbench_tinker/scripts/train_kernel_rl.py index 28e96c6..fe2aedd 100644 --- a/src/kernelbench_tinker/scripts/train_kernel_rl.py +++ b/src/kernelbench_tinker/scripts/train_kernel_rl.py @@ -28,7 +28,8 @@ import yaml # type: ignore[import-untyped] from kernelbench_tinker.env import setup_environment -from kernelbench_tinker.training.loop import TrainingConfig, main as train_main +from kernelbench_tinker.training.loop import TrainingConfig +from kernelbench_tinker.training.loop import main as train_main logger = logging.getLogger(__name__) @@ -105,11 +106,16 @@ def main(): cfg = blueprint.make() logger.info("Starting KernelBench RL Training") + logger.info(f"Multi-turn: {'enabled' if cfg.multiturn.enabled else 'disabled'}") logger.info(f"Model: {cfg.model_name}") logger.info(f"Level: {cfg.dataset_builder.level}") logger.info(f"Batch size: {cfg.dataset_builder.batch_size}") logger.info(f"Group size: {cfg.dataset_builder.group_size}") logger.info(f"Log path: {cfg.log_path}") + if cfg.multiturn.enabled: + logger.info(f"Refinement turns per trajectory (n): {cfg.multiturn.max_turns}") + logger.info(f"Parallel trajectories (group_size): {cfg.dataset_builder.group_size}") + logger.info(f"Discount factor (gamma): {cfg.multiturn.gamma}") # Run training asyncio.run(train_main(cfg)) diff --git a/src/kernelbench_tinker/training/loop.py b/src/kernelbench_tinker/training/loop.py index 33bd86c..80d9676 100644 --- a/src/kernelbench_tinker/training/loop.py +++ b/src/kernelbench_tinker/training/loop.py @@ -16,16 +16,15 @@ import asyncio import logging import os -from pathlib import Path import time -from typing import Any +from pathlib import Path +from typing import Any, Sequence import chz import numpy as np import tinker import torch from tinker.types import LossFnType - from tinker_cookbook import checkpoint_utils from tinker_cookbook.completers import TinkerTokenCompleter from tinker_cookbook.rl.data_processing import ( @@ -35,14 +34,28 @@ ) from tinker_cookbook.rl.rollouts import do_group_rollout from tinker_cookbook.rl.types import ( + Env, EnvGroupBuilder, TrajectoryGroup, ) from tinker_cookbook.utils import ml_log from tinker_cookbook.utils.misc_utils import timed +from kernelbench_tinker.config.configs import MultiTurnConfig from kernelbench_tinker.envs.kernelbench_env import KernelBenchDatasetBuilder -from kernelbench_tinker.training.models import get_adam_params +from kernelbench_tinker.envs.multiturn_kernelbench_env import wrap_builders_as_multiturn +from kernelbench_tinker.training.models import ( + KernelBenchTokenCompleter, + build_loss_fn_config, + get_adam_params, +) +from kernelbench_tinker.training.multiturn import ( + apply_discounted_returns_to_trajectories, + compute_multiturn_advantages, + compute_multiturn_trajectory_metrics, + do_multiturn_group_rollout_and_filter, + flatten_multiturn_trajectory_groups, +) from kernelbench_tinker.training.tensorboard_logger import ( TensorBoardConfig, TensorBoardLogger, @@ -50,6 +63,8 @@ ) from kernelbench_tinker.training.trace_logger import TraceLogger, set_trace_logger +logger = logging.getLogger(__name__) + def remove_mask(datum: tinker.Datum) -> tinker.Datum: """Remove mask from datum loss_fn_inputs before sending to forward_backward. @@ -62,8 +77,6 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: loss_fn_inputs={k: v for k, v in datum.loss_fn_inputs.items() if k != "mask"}, ) -logger = logging.getLogger(__name__) - @chz.chz class TrainingConfig: @@ -83,6 +96,9 @@ class TrainingConfig: default_factory=KernelBenchDatasetBuilder ) + # Multi-turn specific config + multiturn: MultiTurnConfig = chz.field(default_factory=MultiTurnConfig) + # Training configuration num_substeps: int = 1 # Optimizer steps per batch loss_fn: LossFnType = "importance_sampling" @@ -154,7 +170,6 @@ async def do_group_rollout_and_filter( return trajectory_group - def compute_trajectory_metrics( trajectory_groups: list[TrajectoryGroup], taglist: list[list[str]] | None = None, @@ -240,6 +255,8 @@ async def train_step( learning_rate: float, num_substeps: int, loss_fn: LossFnType, + max_grad_norm: float = 0.0, + loss_fn_config: dict[str, float] | None = None, ) -> list[torch.Tensor]: """ Perform a training step with gradient accumulation. @@ -250,6 +267,8 @@ async def train_step( learning_rate: Learning rate num_substeps: Number of optimizer steps loss_fn: Loss function type + max_grad_norm: Maximum gradient norm for clipping (0.0 = no clipping) + loss_fn_config: Optional loss function config (e.g. PPO clip thresholds) Returns: List of training logprobs tensors @@ -262,8 +281,11 @@ async def train_step( batch = data[i : i + substep_size] # Forward-backward pass (remove mask key from datums) + fwd_bwd_kwargs: dict[str, Any] = {"loss_fn": loss_fn} + if loss_fn_config is not None: + fwd_bwd_kwargs["loss_fn_config"] = loss_fn_config fwd_bwd_future = await training_client.forward_backward_async( - [remove_mask(d) for d in batch], loss_fn=loss_fn + [remove_mask(d) for d in batch], **fwd_bwd_kwargs ) fwd_bwd_result = await fwd_bwd_future.result_async() @@ -272,7 +294,7 @@ async def train_step( training_logprobs.append(output["logprobs"].to_torch()) # Optimizer step - adam_params = get_adam_params(learning_rate) + adam_params = get_adam_params(learning_rate, max_grad_norm=max_grad_norm) optim_future = await training_client.optim_step_async(adam_params) await optim_future.result_async() @@ -335,6 +357,13 @@ async def run_training_loop( Args: cfg: Training configuration """ + is_multiturn = cfg.multiturn.enabled + if is_multiturn: + logger.info("Running in MULTI-TURN mode") + logger.info(f" max_turns (refinement turns per trajectory): {cfg.multiturn.max_turns}") + logger.info(f" group_size (parallel trajectories, m): {cfg.dataset_builder.group_size}") + logger.info(f" gamma (discount factor): {cfg.multiturn.gamma}") + # Setup logging os.makedirs(cfg.log_path, exist_ok=True) ml_logger = ml_log.setup_logging( @@ -415,12 +444,20 @@ async def run_training_loop( # Create dataset (pass tokenizer for renderer) dataset_builder = cfg.dataset_builder - logger.info("Using KernelBenchDatasetBuilder") + if is_multiturn: + logger.info("Using KernelBenchDatasetBuilder (multi-turn, max_turns=%d)", cfg.multiturn.max_turns) + else: + logger.info("Using KernelBenchDatasetBuilder") train_dataset, test_dataset = await dataset_builder(tokenizer=tokenizer) num_batches = len(train_dataset) logger.info(f"Training on {num_batches} batches") + # Warmup schedule (multi-turn only) + warmup_batches = int(num_batches * cfg.multiturn.warmup_ratio) if is_multiturn else 0 + if warmup_batches > 0: + logger.info(f"Linear LR warmup for {warmup_batches} batches") + # Get initial sampling client sampling_client, _ = await save_checkpoint_and_get_sampling_client( training_client, start_batch, cfg.log_path, cfg.save_every, start_batch @@ -435,56 +472,165 @@ async def run_training_loop( "optim/lr": cfg.learning_rate, } - # Get batch of env group builders + # Get batch of env group builders (always single-turn from dataset) env_group_builders = train_dataset.get_batch(batch_idx) - # Collect rollouts (single-turn) - with timed("rollout", metrics): - try: - results = await asyncio.gather(*[ - do_group_rollout_and_filter( - sampling_client, - builder, - max_tokens=cfg.max_tokens, - temperature=cfg.temperature, - do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, + if is_multiturn: + # Wrap single-turn builders as multi-turn + env_group_builders = wrap_builders_as_multiturn( + env_group_builders, cfg.multiturn, tokenizer + ) + + # ----- Multi-turn rollouts ----- + # Response length extension (multi-turn only) + effective_max_tokens = cfg.max_tokens + if ( + cfg.multiturn.max_tokens_extended > 0 + and batch_idx >= cfg.multiturn.max_tokens_extend_after_step + ): + effective_max_tokens = cfg.multiturn.max_tokens_extended + if batch_idx == cfg.multiturn.max_tokens_extend_after_step: + logger.info( + f"Extending max_tokens from {cfg.max_tokens} to " + f"{cfg.multiturn.max_tokens_extended} at step {batch_idx}" ) - for builder in env_group_builders - ], return_exceptions=True) - except Exception: - logger.exception("Group rollout failed during gather") - raise - - # Filter out None (removed constant reward groups) and exceptions - trajectory_groups = [] - for tg in results: - if isinstance(tg, Exception): - logger.error("Group rollout failed", exc_info=tg) - elif tg is not None: - trajectory_groups.append(tg) - if len(trajectory_groups) == 0: - logger.warning(f"Batch {batch_idx}: All groups filtered out, skipping") - continue - - # Compute metrics - traj_metrics = compute_trajectory_metrics(trajectory_groups) - metrics.update(traj_metrics) - - # Compute advantages and assemble training data - with timed("assemble_data", metrics): - advantages = compute_advantages(trajectory_groups) - data, _metadata = assemble_training_data(trajectory_groups, advantages) - - # Training step - with timed("train", metrics): - await train_step( - data, - training_client, - cfg.learning_rate, - cfg.num_substeps, - cfg.loss_fn, + with timed("rollout", metrics): + try: + results = await asyncio.gather( + *[ + do_multiturn_group_rollout_and_filter( + sampling_client, + builder, + max_tokens=effective_max_tokens, + temperature=cfg.multiturn.temperature, + do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, + top_p=cfg.multiturn.top_p, + seed=cfg.multiturn.seed, + ) + for builder in env_group_builders + ], + return_exceptions=True, + ) + except Exception: + logger.exception("Group rollout failed during gather") + raise + + trajectory_groups: list[TrajectoryGroup] = [] + env_groups: list[Sequence[Env]] = [] + for r in results: + if isinstance(r, BaseException): + logger.error("Group rollout failed", exc_info=r) + continue + tg, envs = r + if tg is not None and envs is not None: + trajectory_groups.append(tg) + env_groups.append(envs) + + if len(trajectory_groups) == 0: + logger.warning( + f"Batch {batch_idx}: All groups filtered out, skipping" + ) + continue + + with timed("discount_returns", metrics): + apply_discounted_returns_to_trajectories( + trajectory_groups, env_groups, + gamma=cfg.multiturn.gamma, + aggregation=cfg.multiturn.aggregation, + ) + + traj_metrics = compute_multiturn_trajectory_metrics( + trajectory_groups, env_groups ) + metrics.update(traj_metrics) + + # Flatten: each turn becomes its own single-transition trajectory + # so that advantage normalization is across all group_size * n turn-level samples + with timed("flatten", metrics): + trajectory_groups = flatten_multiturn_trajectory_groups( + trajectory_groups + ) + + # Compute advantages and assemble training data + with timed("assemble_data", metrics): + advantages = compute_multiturn_advantages(trajectory_groups) + + if cfg.multiturn.constant_length_norm > 0: + for i in range(len(advantages)): + advantages[i] = advantages[i] / cfg.multiturn.constant_length_norm + + data, _metadata = assemble_training_data(trajectory_groups, advantages) + + # Learning rate warmup (multi-turn only) + if warmup_batches > 0 and batch_idx < warmup_batches: + lr = cfg.learning_rate * (batch_idx + 1) / warmup_batches + else: + lr = cfg.learning_rate + metrics["optim/lr"] = lr + + # Training step with PPO clip and grad norm + with timed("train", metrics): + await train_step( + data, + training_client, + lr, + cfg.multiturn.num_substeps, + cfg.multiturn.loss_fn, # type: ignore[arg-type] + max_grad_norm=cfg.multiturn.max_grad_norm, + loss_fn_config=build_loss_fn_config( + cfg.multiturn.clip_epsilon_low, + cfg.multiturn.clip_epsilon_high, + ), + ) + else: + # Collect rollouts (single-turn) + with timed("rollout", metrics): + try: + st_results = await asyncio.gather(*[ + do_group_rollout_and_filter( + sampling_client, + builder, + max_tokens=cfg.max_tokens, + temperature=cfg.temperature, + do_remove_constant_reward_groups=cfg.remove_constant_reward_groups, + ) + for builder in env_group_builders + ], return_exceptions=True) + except Exception: + logger.exception("Group rollout failed during gather") + raise + + # Filter out None (removed constant reward groups) and exceptions + trajectory_groups = [] + for tg in st_results: + if isinstance(tg, Exception): + logger.error("Group rollout failed", exc_info=tg) + elif tg is not None: + trajectory_groups.append(tg) + + if len(trajectory_groups) == 0: + logger.warning(f"Batch {batch_idx}: All groups filtered out, skipping") + continue + + # Compute metrics + traj_metrics = compute_trajectory_metrics(trajectory_groups) + metrics.update(traj_metrics) + + # Compute advantages and assemble training data + with timed("assemble_data", metrics): + advantages = compute_advantages(trajectory_groups) + data, _metadata = assemble_training_data(trajectory_groups, advantages) + + # Training step + with timed("train", metrics): + await train_step( + data, + training_client, + cfg.learning_rate, + cfg.num_substeps, + cfg.loss_fn, + ) # Save checkpoint and get new sampling client sampling_client, checkpoint_metrics = await save_checkpoint_and_get_sampling_client( @@ -501,14 +647,25 @@ async def run_training_loop( tb_logger.log_training_metrics(metrics, batch_idx) tb_logger.log_trajectory_histograms(trajectory_groups, batch_idx) tb_logger.log_per_level_metrics(trajectory_groups, batch_idx) - tb_logger.log_advantage_statistics(advantages, batch_idx) - - logger.info( - f"Batch {batch_idx}/{num_batches}: " - f"reward={metrics.get('reward/mean', 0):.3f}, " - f"compile={metrics.get('kernel/compile_rate', 0):.1%}, " - f"correct={metrics.get('kernel/correct_rate', 0):.1%}" - ) + adv_arrays = [a.numpy() if isinstance(a, torch.Tensor) else a for a in advantages] + tb_logger.log_advantage_statistics(adv_arrays, batch_idx) + + if is_multiturn: + logger.info( + f"Batch {batch_idx}/{num_batches}: " + f"raw_score={metrics.get('multiturn/raw_score_mean', 0):.3f}, " + f"compile={metrics.get('multiturn/compile_rate', 0):.1%}, " + f"correct={metrics.get('multiturn/correct_rate', 0):.1%}, " + f"success={metrics.get('multiturn/success_rate', 0):.1%}, " + f"avg_turns={metrics.get('multiturn/avg_turns', 0):.1f}" + ) + else: + logger.info( + f"Batch {batch_idx}/{num_batches}: " + f"reward={metrics.get('reward/mean', 0):.3f}, " + f"compile={metrics.get('kernel/compile_rate', 0):.1%}, " + f"correct={metrics.get('kernel/correct_rate', 0):.1%}" + ) # Save final checkpoint if start_batch < num_batches: diff --git a/src/kernelbench_tinker/training/models.py b/src/kernelbench_tinker/training/models.py index f1d6ee3..a7ede69 100644 --- a/src/kernelbench_tinker/training/models.py +++ b/src/kernelbench_tinker/training/models.py @@ -1,10 +1,19 @@ """ -Minimal model helpers for KernelBench ↔ Tinker integration. +Model and completer helpers for KernelBench ↔ Tinker integration. """ from __future__ import annotations import tinker +from tinker_cookbook.completers import ( + StopCondition, + TokenCompleter, + TokensWithLogprobs, +) + +# --------------------------------------------------------------------------- +# Renderer helpers +# --------------------------------------------------------------------------- def get_renderer_name_for_model(model_name: str) -> str: @@ -28,11 +37,86 @@ def get_renderer_name_for_model(model_name: str) -> str: return "role_colon" -def get_adam_params(learning_rate: float) -> tinker.AdamParams: +# --------------------------------------------------------------------------- +# Optimizer helpers +# --------------------------------------------------------------------------- + + +def get_adam_params( + learning_rate: float, + max_grad_norm: float = 0.0, +) -> tinker.AdamParams: """Get Adam optimizer parameters.""" - return tinker.AdamParams( - learning_rate=learning_rate, - beta1=0.9, - beta2=0.95, - eps=1e-8, - ) + kwargs: dict = { + "learning_rate": learning_rate, + "beta1": 0.9, + "beta2": 0.95, + "eps": 1e-8, + } + if max_grad_norm > 0: + kwargs["grad_clip_norm"] = max_grad_norm + return tinker.AdamParams(**kwargs) + + +# --------------------------------------------------------------------------- +# Token completers +# --------------------------------------------------------------------------- + + +class KernelBenchTokenCompleter(TokenCompleter): + """Token completer with top_p and seed support. + + TinkerTokenCompleter only accepts temperature. This subclass adds top_p + and seed, which the multi-turn training loop and eval script need. + """ + + def __init__( + self, + sampling_client: tinker.SamplingClient, + max_tokens: int, + temperature: float = 1.0, + top_p: float = 1.0, + seed: int | None = None, + ): + self.sampling_client = sampling_client + self.max_tokens = max_tokens + self.temperature = temperature + self.top_p = top_p + self.seed = seed + + async def __call__( + self, model_input: tinker.ModelInput, stop: StopCondition + ) -> TokensWithLogprobs: + sample_result = await self.sampling_client.sample_async( + prompt=model_input, + num_samples=1, + sampling_params=tinker.SamplingParams( + stop=stop, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p, + seed=self.seed, + ), + ) + sampled_tokens = sample_result.sequences[0].tokens + sampled_logprobs = sample_result.sequences[0].logprobs + assert sampled_logprobs is not None + return TokensWithLogprobs(tokens=sampled_tokens, maybe_logprobs=sampled_logprobs) + + +# --------------------------------------------------------------------------- +# Loss function helpers +# --------------------------------------------------------------------------- + + +def build_loss_fn_config( + clip_epsilon_low: float = 0.0, + clip_epsilon_high: float = 0.0, +) -> dict[str, float] | None: + """Build loss_fn_config for PPO clip thresholds (passed to forward_backward_async).""" + if clip_epsilon_low <= 0: + return None + return { + "clip_low_threshold": 1.0 - clip_epsilon_low, + "clip_high_threshold": 1.0 + clip_epsilon_high, + } diff --git a/src/kernelbench_tinker/training/multiturn.py b/src/kernelbench_tinker/training/multiturn.py new file mode 100644 index 0000000..af7b5c8 --- /dev/null +++ b/src/kernelbench_tinker/training/multiturn.py @@ -0,0 +1,281 @@ +""" +Multi-turn rollout, advantage estimation, and metrics for KernelBench RL. + +These helpers are used by the training loop when multiturn.enabled is True. +Single-turn training does not touch this module. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import defaultdict +from typing import Any, Sequence + +import numpy as np +import tinker +import torch +from tinker_cookbook.rl.data_processing import remove_constant_reward_groups +from tinker_cookbook.rl.rollouts import do_single_rollout +from tinker_cookbook.rl.types import ( + Env, + EnvGroupBuilder, + Trajectory, + TrajectoryGroup, +) + +from kernelbench_tinker.envs.multiturn_kernelbench_env import MultiTurnKernelBenchEnv +from kernelbench_tinker.training.models import KernelBenchTokenCompleter +from kernelbench_tinker.training.reward import compute_discounted_returns + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Rollouts +# --------------------------------------------------------------------------- + + +async def do_multiturn_group_rollout_and_filter( + sampling_client: tinker.SamplingClient, + env_group_builder: EnvGroupBuilder, + max_tokens: int, + temperature: float, + do_remove_constant_reward_groups: bool, + top_p: float = 1.0, + seed: int | None = None, +) -> tuple[TrajectoryGroup | None, Sequence[Env] | None]: + """Multi-turn rollout that returns (trajectory_group, envs). + + We can't use do_group_rollout here because it doesn't return the envs, + and we need env access to read per-step scores for discounted returns. + """ + policy = KernelBenchTokenCompleter( + sampling_client, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + seed=seed, + ) + + envs = await env_group_builder.make_envs() + rollout_results = await asyncio.gather( + *[do_single_rollout(policy, env) for env in envs], + return_exceptions=True, + ) + + trajectories = [] + valid_envs: list[Env] = [] + for traj, env in zip(rollout_results, envs): + if isinstance(traj, Exception): + logger.warning(f"Rollout failed: {traj}") + else: + trajectories.append(traj) + valid_envs.append(env) + + if not trajectories: + logger.warning("All rollouts in group failed") + return None, None + + # Final rewards are [0.0] because multi-turn rewards live in + # transition.reward (set by env.step) and are later overwritten by + # apply_discounted_returns. TrajectoryGroup.get_total_rewards() sums + # transition.reward + final_reward, so final_reward must be zero. + trajectory_group = TrajectoryGroup( + trajectories, + [0.0] * len(trajectories), + [{}] * len(trajectories), + ) + + if do_remove_constant_reward_groups: + trajectory_groups = remove_constant_reward_groups([trajectory_group]) + if len(trajectory_groups) == 0: + return None, None + trajectory_group = trajectory_groups[0] + + return trajectory_group, valid_envs + + +# --------------------------------------------------------------------------- +# Discounted returns +# --------------------------------------------------------------------------- + + +def apply_discounted_returns_to_trajectories( + trajectory_groups: list[TrajectoryGroup], + env_groups: list[Sequence[Env]], + gamma: float, + aggregation: str = "sum", +) -> None: + """Replace per-step rewards with discounted returns for multi-turn training.""" + for tg, envs in zip(trajectory_groups, env_groups): + for traj, env in zip(tg.trajectories_G, envs): + if isinstance(env, MultiTurnKernelBenchEnv): + step_scores = env.get_step_scores() + else: + step_scores = [t.reward for t in traj.transitions] + + if not step_scores: + continue + + returns = compute_discounted_returns(step_scores, gamma, aggregation) + for i, trans in enumerate(traj.transitions): + if i < len(returns): + trans.reward = returns[i] + + +# --------------------------------------------------------------------------- +# Flatten and advantage estimation +# --------------------------------------------------------------------------- + + +def flatten_multiturn_trajectory_groups( + trajectory_groups: list[TrajectoryGroup], +) -> list[TrajectoryGroup]: + """Flatten multi-turn trajectories so each turn is its own single-transition trajectory.""" + flattened = [] + for tg in trajectory_groups: + new_trajectories = [] + for traj in tg.trajectories_G: + for trans in traj.transitions: + new_trajectories.append( + Trajectory(transitions=[trans], final_ob=tinker.ModelInput.empty()) + ) + + # final_rewards must be 0.0 because get_total_rewards() sums + # transition.reward + final_reward. The real rewards already live + # in transition.reward (set by apply_discounted_returns). + new_group = TrajectoryGroup( + new_trajectories, + [0.0] * len(new_trajectories), + [{}] * len(new_trajectories), + ) + flattened.append(new_group) + return flattened + + +def compute_multiturn_advantages( + trajectory_groups: list[TrajectoryGroup], +) -> list[torch.Tensor]: + """GRPO advantage with std normalization. + + Expects flattened trajectory groups (each "trajectory" = one turn). + Normalizes across all m*n samples per problem group. + """ + advantages_P = [] + for tg in trajectory_groups: + rewards = torch.tensor(tg.get_total_rewards()) + mean = rewards.mean() + std = rewards.std() + advantages = (rewards - mean) / (std + 1e-9) + advantages_P.append(advantages) + return advantages_P + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + + +def compute_multiturn_trajectory_metrics( + trajectory_groups: list[TrajectoryGroup], + env_groups: list[Sequence[Env]], +) -> dict[str, Any]: + """Compute aggregate metrics for multi-turn trajectories.""" + metrics: dict[str, Any] = {} + + turn_compiled: dict[int, list[float]] = defaultdict(list) + turn_correct: dict[int, list[float]] = defaultdict(list) + + all_rewards = [] + all_num_turns = [] + all_success = [] + all_best_speedup = [] + + all_format_ok = [] + all_compiled = [] + all_correct = [] + all_step_scores = [] + all_eval_times = [] + all_step_times = [] + + for tg, envs in zip(trajectory_groups, env_groups): + rewards = tg.get_total_rewards() + all_rewards.extend(rewards) + + for traj, env in zip(tg.trajectories_G, envs): + traj_speedups = [] + + for trans in traj.transitions: + if trans.metrics: + turn = trans.metrics.get("turn", 0) + compiled = trans.metrics.get("compiled", 0) + correct = trans.metrics.get("correctness", 0) + turn_compiled[turn].append(compiled) + turn_correct[turn].append(correct) + + all_format_ok.append(trans.metrics.get("format_ok", 0)) + all_compiled.append(compiled) + all_correct.append(correct) + + if "step_score" in trans.metrics: + all_step_scores.append(trans.metrics["step_score"]) + if "time/eval" in trans.metrics: + all_eval_times.append(trans.metrics["time/eval"]) + if "time/step_total" in trans.metrics: + all_step_times.append(trans.metrics["time/step_total"]) + if "speedup" in trans.metrics: + traj_speedups.append(trans.metrics["speedup"]) + + if traj_speedups: + all_best_speedup.append(max(traj_speedups)) + + if isinstance(env, MultiTurnKernelBenchEnv): + all_success.append(float(env.state.success)) + all_num_turns.append(env.state.turn_idx) + + if all_rewards: + metrics["reward/discounted_mean"] = float(np.mean(all_rewards)) + metrics["reward/discounted_std"] = float(np.std(all_rewards)) + metrics["reward/discounted_min"] = float(np.min(all_rewards)) + metrics["reward/discounted_max"] = float(np.max(all_rewards)) + + if all_format_ok: + metrics["multiturn/format_rate"] = float(np.mean(all_format_ok)) + if all_compiled: + metrics["multiturn/compile_rate"] = float(np.mean(all_compiled)) + if all_correct: + metrics["multiturn/correct_rate"] = float(np.mean(all_correct)) + if all_format_ok: + failures = [1.0 - (f and c and r) for f, c, r in zip(all_format_ok, all_compiled, all_correct)] + metrics["multiturn/failure_rate"] = float(np.mean(failures)) + if all_step_scores: + metrics["multiturn/raw_score_mean"] = float(np.mean(all_step_scores)) + if all_success: + metrics["multiturn/success_rate"] = float(np.mean(all_success)) + if all_num_turns: + metrics["multiturn/avg_turns"] = float(np.mean(all_num_turns)) + if all_best_speedup: + metrics["multiturn/best_speedup_mean"] = float(np.mean(all_best_speedup)) + if all_eval_times: + metrics["time/eval_mean"] = float(np.mean(all_eval_times)) + if all_step_times: + metrics["time/step_mean"] = float(np.mean(all_step_times)) + + for turn in sorted(turn_compiled.keys()): + if turn_compiled[turn]: + metrics[f"multiturn/turn_{turn}/compile_rate"] = float( + np.mean(turn_compiled[turn]) + ) + if turn_correct[turn]: + metrics[f"multiturn/turn_{turn}/correct_rate"] = float( + np.mean(turn_correct[turn]) + ) + + metrics["batch/num_groups"] = len(trajectory_groups) + metrics["batch/num_trajectories"] = sum( + len(tg.trajectories_G) for tg in trajectory_groups + ) + metrics["batch/total_steps"] = len(all_step_scores) + + return metrics diff --git a/src/kernelbench_tinker/training/reward.py b/src/kernelbench_tinker/training/reward.py index 4d7f9e0..de8df0a 100644 --- a/src/kernelbench_tinker/training/reward.py +++ b/src/kernelbench_tinker/training/reward.py @@ -10,7 +10,7 @@ from __future__ import annotations -import math +import logging import sys from dataclasses import dataclass from pathlib import Path @@ -30,6 +30,8 @@ if TYPE_CHECKING: from kernelbench_tinker.envs.kernelbench_client import KernelEvalResult +logger = logging.getLogger(__name__) + @dataclass class RewardConfig: @@ -58,7 +60,7 @@ class RewardConfig: # Speed reward configuration # Linear speedup: reward = T_baseline / T_kernel # ========================================================================== - speed_baseline: float = 1.0 # Speedup threshold (1.0 = same as baseline) + speed_baseline: float = 0.0 # Speedup threshold (0.0 = raw speedup as reward) speed_scale: float = 1.0 # Linear scaling (not log) speed_max_reward: float = 10.0 # Cap to prevent outliers @@ -98,6 +100,12 @@ class RewardConfig: # Default: all warning checks from static_checker.WARNING_CHECKS static_checker_warnings: list[str] | None = None # None = use defaults (all warning checks) + # ========================================================================== + # Reward clipping configuration + # ========================================================================== + reward_clip_min: float | None = None # Lower bound on total reward (None = no clipping) + reward_clip_max: float | None = None # Upper bound on total reward (None = no clipping) + def format_reward(eval_result: "KernelEvalResult", config: RewardConfig) -> float: """ @@ -189,13 +197,11 @@ def speed_reward( if speedup is None or speedup <= 0: return 0.0 - # Linear speedup, not log-scaled - # If speedup <= baseline (1.0), no speed bonus - if speedup <= config.speed_baseline: + # Linear speedup: reward = speedup for correct kernels + # With default speed_baseline=0.0: 2x speedup = 2.0 reward + if config.speed_baseline > 0 and speedup <= config.speed_baseline: return 0.0 - # Linear reward: speedup - 1.0 (so 2x speedup = 1.0 reward, 3x = 2.0, etc.) - # This matches the default formula where reward = speedup for correct kernels reward = config.speed_scale * (speedup - config.speed_baseline) # Clamp to max to prevent outliers @@ -337,16 +343,11 @@ def compute_reward( ) # Log warnings (don't zero reward) - if warnings: - import logging - logger = logging.getLogger(__name__) - for warning in warnings: - logger.warning(f"Static checker warning: {warning}") - + for warning in warnings: + logger.warning(f"Static checker warning: {warning}") + # Zero reward if errors detected if has_errors: - import logging - logger = logging.getLogger(__name__) for error in errors: logger.error(f"Reward hacking detected (reward set to 0): {error}") return 0.0 @@ -401,6 +402,11 @@ def compute_reward( l_reward = length_reward(eval_result, config) total += config.length_weight * l_reward + # Reward clipping + if config.reward_clip_min is not None: + total = max(total, config.reward_clip_min) + if config.reward_clip_max is not None: + total = min(total, config.reward_clip_max) return total @@ -446,3 +452,35 @@ def compute_reward_breakdown( "static_checker_errors": static_checker_errors, "static_checker_warnings": static_checker_warnings, } + + +def compute_discounted_returns( + step_scores: list[float], + gamma: float = 0.4, + aggregation: str = "sum", +) -> list[float]: + """Compute discounted returns for multi-turn RL. + + sum: R_t = S_t + gamma * R_{t+1} (backward recursion) + max: R_t = max{ gamma^(i-t) * S_i } + """ + if aggregation not in ("sum", "max"): + raise ValueError(f"Unknown aggregation mode: {aggregation!r}. Must be 'sum' or 'max'.") + + if not step_scores: + return [] + + T = len(step_scores) + + if aggregation == "sum": + returns = [0.0] * T + returns[-1] = step_scores[-1] + for t in range(T - 2, -1, -1): + returns[t] = step_scores[t] + gamma * returns[t + 1] + return returns + + # aggregation == "max" + return [ + max(gamma ** (i - t) * step_scores[i] for i in range(t, T)) + for t in range(T) + ] diff --git a/src/kernelbench_tinker/training/tensorboard_logger.py b/src/kernelbench_tinker/training/tensorboard_logger.py index 891aa6b..7eea9b4 100644 --- a/src/kernelbench_tinker/training/tensorboard_logger.py +++ b/src/kernelbench_tinker/training/tensorboard_logger.py @@ -17,9 +17,8 @@ from typing import Any, Sequence import numpy as np -from torch.utils.tensorboard import SummaryWriter - from tinker_cookbook.rl.types import TrajectoryGroup +from torch.utils.tensorboard import SummaryWriter logger = logging.getLogger(__name__) @@ -199,6 +198,41 @@ def log_training_metrics( if "kernel/mean_speedup" in metrics: self.writer.add_scalar("Kernel/MeanSpeedup", metrics["kernel/mean_speedup"], step) + # === Multi-turn Reward Metrics === + # Multi-turn emits reward/discounted_{mean,std,min,max} + if "reward/discounted_mean" in metrics: + self.writer.add_scalar("Reward/DiscountedMean", metrics["reward/discounted_mean"], step) + if "reward/discounted_std" in metrics: + self.writer.add_scalar("Reward/DiscountedStdDev", metrics["reward/discounted_std"], step) + if "reward/discounted_min" in metrics: + self.writer.add_scalar("Reward/DiscountedMin", metrics["reward/discounted_min"], step) + if "reward/discounted_max" in metrics: + self.writer.add_scalar("Reward/DiscountedMax", metrics["reward/discounted_max"], step) + + # === Multi-turn Kernel Quality Metrics === + if "multiturn/format_rate" in metrics: + self.writer.add_scalar("MultiTurn/FormatRate", metrics["multiturn/format_rate"], step) + if "multiturn/compile_rate" in metrics: + self.writer.add_scalar("MultiTurn/CompileRate", metrics["multiturn/compile_rate"], step) + if "multiturn/correct_rate" in metrics: + self.writer.add_scalar("MultiTurn/CorrectRate", metrics["multiturn/correct_rate"], step) + + # === Multi-turn Failure Rate === + if "kernel/failure_rate" in metrics: + self.writer.add_scalar("Kernel/FailureRate", metrics["kernel/failure_rate"], step) + if "multiturn/failure_rate" in metrics: + self.writer.add_scalar("MultiTurn/FailureRate", metrics["multiturn/failure_rate"], step) + + # === Multi-turn Specific Metrics === + if "multiturn/raw_score_mean" in metrics: + self.writer.add_scalar("MultiTurn/RawScoreMean", metrics["multiturn/raw_score_mean"], step) + if "multiturn/success_rate" in metrics: + self.writer.add_scalar("MultiTurn/SuccessRate", metrics["multiturn/success_rate"], step) + if "multiturn/avg_turns" in metrics: + self.writer.add_scalar("MultiTurn/AvgTurns", metrics["multiturn/avg_turns"], step) + if "multiturn/best_speedup_mean" in metrics: + self.writer.add_scalar("MultiTurn/BestSpeedupMean", metrics["multiturn/best_speedup_mean"], step) + def log_trajectory_histograms( self, trajectory_groups: list[TrajectoryGroup], diff --git a/uv.lock b/uv.lock index 19e2d97..b5445b7 100644 --- a/uv.lock +++ b/uv.lock @@ -2,9 +2,12 @@ version = 1 revision = 3 requires-python = ">=3.11" resolution-markers = [ - "python_full_version >= '3.14'", - "python_full_version >= '3.12' and python_full_version < '3.14'", - "python_full_version < '3.12'", + "python_full_version >= '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform != 'linux'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'linux'", + "python_full_version < '3.12' and sys_platform == 'linux'", + "python_full_version < '3.12' and sys_platform != 'linux'", ] [[package]] @@ -1054,13 +1057,39 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.46" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/b5/59d16470a1f0dfe8c793f9ef56fd3826093fc52b3bd96d6b9d6c26c7e27b/gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f", size = 215371, upload-time = "2026-01-01T15:37:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" }, +] + [[package]] name = "grpcio" version = "1.67.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and python_full_version < '3.14'", - "python_full_version < '3.12'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'linux'", + "python_full_version < '3.12' and sys_platform == 'linux'", + "python_full_version < '3.12' and sys_platform != 'linux'", ] sdist = { url = "https://files.pythonhosted.org/packages/20/53/d9282a66a5db45981499190b77790570617a604a38f3d103d0400974aeb5/grpcio-1.67.1.tar.gz", hash = "sha256:3dc2ed4cabea4dc14d5e708c2b426205956077cc5de419b4d4079315017e9732", size = 12580022, upload-time = "2024-10-29T06:30:07.787Z" } wheels = [ @@ -1098,7 +1127,8 @@ name = "grpcio" version = "1.76.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.14'", + "python_full_version >= '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform != 'linux'", ] dependencies = [ { name = "typing-extensions", marker = "python_full_version >= '3.14'" }, @@ -1675,7 +1705,7 @@ wheels = [ [[package]] name = "kernelbench" version = "0.2.0.dev0" -source = { directory = "../KernelBench" } +source = { editable = "KernelBench" } dependencies = [ { name = "datasets" }, { name = "einops" }, @@ -1703,6 +1733,7 @@ requires-dist = [ { name = "litellm", extras = ["proxy"] }, { name = "modal" }, { name = "ninja" }, + { name = "nsight-python", marker = "extra == 'gpu'" }, { name = "numpy" }, { name = "nvidia-cutlass-dsl", marker = "extra == 'gpu'" }, { name = "openai" }, @@ -1724,7 +1755,7 @@ provides-extras = ["gpu", "dev"] [[package]] name = "kernelbench-tinker" -version = "0.1.0" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "chz" }, @@ -1741,6 +1772,7 @@ dependencies = [ { name = "tomli" }, { name = "torch" }, { name = "tqdm" }, + { name = "wandb" }, ] [package.optional-dependencies] @@ -1757,7 +1789,7 @@ requires-dist = [ { name = "chz" }, { name = "datasets" }, { name = "isort", marker = "extra == 'dev'" }, - { name = "kernelbench", directory = "../KernelBench" }, + { name = "kernelbench", editable = "KernelBench" }, { name = "litellm" }, { name = "modal", specifier = ">=0.64.0" }, { name = "mypy", marker = "extra == 'dev'" }, @@ -1771,6 +1803,7 @@ requires-dist = [ { name = "tomli" }, { name = "torch", specifier = ">=2.9.0" }, { name = "tqdm" }, + { name = "wandb" }, ] provides-extras = ["dev"] @@ -4097,6 +4130,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/24/4d91e05817e92e3a61c8a21e08fd0f390f5301f1c448b137c57c4bc6e543/semver-3.0.4-py3-none-any.whl", hash = "sha256:9c824d87ba7f7ab4a1890799cec8596f15c1241cb473404ea1cb0c55e4b04746", size = 17912, upload-time = "2025-01-24T13:19:24.949Z" }, ] +[[package]] +name = "sentry-sdk" +version = "2.53.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/06/66c8b705179bc54087845f28fd1b72f83751b6e9a195628e2e9af9926505/sentry_sdk-2.53.0.tar.gz", hash = "sha256:6520ef2c4acd823f28efc55e43eb6ce2e6d9f954a95a3aa96b6fd14871e92b77", size = 412369, upload-time = "2026-02-16T11:11:14.743Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/d4/2fdf854bc3b9c7f55219678f812600a20a138af2dd847d99004994eada8f/sentry_sdk-2.53.0-py2.py3-none-any.whl", hash = "sha256:46e1ed8d84355ae54406c924f6b290c3d61f4048625989a723fd622aab838899", size = 437908, upload-time = "2026-02-16T11:11:13.227Z" }, +] + [[package]] name = "setuptools" version = "80.9.0" @@ -4133,6 +4179,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -4741,6 +4796,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/9a/0962b05b308494e3202d3f794a6e85abe471fe3cafdbcf95c2e8c713aabd/uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553", size = 4660018, upload-time = "2024-10-14T23:38:10.888Z" }, ] +[[package]] +name = "wandb" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "gitpython" }, + { name = "packaging" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/60/d94952549920469524b689479c864c692ca47eca4b8c2fe3389b64a58778/wandb-0.25.0.tar.gz", hash = "sha256:45840495a288e34245d69d07b5a0b449220fbc5b032e6b51c4f92ec9026d2ad1", size = 43951335, upload-time = "2026-02-13T00:17:45.515Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/7d/0c131db3ec9deaabbd32263d90863cbfbe07659527e11c35a5c738cecdc5/wandb-0.25.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:5eecb3c7b5e60d1acfa4b056bfbaa0b79a482566a9db58c9f99724b3862bc8e5", size = 23287536, upload-time = "2026-02-13T00:17:20.265Z" }, + { url = "https://files.pythonhosted.org/packages/c3/95/31bb7f76a966ec87495e5a72ac7570685be162494c41757ac871768dbc4f/wandb-0.25.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:daeedaadb183dc466e634fba90ab2bab1d4e93000912be0dee95065a0624a3fd", size = 25196062, upload-time = "2026-02-13T00:17:23.356Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a1/258cdedbf30cebc692198a774cf0ef945b7ed98ee64bdaf62621281c95d8/wandb-0.25.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:5e0127dbcef13eea48f4b84268da7004d34d3120ebc7b2fa9cefb72b49dbb825", size = 22799744, upload-time = "2026-02-13T00:17:26.437Z" }, + { url = "https://files.pythonhosted.org/packages/de/91/ec9465d014cfd199c5b2083d271d31b3c2aedeae66f3d8a0712f7f54bdf3/wandb-0.25.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:6c4c38077836f9b7569a35b0e1dcf1f0c43616fcd936d182f475edbfea063665", size = 25262839, upload-time = "2026-02-13T00:17:28.8Z" }, + { url = "https://files.pythonhosted.org/packages/c7/95/cb2d1c7143f534544147fb53fe87944508b8cb9a058bc5b6f8a94adbee15/wandb-0.25.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6edd8948d305cb73745bf564b807bd73da2ccbd47c548196b8a362f7df40aed8", size = 22853714, upload-time = "2026-02-13T00:17:31.68Z" }, + { url = "https://files.pythonhosted.org/packages/d7/94/68163f70c1669edcf130822aaaea782d8198b5df74443eca0085ec596774/wandb-0.25.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ada6f08629bb014ad6e0a19d5dec478cdaa116431baa3f0a4bf4ab8d9893611f", size = 25358037, upload-time = "2026-02-13T00:17:34.676Z" }, + { url = "https://files.pythonhosted.org/packages/cc/fb/9578eed2c01b2fc6c8b693da110aa9c73a33d7bb556480f5cfc42e48c94e/wandb-0.25.0-py3-none-win32.whl", hash = "sha256:020b42ca4d76e347709d65f59b30d4623a115edc28f462af1c92681cb17eae7c", size = 24604118, upload-time = "2026-02-13T00:17:37.641Z" }, + { url = "https://files.pythonhosted.org/packages/25/97/460f6cb738aaa39b4eb2e6b4c630b2ae4321cdd70a79d5955ea75a878981/wandb-0.25.0-py3-none-win_amd64.whl", hash = "sha256:78307ac0b328f2dc334c8607bec772851215584b62c439eb320c4af4fb077a00", size = 24604122, upload-time = "2026-02-13T00:17:39.991Z" }, + { url = "https://files.pythonhosted.org/packages/27/6c/5847b4dda1dfd52630dac08711d4348c69ed657f0698fc2d949c7f7a6622/wandb-0.25.0-py3-none-win_arm64.whl", hash = "sha256:c6174401fd6fb726295e98d57b4231c100eca96bd17de51bfc64038a57230aaf", size = 21785298, upload-time = "2026-02-13T00:17:42.475Z" }, +] + [[package]] name = "watchfiles" version = "1.1.1"