diff --git a/crates/vm/src/arch/testing/cpu.rs b/crates/vm/src/arch/testing/cpu.rs index 70c374968c..ae2e9e906d 100644 --- a/crates/vm/src/arch/testing/cpu.rs +++ b/crates/vm/src/arch/testing/cpu.rs @@ -37,8 +37,9 @@ use crate::{ testing::{ execution::air::ExecutionDummyAir, program::{air::ProgramDummyAir, ProgramTester}, - ExecutionTester, MemoryTester, TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS, - MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, RANGE_CHECKER_BUS, READ_INSTRUCTION_BUS, + ExecutionTester, MemoryTester, TestBuilder, TestChipHarness, EXECUTION_BUS, HINT_BUS, + MEMORY_BUS, MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, RANGE_CHECKER_BUS, + READ_INSTRUCTION_BUS, }, vm_poseidon2_config, Arena, ExecutionBridge, ExecutionBus, ExecutionState, MatrixRecordArena, MemoryConfig, PreflightExecutor, Streams, VmStateMut, @@ -46,7 +47,7 @@ use crate::{ system::{ memory::{ adapter::records::arena_size_bound, - offline_checker::{MemoryBridge, MemoryBus}, + offline_checker::{HintBridge, HintBus, MemoryBridge, MemoryBus}, online::TracingMemory, MemoryAirInventory, MemoryController, SharedMemoryHelper, CHUNK, }, @@ -258,10 +259,13 @@ impl VmChipTestBuilder { } pub fn system_port(&self) -> SystemPort { + let hint_bus = HintBus::new(HINT_BUS); + let hint_bridge = HintBridge::new(hint_bus); SystemPort { execution_bus: self.execution.bus, program_bus: self.program.bus, memory_bridge: self.memory_bridge(), + hint_bridge, } } diff --git a/crates/vm/src/arch/testing/cuda.rs b/crates/vm/src/arch/testing/cuda.rs index 0427f50671..e53ded66cc 100644 --- a/crates/vm/src/arch/testing/cuda.rs +++ b/crates/vm/src/arch/testing/cuda.rs @@ -50,7 +50,7 @@ use crate::{ execution::{air::ExecutionDummyAir, DeviceExecutionTester}, memory::DeviceMemoryTester, program::{air::ProgramDummyAir, DeviceProgramTester}, - TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS, MEMORY_MERKLE_BUS, + TestBuilder, TestChipHarness, EXECUTION_BUS, HINT_BUS, MEMORY_BUS, MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, READ_INSTRUCTION_BUS, }, Arena, DenseRecordArena, ExecutionBridge, ExecutionBus, ExecutionState, MatrixRecordArena, @@ -59,7 +59,7 @@ use crate::{ system::{ cuda::{poseidon2::Poseidon2PeripheryChipGPU, DIGEST_WIDTH}, memory::{ - offline_checker::{MemoryBridge, MemoryBus}, + offline_checker::{HintBridge, HintBus, MemoryBridge, MemoryBus}, MemoryAirInventory, SharedMemoryHelper, }, poseidon2::air::Poseidon2PeripheryAir, @@ -393,8 +393,14 @@ impl GpuChipTestBuilder { execution_bus: self.execution_bus(), program_bus: self.program_bus(), memory_bridge: self.memory_bridge(), + hint_bridge: self.hint_bridge(), } } + + pub fn hint_bridge(&self) -> HintBridge { + let hint_bus = HintBus::new(HINT_BUS); + HintBridge::new(hint_bus) + } pub fn execution_bridge(&self) -> ExecutionBridge { ExecutionBridge::new(self.execution.bus(), self.program.bus()) } diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 5293a0275a..52a9385e25 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -29,6 +29,7 @@ pub const BITWISE_OP_LOOKUP_BUS: BusIndex = 9; pub const BYTE_XOR_BUS: BusIndex = 10; pub const RANGE_TUPLE_CHECKER_BUS: BusIndex = 11; pub const MEMORY_MERKLE_BUS: BusIndex = 12; +pub const HINT_BUS: BusIndex = 13; pub const RANGE_CHECKER_BUS: BusIndex = 4; diff --git a/crates/vm/src/system/memory/offline_checker/bridge.rs b/crates/vm/src/system/memory/offline_checker/bridge.rs index 367e1344d7..867348fb57 100644 --- a/crates/vm/src/system/memory/offline_checker/bridge.rs +++ b/crates/vm/src/system/memory/offline_checker/bridge.rs @@ -10,7 +10,7 @@ use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::AirBuilder, p3_field::FieldAlgebra, }; -use super::bus::MemoryBus; +use super::bus::{HintBus, MemoryBus}; use crate::system::memory::{ offline_checker::columns::{ MemoryBaseAuxCols, MemoryReadAuxCols, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, @@ -326,3 +326,52 @@ impl MemoryOfflineChecker { .eval(builder, enabled); } } + +/// The [HintBridge] is used to constrain hint space lookups. +/// Consumer chips call `lookup` to verify that values they read from hint_space +/// match what was originally loaded via the hint bus lookup table. +#[derive(Clone, Copy, Debug)] +pub struct HintBridge { + hint_bus: HintBus, +} + +impl HintBridge { + /// Create a new [HintBridge] with the provided hint bus. + pub fn new(hint_bus: HintBus) -> Self { + Self { hint_bus } + } + + pub fn hint_bus(&self) -> HintBus { + self.hint_bus + } + + /// Perform a lookup on the hint bus for a single element. + /// + /// Constrains that `(hint_id, offset, value)` exists in the hint lookup table. + /// Caller must constrain that `enabled` is boolean. + pub fn lookup( + &self, + builder: &mut AB, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + enabled: impl Into, + ) { + self.hint_bus.lookup(builder, hint_id, offset, value, enabled); + } + + /// Add a key to the hint lookup table. + /// + /// Provider chips call this to register that `(hint_id, offset, value)` is available. + pub fn provide( + &self, + builder: &mut AB, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + num_lookups: impl Into, + ) { + self.hint_bus + .provide(builder, hint_id, offset, value, num_lookups); + } +} diff --git a/crates/vm/src/system/memory/offline_checker/bus.rs b/crates/vm/src/system/memory/offline_checker/bus.rs index d15f5798ea..4d5b9ef5cd 100644 --- a/crates/vm/src/system/memory/offline_checker/bus.rs +++ b/crates/vm/src/system/memory/offline_checker/bus.rs @@ -1,7 +1,7 @@ use std::iter; use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + interaction::{BusIndex, InteractionBuilder, LookupBus, PermutationCheckBus}, p3_field::FieldAlgebra, }; @@ -101,3 +101,65 @@ impl MemoryBusInteraction { } } } + +/// Represents a hint bus identified by a unique bus index. +/// Used as a lookup table to constrain values read from hint space. +/// +/// Consumer chips (e.g. NativeSumcheck) perform lookups to verify that +/// hint_space values match what was originally loaded. +/// Provider chips (e.g. a hint loader) add keys to the lookup table. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct HintBus { + pub inner: LookupBus, +} + +impl HintBus { + pub const fn new(index: BusIndex) -> Self { + Self { + inner: LookupBus::new(index), + } + } + + #[inline(always)] + pub fn index(&self) -> BusIndex { + self.inner.index + } + + /// Performs a lookup on the hint bus. + /// + /// Asserts that `(hint_id, offset, value)` is present in the hint lookup table. + /// Caller must constrain that `enabled` is boolean. + pub fn lookup( + &self, + builder: &mut AB, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + enabled: impl Into, + ) { + self.inner.lookup_key( + builder, + [hint_id.into(), offset.into(), value.into()], + enabled, + ); + } + + /// Adds a key to the hint lookup table. + /// + /// The `num_lookups` parameter should equal the number of enabled lookups performed + /// for this key. + pub fn provide( + &self, + builder: &mut AB, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + num_lookups: impl Into, + ) { + self.inner.add_key_with_lookups( + builder, + [hint_id.into(), offset.into(), value.into()], + num_lookups, + ); + } +} diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index 63c114193d..f614a43026 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -122,4 +122,4 @@ impl AsMut> for MemoryReadOrImmediateAuxCols { fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { &mut self.base } -} +} \ No newline at end of file diff --git a/crates/vm/src/system/mod.rs b/crates/vm/src/system/mod.rs index 8c0d2a3d37..5bb9a4be07 100644 --- a/crates/vm/src/system/mod.rs +++ b/crates/vm/src/system/mod.rs @@ -36,7 +36,7 @@ use crate::{ connector::VmConnectorChip, memory::{ interface::MemoryInterfaceAirs, - offline_checker::{MemoryBridge, MemoryBus}, + offline_checker::{HintBridge, HintBus, MemoryBridge, MemoryBus}, online::GuestMemory, MemoryAirInventory, MemoryController, TimestampedEquipartition, CHUNK, }, @@ -149,6 +149,7 @@ pub struct SystemPort { pub execution_bus: ExecutionBus, pub program_bus: ProgramBus, pub memory_bridge: MemoryBridge, + pub hint_bridge: HintBridge, } #[derive(Clone)] @@ -156,6 +157,7 @@ pub struct SystemAirInventory { pub program: ProgramAir, pub connector: VmConnectorAir, pub memory: MemoryAirInventory, + pub hint_bridge: HintBridge, /// Public values AIR exists if and only if continuations is disabled and `num_public_values` /// is greater than 0. pub public_values: Option, @@ -171,6 +173,7 @@ impl SystemAirInventory { execution_bus, program_bus, memory_bridge, + hint_bridge, } = port; let range_bus = memory_bridge.range_bus(); let program = ProgramAir::new(program_bus); @@ -212,6 +215,7 @@ impl SystemAirInventory { program, connector, memory, + hint_bridge, public_values, } } @@ -221,6 +225,7 @@ impl SystemAirInventory { memory_bridge: self.memory.bridge, program_bus: self.program.bus, execution_bus: self.connector.execution_bus, + hint_bridge: self.hint_bridge, } } @@ -300,10 +305,13 @@ impl VmCircuitConfig for SystemConfig { }; let memory_bridge = MemoryBridge::new(memory_bus, self.memory_config.timestamp_max_bits, range_bus); + let hint_bus = HintBus::new(bus_idx_mgr.new_bus_idx()); + let hint_bridge = HintBridge::new(hint_bus); let system_port = SystemPort { execution_bus, program_bus, memory_bridge, + hint_bridge, }; let system = SystemAirInventory::new(self, system_port, merkle_compression_buses); diff --git a/extensions/algebra/circuit/src/extension/fp2.rs b/extensions/algebra/circuit/src/extension/fp2.rs index 3081c88565..ce0938af4a 100644 --- a/extensions/algebra/circuit/src/extension/fp2.rs +++ b/extensions/algebra/circuit/src/extension/fp2.rs @@ -175,6 +175,7 @@ impl VmCircuitExtension for Fp2Extension { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/algebra/circuit/src/extension/modular.rs b/extensions/algebra/circuit/src/extension/modular.rs index 8946daa9c3..5c5be255fe 100644 --- a/extensions/algebra/circuit/src/extension/modular.rs +++ b/extensions/algebra/circuit/src/extension/modular.rs @@ -231,6 +231,7 @@ impl VmCircuitExtension for ModularExtension { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/bigint/circuit/src/extension/mod.rs b/extensions/bigint/circuit/src/extension/mod.rs index 1725a4860d..90e3688d0d 100644 --- a/extensions/bigint/circuit/src/extension/mod.rs +++ b/extensions/bigint/circuit/src/extension/mod.rs @@ -143,6 +143,7 @@ impl VmCircuitExtension for Int256 { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/ecc/circuit/src/extension/weierstrass.rs b/extensions/ecc/circuit/src/extension/weierstrass.rs index 5048584183..10181b1ea5 100644 --- a/extensions/ecc/circuit/src/extension/weierstrass.rs +++ b/extensions/ecc/circuit/src/extension/weierstrass.rs @@ -200,6 +200,7 @@ impl VmCircuitExtension for WeierstrassExtension { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/keccak256/circuit/src/extension/mod.rs b/extensions/keccak256/circuit/src/extension/mod.rs index 9f6e55a540..bbdd64cc06 100644 --- a/extensions/keccak256/circuit/src/extension/mod.rs +++ b/extensions/keccak256/circuit/src/extension/mod.rs @@ -148,6 +148,7 @@ impl VmCircuitExtension for Keccak256 { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index a7d6eee536..3ba5a27195 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -82,7 +82,10 @@ template struct NativeSumcheckCols { T eval_acc[EXT_DEG]; - T is_hint_src_id; + T is_writeback; + + T prod_hint_id; + T logup_hint_id; T specific[COL_SPECIFIC_WIDTH]; }; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index 139a56473f..224703517e 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -32,34 +32,63 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, - specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) - ); + if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) { + // writeback p1, p2 + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base)) + ); + // write p_eval + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } else { + // write p_eval only + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 2, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) - ); + if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) { + // writeback p1, p2, q1, q2 + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base)) + ); + // write p_eval + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + // write q_eval + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } else { + // write p_eval + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + // write q_eval + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } } } } diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 765ce8d6cc..50a9cba86d 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use openvm_circuit::{ arch::{ChipInventory, ChipInventoryError, DenseRecordArena, VmProverExtension}, system::cuda::extensions::get_inventory_range_checker, @@ -14,6 +16,7 @@ use crate::{ field_arithmetic::{FieldArithmeticAir, FieldArithmeticChipGpu}, field_extension::{FieldExtensionAir, FieldExtensionChipGpu}, fri::{FriReducedOpeningAir, FriReducedOpeningChipGpu}, + hint_space_provider::{cuda::HintSpaceProviderChipGpu, HintSpaceProviderAir, HintSpaceProviderChip}, jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu}, loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu}, poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu}, @@ -76,8 +79,18 @@ impl VmProverExtension let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(poseidon2); + let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; + let cpu_chip = Arc::new(HintSpaceProviderChip::new( + hint_air.hint_bus, + range_checker.clone(), + timestamp_max_bits, + )); + let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone()); + inventory.add_periphery_chip(provider_gpu); + inventory.next_air::()?; - let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits); + let sumcheck = + NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip); inventory.add_executor_chip(sumcheck); Ok(()) diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 9f3e2035ad..924d4927e8 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterExecutor}; use branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterExecutor}; use convert_adapter::{ConvertAdapterAir, ConvertAdapterExecutor}; @@ -12,6 +14,7 @@ use openvm_circuit::{ }, system::{memory::SharedMemoryHelper, SystemPort}, }; +use openvm_circuit_primitives::is_less_than::IsLtSubAir; use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant}; use openvm_native_compiler::{ @@ -49,6 +52,7 @@ use crate::{ FriReducedOpeningAir, FriReducedOpeningChip, FriReducedOpeningExecutor, FriReducedOpeningFiller, }, + hint_space_provider::{HintSpaceProviderAir, HintSpaceProviderChip}, jal_rangecheck::{ JalRangeCheckAir, JalRangeCheckExecutor, JalRangeCheckFiller, NativeJalRangeCheckChip, }, @@ -219,6 +223,7 @@ where execution_bus, program_bus, memory_bridge, + hint_bridge, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); let range_checker = inventory.range_checker().bus; @@ -269,12 +274,22 @@ where let verify_batch = NativePoseidon2Air::<_, 1>::new( exec_bridge, memory_bridge, + hint_bridge, VerifyBatchBus::new(inventory.new_bus_idx()), Poseidon2Config::default(), ); inventory.add_air(verify_batch); - let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge); + let hint_space_provider = HintSpaceProviderAir { + hint_bus: hint_bridge.hint_bus(), + lt_air: IsLtSubAir::new( + range_checker, + inventory.config().memory_config.timestamp_max_bits, + ), + }; + inventory.add_air(hint_space_provider); + + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge); inventory.add_air(tower_evaluate); Ok(()) @@ -357,7 +372,20 @@ where ); inventory.add_executor_chip(poseidon2); - let tower_verify = NativeSumcheckChip::new(NativeSumcheckFiller::new(), mem_helper.clone()); + let hint_bus = inventory.airs().system().hint_bridge.hint_bus(); + let hint_space_provider = Arc::new(HintSpaceProviderChip::new( + hint_bus, + range_checker.clone(), + timestamp_max_bits, + )); + + inventory.next_air::()?; + inventory.add_periphery_chip(hint_space_provider.clone()); + + let tower_verify = NativeSumcheckChip::new( + NativeSumcheckFiller::new(hint_space_provider), + mem_helper.clone(), + ); inventory.add_executor_chip(tower_verify); Ok(()) @@ -542,6 +570,7 @@ impl VmCircuitExtension for CastFExtension { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); let range_checker = inventory.range_checker().bus; diff --git a/extensions/native/circuit/src/hint_space_provider.rs b/extensions/native/circuit/src/hint_space_provider.rs new file mode 100644 index 0000000000..2c5fdf0652 --- /dev/null +++ b/extensions/native/circuit/src/hint_space_provider.rs @@ -0,0 +1,315 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + collections::HashMap, + mem::size_of, + sync::{Arc, Mutex}, +}; + +use openvm_circuit::system::memory::offline_checker::HintBus; +use openvm_circuit_primitives::{ + is_less_than::{IsLessThanIo, IsLtSubAir}, + var_range::SharedVariableRangeCheckerChip, + SubAir, TraceSubRowGenerator, +}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, +}; +pub const HINT_ID_LT_AUX_LEN: usize = 2; + +#[derive(Default, AlignedBorrow, Copy, Clone)] +#[repr(C)] +pub struct HintSpaceProviderCols { + pub hint_id: T, + pub offset: T, + pub value: T, + pub multiplicity: T, + /// Inverse of multiplicity when nonzero; 0 for padding rows. + pub mult_inv: T, + /// Boolean: 1 if hint_id changes between this row and the next non-padding row. + pub hint_id_changed: T, + /// Auxiliary limbs for IsLtSubAir range check decomposition of (curr.hint_id < next.hint_id). + pub hint_id_lt_aux: [T; HINT_ID_LT_AUX_LEN], + /// Boolean: 1 if this row is not a padding row (multiplicity > 0). + pub curr_is_non_padding: T, + /// Boolean: 1 if the next row is not a padding row. + pub next_is_non_padding: T, + /// Boolean: curr_is_non_padding * next_is_non_padding. + pub both_non_padding: T, +} + +pub const NUM_HINT_SPACE_PROVIDER_COLS: usize = size_of::>(); + +#[derive(Clone, Debug)] +pub struct HintSpaceProviderAir { + pub hint_bus: HintBus, + pub lt_air: IsLtSubAir, +} + +impl BaseAirWithPublicValues for HintSpaceProviderAir {} +impl PartitionedBaseAir for HintSpaceProviderAir {} + +impl BaseAir for HintSpaceProviderAir { + fn width(&self) -> usize { + NUM_HINT_SPACE_PROVIDER_COLS + } +} + +impl Air for HintSpaceProviderAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let curr = main.row_slice(0); + let curr: &HintSpaceProviderCols = (*curr).borrow(); + let next = main.row_slice(1); + let next: &HintSpaceProviderCols = (*next).borrow(); + + // curr_is_non_padding is boolean and tied to multiplicity via mult_inv. + builder.assert_bool(curr.curr_is_non_padding); + builder.assert_eq( + curr.curr_is_non_padding, + curr.multiplicity * curr.mult_inv, + ); + // Padding rows must have multiplicity = 0. + builder.assert_zero( + (AB::Expr::ONE - curr.curr_is_non_padding) * curr.multiplicity, + ); + + builder.assert_bool(curr.hint_id_changed); + + // Tie next_is_non_padding and both_non_padding columns to their definitions. + builder + .when_transition() + .assert_eq(curr.next_is_non_padding, next.curr_is_non_padding); + builder.when_transition().assert_eq( + curr.both_non_padding, + curr.curr_is_non_padding * curr.next_is_non_padding, + ); + + // Non-padding rows must appear before padding rows (non-increasing). + builder + .when_transition() + .when(curr.next_is_non_padding) + .assert_one(curr.curr_is_non_padding); + + // Uniqueness of (hint_id, offset) among non-padding rows. + // Rows are sorted by (hint_id, offset). For consecutive non-padding rows: + // - Same hint_id (hint_id_changed=0): offset must increase by exactly 1. + // - Different hint_id (hint_id_changed=1): hint_id strictly increases + // (proven via IsLtSubAir range-check), and the new block starts at offset 0. + let d_id: AB::Expr = next.hint_id - curr.hint_id; + + // hint_id_changed = 0 => same hint_id, offset increases by 1 + builder + .when_transition() + .when(curr.both_non_padding) + .when_ne(curr.hint_id_changed, AB::Expr::ONE) + .assert_zero(d_id); + builder + .when_transition() + .when(curr.both_non_padding) + .when_ne(curr.hint_id_changed, AB::Expr::ONE) + .assert_eq(next.offset, curr.offset + AB::Expr::ONE); + + // hint_id_changed = 1 => hint_id strictly increases (curr.hint_id < next.hint_id) + let lt_count: AB::Expr = curr.hint_id_changed.into() * curr.both_non_padding.into(); + self.lt_air.eval( + builder, + ( + IsLessThanIo { + x: curr.hint_id.into(), + y: next.hint_id.into(), + out: curr.hint_id_changed.into(), + count: lt_count, + }, + &curr.hint_id_lt_aux, + ), + ); + + // hint_id_changed = 1 => new block starts at offset 0 + builder + .when_transition() + .when(curr.both_non_padding) + .when(curr.hint_id_changed) + .assert_zero(next.offset); + + self.hint_bus.provide( + builder, + curr.hint_id, + curr.offset, + curr.value, + curr.multiplicity, + ); + } +} + +pub struct HintSpaceProviderChip { + pub air: HintSpaceProviderAir, + range_checker: SharedVariableRangeCheckerChip, + /// Maps (hint_id, offset) -> (value, multiplicity). + /// Deduplicates keys and tracks how many times each is looked up. + data: Mutex>, +} + +pub type SharedHintSpaceProviderChip = Arc>; + +impl HintSpaceProviderChip { + pub fn new( + hint_bus: HintBus, + range_checker: SharedVariableRangeCheckerChip, + hint_id_max_bits: usize, + ) -> Self { + let lt_air = IsLtSubAir::new(range_checker.bus(), hint_id_max_bits); + assert_eq!( + lt_air.decomp_limbs, HINT_ID_LT_AUX_LEN, + "hint_id_max_bits={hint_id_max_bits} with range_max_bits={} requires {} limbs, but HINT_ID_LT_AUX_LEN={HINT_ID_LT_AUX_LEN}", + range_checker.range_max_bits(), + lt_air.decomp_limbs + ); + Self { + air: HintSpaceProviderAir { hint_bus, lt_air }, + range_checker, + data: Mutex::new(HashMap::new()), + } + } +} + +impl HintSpaceProviderChip { + /// Register a (hint_id, offset, value) triple for the provider trace. + /// Called by consumer chips during trace filling to match each lookup. + /// Deduplicates by (hint_id, offset) and increments the multiplicity counter. + pub fn request(&self, hint_id: F, offset: F, value: F) { + self.data + .lock() + .unwrap() + .entry((hint_id, offset)) + .and_modify(|(v, m)| { + debug_assert_eq!(*v, value, "conflicting values for same (hint_id, offset)"); + *m += F::ONE; + }) + .or_insert((value, F::ONE)); + } +} + +impl HintSpaceProviderChip { + pub fn generate_trace(&self) -> RowMajorMatrix { + let data = std::mem::take(&mut *self.data.lock().unwrap()); + // Collect into a Vec and sort by (hint_id, offset) to satisfy the AIR ordering constraints. + let mut entries: Vec<_> = data.into_iter().collect(); + entries.sort_by_key(|((h, o), _)| (h.as_canonical_u64(), o.as_canonical_u64())); + + let num_non_padding_rows = entries.len(); + let trace_height = num_non_padding_rows.next_power_of_two().max(2); + + let mut rows = F::zero_vec(trace_height * NUM_HINT_SPACE_PROVIDER_COLS); + for (n, ((hint_id, offset), (value, multiplicity))) in entries.iter().enumerate() { + let row = + &mut rows[n * NUM_HINT_SPACE_PROVIDER_COLS..(n + 1) * NUM_HINT_SPACE_PROVIDER_COLS]; + let cols: &mut HintSpaceProviderCols = row.borrow_mut(); + cols.hint_id = *hint_id; + cols.offset = *offset; + cols.value = *value; + cols.multiplicity = *multiplicity; + cols.mult_inv = multiplicity.try_inverse().unwrap(); + cols.curr_is_non_padding = F::ONE; + cols.next_is_non_padding = + if n + 1 < num_non_padding_rows { F::ONE } else { F::ZERO }; + cols.both_non_padding = + if n + 1 < num_non_padding_rows { F::ONE } else { F::ZERO }; + + // Fill auxiliary columns for the uniqueness constraint. + if n + 1 < num_non_padding_rows { + let next_hint_id = entries[n + 1].0 .0; + if next_hint_id != *hint_id { + // hint_id changes: fill IsLtSubAir aux columns + self.air.lt_air.generate_subrow( + ( + self.range_checker.as_ref(), + hint_id.as_canonical_u32(), + next_hint_id.as_canonical_u32(), + ), + (&mut cols.hint_id_lt_aux, &mut cols.hint_id_changed), + ); + } else { + debug_assert_eq!( + entries[n + 1].0 .1, + *offset + F::ONE, + "Offsets for hint_id {:?} are not consecutive: {:?} -> {:?}", + hint_id, + offset, + entries[n + 1].0 .1 + ); + // hint_id_changed = 0, hint_id_lt_aux = [0; ..] (defaults) + } + } + // Last non-padding row: aux columns stay zero (no next non-padding row to compare) + } + // padding rows are already zero (multiplicity = 0) + RowMajorMatrix::new(rows, NUM_HINT_SPACE_PROVIDER_COLS) + } +} + +impl Chip> for HintSpaceProviderChip> +where + Val: PrimeField32, +{ + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { + let trace = self.generate_trace(); + AirProvingContext::simple_no_pis(Arc::new(trace)) + } +} + +impl ChipUsageGetter for HintSpaceProviderChip { + fn air_name(&self) -> String { + get_air_name(&self.air) + } + fn constant_trace_height(&self) -> Option { + None + } + fn current_trace_height(&self) -> usize { + self.data.lock().unwrap().len().next_power_of_two().max(2) + } + fn trace_width(&self) -> usize { + NUM_HINT_SPACE_PROVIDER_COLS + } +} + +#[cfg(feature = "cuda")] +pub mod cuda { + use std::sync::Arc; + + use openvm_circuit::arch::DenseRecordArena; + use openvm_cuda_backend::{ + chip::cpu_proving_ctx_to_gpu, prover_backend::GpuBackend, types::F, types::SC, + }; + use openvm_stark_backend::{ + prover::{cpu::CpuBackend, types::AirProvingContext}, + Chip, + }; + + use super::HintSpaceProviderChip; + + pub struct HintSpaceProviderChipGpu { + pub cpu_chip: Arc>, + } + + impl HintSpaceProviderChipGpu { + pub fn new(cpu_chip: Arc>) -> Self { + Self { cpu_chip } + } + } + + impl Chip for HintSpaceProviderChipGpu { + fn generate_proving_ctx(&self, _: DenseRecordArena) -> AirProvingContext { + let cpu_ctx: AirProvingContext> = + AirProvingContext::simple_no_pis(Arc::new(self.cpu_chip.generate_trace())); + cpu_proving_ctx_to_gpu(cpu_ctx) + } + } +} diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index ce257c9c22..dc1c4dd050 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -39,6 +39,7 @@ mod castf; mod field_arithmetic; mod field_extension; mod fri; +pub(crate) mod hint_space_provider; mod jal_rangecheck; mod loadstore; mod poseidon2; diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index baf18b06a3..9e9cdf5ce8 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -3,7 +3,7 @@ use std::{array::from_fn, borrow::Borrow, sync::Arc}; use itertools::Itertools; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, + system::memory::{offline_checker::{HintBridge, MemoryBridge}, MemoryAddress, CHUNK}, }; use openvm_circuit_primitives::utils::not; use openvm_instructions::LocalOpcode; @@ -36,6 +36,7 @@ use crate::poseidon2::{ pub struct NativePoseidon2Air { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, + pub hint_bridge: HintBridge, pub internal_bus: VerifyBatchBus, pub(crate) subair: Arc>, pub(crate) address_space: F, @@ -45,12 +46,14 @@ impl NativePoseidon2Air, ) -> Self { NativePoseidon2Air { execution_bridge, memory_bridge, + hint_bridge, internal_bus: verify_batch_bus, subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), address_space: F::from_canonical_u32(AS::Native as u32), diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index c9bbf1279e..4207695486 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, system::memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, + offline_checker::{HintBridge, MemoryBridge}, MemoryAddress, }, }; @@ -26,20 +26,23 @@ use crate::{ }, }; -pub const NUM_RWS_FOR_PRODUCT: usize = 2; -pub const NUM_RWS_FOR_LOGUP: usize = 3; - #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, + pub hint_bridge: HintBridge, } impl NativeSumcheckAir { - pub fn new(execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge) -> Self { + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + hint_bridge: HintBridge, + ) -> Self { Self { execution_bridge, memory_bridge, + hint_bridge, } } } @@ -105,7 +108,9 @@ impl Air for NativeSumcheckAir { within_round_limit, should_acc, eval_acc, - is_hint_src_id, + is_writeback, + prod_hint_id, + logup_hint_id, specific, } = local; @@ -115,6 +120,7 @@ impl Air for NativeSumcheckAir { builder.assert_bool(prod_row); builder.assert_bool(logup_row); builder.assert_bool(within_round_limit); + builder.assert_bool(is_writeback); builder.assert_bool(prod_in_round_evaluation); builder.assert_bool(logup_in_round_evaluation); @@ -183,6 +189,15 @@ impl Air for NativeSumcheckAir { builder .when(next.prod_row + next.logup_row) .assert_eq(logup_nested_len, next.logup_nested_len); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(is_writeback, next.is_writeback); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(prod_hint_id, next.prod_hint_id); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(logup_hint_id, next.logup_hint_id); //////////////////////////////////////////////////////////////// // Row transitions from current to next row @@ -235,21 +250,25 @@ impl Air for NativeSumcheckAir { next.start_timestamp, start_timestamp + AB::F::from_canonical_usize(8), ); + + // Prod row timestamp transition builder - .when(prod_row) - .when(next.prod_row + next.logup_row) + .when_transition() + .when_ne(is_end, AB::Expr::ONE) .assert_eq( - next.start_timestamp, - start_timestamp - + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), + prod_row * (next.start_timestamp - start_timestamp), + (prod_in_round_evaluation + prod_next_round_evaluation) + * (AB::Expr::ONE + is_writeback), ); + + // Logup row timestamp transition builder - .when(logup_row) - .when(next.prod_row + next.logup_row) + .when_transition() + .when_ne(is_end, AB::Expr::ONE) .assert_eq( - next.start_timestamp, - start_timestamp - + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), + logup_row * (next.start_timestamp - start_timestamp), + (logup_in_round_evaluation + logup_next_round_evaluation) + * (AB::Expr::TWO + is_writeback), ); // Termination condition @@ -349,7 +368,7 @@ impl Air for NativeSumcheckAir { native_as, register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), - [max_round, is_hint_src_id], + [max_round, is_writeback], first_timestamp + AB::F::from_canonical_usize(7), &header_row_specific.read_records[7], ) @@ -392,21 +411,6 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); - // Read p1, p2 from witness arrays - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), - prod_row_specific.p, - start_timestamp, - &MemoryReadAuxCols { - base: prod_row_specific.ps_record.base, - }, - ) - .eval( - builder, - (prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id), - ); - // Obtain p1, p2 from hint space and write back to witness arrays self.memory_bridge .write( @@ -417,9 +421,21 @@ impl Air for NativeSumcheckAir { ) .eval( builder, - (prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id, + (prod_in_round_evaluation + prod_next_round_evaluation) * is_writeback, ); + // Lookup each element of p in the hint bus to constrain hint_space reads + let prod_enabled: AB::Expr = prod_in_round_evaluation + prod_next_round_evaluation; + for (j, &val) in prod_row_specific.p.iter().enumerate() { + self.hint_bridge.lookup( + builder, + prod_hint_id, + prod_row_specific.data_ptr + AB::F::from_canonical_usize(j), + val, + prod_enabled.clone(), + ); + } + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] .try_into() @@ -432,7 +448,7 @@ impl Air for NativeSumcheckAir { register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + is_writeback * AB::F::ONE, &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); @@ -499,21 +515,6 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); - // Read p1, p2, q1, q2 from witness arrays - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), - logup_row_specific.pq, - start_timestamp, - &MemoryReadAuxCols { - base: logup_row_specific.pqs_record.base, - }, - ) - .eval( - builder, - (logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id), - ); - // Obtain p1, p2, q1, q2 from hint space self.memory_bridge .write( @@ -524,8 +525,20 @@ impl Air for NativeSumcheckAir { ) .eval( builder, - (logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id, + (logup_in_round_evaluation + logup_next_round_evaluation) * is_writeback, ); + + // Lookup each element of pq in the hint bus to constrain hint_space reads + let logup_enabled: AB::Expr = logup_in_round_evaluation + logup_next_round_evaluation; + for (j, &val) in logup_row_specific.pq.iter().enumerate() { + self.hint_bridge.lookup( + builder, + logup_hint_id, + logup_row_specific.data_ptr + AB::F::from_canonical_usize(j), + val, + logup_enabled.clone(), + ); + } let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] .try_into() @@ -546,7 +559,7 @@ impl Air for NativeSumcheckAir { + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + is_writeback * AB::F::ONE, &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -561,7 +574,7 @@ impl Air for NativeSumcheckAir { * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + AB::F::TWO, + start_timestamp + is_writeback * AB::F::ONE + AB::F::ONE, &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index d4fbf2524d..df6a41d1ca 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -17,6 +17,7 @@ use openvm_stark_backend::p3_field::PrimeField32; use crate::{ field_extension::{FieldExtension, EXT_DEG}, fri::elem_to_ext, + hint_space_provider::SharedHintSpaceProviderChip, mem_fill_helper, sumcheck::columns::{ HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, @@ -96,9 +97,11 @@ impl SizedRecord for NativeSumcheck pub struct NativeSumcheckExecutor; #[derive(derive_new::new)] -pub struct NativeSumcheckFiller; +pub struct NativeSumcheckFiller { + pub hint_space_provider: SharedHintSpaceProviderChip, +} -pub type NativeSumcheckChip = VmChipWrapper; +pub type NativeSumcheckChip = VmChipWrapper>; impl Default for NativeSumcheckExecutor { fn default() -> Self { @@ -207,7 +210,7 @@ where challenges_ptr.as_canonical_u32(), head_specific.read_records[6].as_mut(), ); - let [max_round, is_hint_src_id]: [F; 2] = tracing_read_native_helper( + let [max_round, is_writeback]: [F; 2] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, head_specific.read_records[7].as_mut(), @@ -242,21 +245,15 @@ where row.register_ptrs[3] = logup_evals_ptr; row.register_ptrs[4] = r_evals_ptr; row.max_round = max_round; - row.is_hint_src_id = is_hint_src_id; + row.is_writeback = is_writeback; + row.prod_hint_id = prod_evals_id; + row.logup_hint_id = logup_evals_id; } - // Load hints if source is a ptr - let is_hint_src_id = is_hint_src_id > F::ZERO; let prod_evals_id = prod_evals_id.as_canonical_u32(); let logup_evals_id = logup_evals_id.as_canonical_u32(); - let (prod_evals, logup_evals) = if is_hint_src_id { - ( - state.streams.hint_space[prod_evals_id as usize].clone(), - state.streams.hint_space[logup_evals_id as usize].clone(), - ) - } else { - (Vec::new(), Vec::new()) - }; + let prod_evals = state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = state.streams.hint_space[logup_evals_id as usize].clone(); // product rows for (i, prod_row) in rows @@ -292,24 +289,16 @@ where prod_specific.data_ptr = F::from_canonical_u32(start); // read p1, p2 - let ps: [F; EXT_DEG * 2] = if is_hint_src_id { - prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)] - .try_into() - .unwrap() - } else { - tracing_read_native_helper( - state.memory, - prod_evals_ptr.as_canonical_u32() + start, - prod_specific.ps_record.as_mut(), - ) - }; + let ps: [F; EXT_DEG * 2] = prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)] + .try_into() + .unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); prod_specific.p = ps; // If p values come from the hint stream, write back to the actual witness array - if is_hint_src_id { + if is_writeback != F::ZERO { tracing_write_native_inplace( state.memory, prod_evals_ptr.as_canonical_u32() + start, @@ -346,7 +335,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 2; // Either 1 read, 1 write (witness array input), or 2 writes (hint_ptr_id) + cur_timestamp += if is_writeback != F::ZERO { 2 } else { 1 }; // Only write back to the witness array when the is_writeback indicator is true let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -394,17 +383,9 @@ where logup_specific.data_ptr = F::from_canonical_u32(start); // read p1, p2, q1, q2 - let pqs: [F; EXT_DEG * 4] = if is_hint_src_id { - logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] - .try_into() - .unwrap() - } else { - tracing_read_native_helper( - state.memory, - logup_evals_ptr.as_canonical_u32() + start, - logup_specific.pqs_record.as_mut(), - ) - }; + let pqs: [F; EXT_DEG * 4] = logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] + .try_into() + .unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); @@ -413,7 +394,7 @@ where logup_specific.pq = pqs; // write pqs - if is_hint_src_id { + if is_writeback != F::ZERO { tracing_write_native_inplace( state.memory, logup_evals_ptr.as_canonical_u32() + start, @@ -472,7 +453,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += 3; // 1 read, 2 writes (witness array case) or 3 writes (hint space ptr case) + cur_timestamp += if is_writeback != F::ZERO { 3 } else { 2 }; // Only write back to the witness array when the is_writeback indicator is true let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), @@ -541,7 +522,7 @@ where } } -impl TraceFiller for NativeSumcheckFiller { +impl TraceFiller for NativeSumcheckFiller { fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); let start_timestamp = cols.start_timestamp.as_canonical_u32(); @@ -568,42 +549,84 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // obtain p1, p2 - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.ps_record.as_mut(), - ); - // write p_eval - mem_fill_helper( - mem_helper, - start_timestamp + 1, - prod_row_specific.write_record.as_mut(), - ); + // Register each p element with the hint space provider for the lookup bus + for (j, &val) in prod_row_specific.p.iter().enumerate() { + self.hint_space_provider.request( + cols.prod_hint_id, + prod_row_specific.data_ptr + F::from_canonical_usize(j), + val, + ); + } + + if cols.is_writeback == F::ONE { + // writeback p1, p2 + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.ps_record.as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + prod_row_specific.write_record.as_mut(), + ); + } else { + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.write_record.as_mut(), + ); + } } } else if cols.logup_row == F::ONE { let logup_row_specific: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // obtain p1, p2, q1, q2 - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.pqs_record.as_mut(), - ); - // write p_eval - mem_fill_helper( - mem_helper, - start_timestamp + 1, - logup_row_specific.write_records[0].as_mut(), - ); - // write q_eval - mem_fill_helper( - mem_helper, - start_timestamp + 2, - logup_row_specific.write_records[1].as_mut(), - ); + // Register each pq element with the hint space provider for the lookup bus + for (j, &val) in logup_row_specific.pq.iter().enumerate() { + self.hint_space_provider.request( + cols.logup_hint_id, + logup_row_specific.data_ptr + F::from_canonical_usize(j), + val, + ); + } + + if cols.is_writeback == F::ONE { + // writeback p1, p2, q1, q2 + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.pqs_record.as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + logup_row_specific.write_records[1].as_mut(), + ); + } else { + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.write_records[1].as_mut(), + ); + } } } } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index f02f154cf2..ace9ad8a3d 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -73,8 +73,12 @@ pub struct NativeSumcheckCols { // The current final evaluation accumulator. Extension element. pub eval_acc: [T; EXT_DEG], - // Indicator for an alternative source form of the inputs prod_evals/logup_evals - pub is_hint_src_id: T, + // Indicate whether the values read from hint slices should be written back to a witness array + pub is_writeback: T, + + // Hint space IDs for lookup bus interactions + pub prod_hint_id: T, + pub logup_hint_id: T, // /// 1. For header row, 5 registers, ctx, challenges // /// 2. For the rest: max_variables, p1, p2, q1, q2 diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs index 60aba15b95..2dcecd5756 100644 --- a/extensions/native/circuit/src/sumcheck/cuda.rs +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -1,4 +1,4 @@ -use std::{mem::size_of, slice::from_raw_parts, sync::Arc}; +use std::{borrow::Borrow, mem::size_of, slice::from_raw_parts, sync::Arc}; use derive_new::new; use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero}; @@ -7,15 +7,71 @@ use openvm_cuda_backend::{ base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F, }; use openvm_cuda_common::copy::MemCopyH2D; -use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; +use openvm_stark_backend::{p3_field::PrimeField32, prover::types::AirProvingContext, Chip}; -use super::columns::NativeSumcheckCols; -use crate::cuda_abi::sumcheck_cuda; +use super::columns::{LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}; +use crate::{ + cuda_abi::sumcheck_cuda, + hint_space_provider::SharedHintSpaceProviderChip, +}; +use p3_field::FieldAlgebra; #[derive(new)] pub struct NativeSumcheckChipGpu { pub range_checker: Arc, pub timestamp_max_bits: usize, + pub hint_space_provider: SharedHintSpaceProviderChip, +} + +impl NativeSumcheckChipGpu { + /// Scans execution records to populate the hint space provider with + /// (hint_id, offset, value) triples for each hint element referenced + /// by prod and logup rows. This bridges the gap between CPU execution + /// (which produces the records) and GPU trace generation. + fn populate_hint_provider(&self, records: &[u8]) { + let width = NativeSumcheckCols::::width(); + let record_size = width * size_of::(); + if records.len() % record_size != 0 { + return; + } + let num_rows = records.len() / record_size; + + let row_slice = unsafe { + let ptr = records.as_ptr() as *const F; + from_raw_parts(ptr, num_rows * width) + }; + + for i in 0..num_rows { + let row_data = &row_slice[i * width..(i + 1) * width]; + let cols: &NativeSumcheckCols = row_data.borrow(); + + if cols.within_round_limit != F::ONE { + continue; + } + + if cols.prod_row == F::ONE { + let prod_specific: &ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow(); + for (j, &val) in prod_specific.p.iter().enumerate() { + self.hint_space_provider.request( + cols.prod_hint_id, + prod_specific.data_ptr + F::from_canonical_usize(j), + val, + ); + } + } else if cols.logup_row == F::ONE { + let logup_specific: &LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow(); + for (j, &val) in logup_specific.pq.iter().enumerate() { + self.hint_space_provider.request( + cols.logup_hint_id, + logup_specific.data_ptr + F::from_canonical_usize(j), + val, + ); + } + } + } + } } impl Chip for NativeSumcheckChipGpu { @@ -25,6 +81,9 @@ impl Chip for NativeSumcheckChipGpu { return get_empty_air_proving_ctx::(); } + // Populate hint space provider from execution records before GPU upload. + self.populate_hint_provider(records); + let width = NativeSumcheckCols::::width(); let record_size = width * size_of::(); assert_eq!(records.len() % record_size, 0); diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index a311e634b8..7cbd1fe319 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -217,7 +217,7 @@ unsafe fn execute_e12_impl( ctx; let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); - let [max_round, is_hint_space_ids]: [u32; 2] = exec_state + let [max_round, is_writeback]: [u32; 2] = exec_state .vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32) .map(|x: F| x.as_canonical_u32()); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); @@ -228,14 +228,8 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); - let (prod_evals, logup_evals) = if is_hint_space_ids > 0 { - ( - exec_state.streams.hint_space[prod_evals_id as usize].clone(), - exec_state.streams.hint_space[logup_evals_id as usize].clone(), - ) - } else { - (Vec::new(), Vec::new()) - }; + let prod_evals = exec_state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = exec_state.streams.hint_space[logup_evals_id as usize].clone(); for i in 0..num_prod_spec { let start = calculate_3d_ext_idx( @@ -247,16 +241,11 @@ unsafe fn execute_e12_impl( ); if round < max_round - 1 { - let ps: [F; EXT_DEG * 2] = if is_hint_space_ids > 0 { - prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2].try_into().unwrap() - } else { - exec_state.vm_read::<_, { EXT_DEG * 2 }>(NATIVE_AS, prod_evals_ptr + start).try_into().unwrap() - }; - + let ps: [F; EXT_DEG * 2] = prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2].try_into().unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); - if is_hint_space_ids > 0 { + if is_writeback > 0 { exec_state.vm_write(NATIVE_AS, prod_evals_ptr + start, &ps); } @@ -297,17 +286,13 @@ unsafe fn execute_e12_impl( if round < max_round - 1 { // read logup_evals - let pqs: [F; EXT_DEG * 4] = if is_hint_space_ids > 0 { - logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4].try_into().unwrap() - } else { - exec_state.vm_read::<_, { EXT_DEG * 4 }>(NATIVE_AS, logup_evals_ptr + start).try_into().unwrap() - }; + let pqs: [F; EXT_DEG * 4] = logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4].try_into().unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); let q2: [F; EXT_DEG] = pqs[EXT_DEG * 3..EXT_DEG * 4].try_into().unwrap(); - if is_hint_space_ids > 0 { + if is_writeback > 0 { exec_state.vm_write(NATIVE_AS, logup_evals_ptr + start, &pqs); } diff --git a/extensions/rv32im/circuit/src/extension/mod.rs b/extensions/rv32im/circuit/src/extension/mod.rs index 8055c16b54..fc9cb9ab63 100644 --- a/extensions/rv32im/circuit/src/extension/mod.rs +++ b/extensions/rv32im/circuit/src/extension/mod.rs @@ -202,6 +202,7 @@ impl VmCircuitExtension for Rv32I { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); @@ -466,6 +467,7 @@ impl VmCircuitExtension for Rv32M { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); @@ -635,6 +637,7 @@ impl VmCircuitExtension for Rv32Io { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs index 2fe1cb26c0..fb5ce2070f 100644 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ b/extensions/sha256/circuit/src/sha256_chip/air.rs @@ -53,6 +53,7 @@ impl Sha256VmAir { execution_bus, program_bus, memory_bridge, + hint_bridge: _, }: SystemPort, bitwise_lookup_bus: BitwiseOperationLookupBus, ptr_max_bits: usize,