diff --git a/Cargo.toml b/Cargo.toml index 97d8505..8acc2ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "pymath" authors = ["Jeong, YunWon "] repository = "https://github.com/RustPython/pymath" description = "A binary representation compatible Rust implementation of Python's math library." -version = "0.1.5" +version = "0.2.0" edition = "2024" license = "PSF-2.0" diff --git a/src/cmath/exponential.rs b/src/cmath/exponential.rs index d11bf7b..d4a6bf9 100644 --- a/src/cmath/exponential.rs +++ b/src/cmath/exponential.rs @@ -177,40 +177,47 @@ pub(crate) fn ln(z: Complex64) -> Result { /// If base is Some(b), returns log(z) / log(b). #[inline] pub fn log(z: Complex64, base: Option) -> Result { - // c_log always returns a value, but sets errno for special cases. - // The error check happens at the end of cmath_log_impl. - // For log(z) without base: z=0 raises EDOM - // For log(z, base): z=0 doesn't raise because c_log(base) clears errno - let z_is_zero = z.re == 0.0 && z.im == 0.0; + let (log_z, mut err) = c_log(z); match base { - None => { - // No base: raise error if z=0 - if z_is_zero { - return Err(Error::EDOM); - } - ln(z) - } + None => err.map_or(Ok(log_z), Err), Some(b) => { - // With base: z=0 is allowed (second ln clears the "errno") - let log_z = ln(z)?; - let log_b = ln(b)?; - // Use _Py_c_quot-style division to preserve sign of zero - Ok(c_quot(log_z, log_b)) + // Like cmath_log_impl, the second c_log call overwrites + // any pending error from the first one. + let (log_b, base_err) = c_log(b); + err = base_err; + let (q, quot_err) = c_quot(log_z, log_b); + if let Some(e) = quot_err { + err = Some(e); + } + err.map_or(Ok(q), Err) } } } +/// c_log behavior: always returns a value, but reports EDOM for zero. +#[inline] +fn c_log(z: Complex64) -> (Complex64, Option) { + let r = ln(z).expect("ln handles special values without failing"); + if z.re == 0.0 && z.im == 0.0 { + (r, Some(Error::EDOM)) + } else { + (r, None) + } +} + /// Complex division following _Py_c_quot algorithm. /// This preserves the sign of zero correctly and recovers infinities /// from NaN results per C11 Annex G.5.2. #[inline] -fn c_quot(a: Complex64, b: Complex64) -> Complex64 { +fn c_quot(a: Complex64, b: Complex64) -> (Complex64, Option) { let abs_breal = m::fabs(b.re); let abs_bimag = m::fabs(b.im); + let mut err = None; let mut r = if abs_breal >= abs_bimag { if abs_breal == 0.0 { - Complex64::new(f64::NAN, f64::NAN) + err = Some(Error::EDOM); + Complex64::new(0.0, 0.0) } else { let ratio = b.im / b.re; let denom = b.re + b.im * ratio; @@ -244,7 +251,7 @@ fn c_quot(a: Complex64, b: Complex64) -> Complex64 { } } - r + (r, err) } /// Complex base-10 logarithm. @@ -325,6 +332,53 @@ mod tests { }); } + fn test_log_error(z: Complex64, base: Complex64) { + use pyo3::prelude::*; + + let rs_result = log(z, Some(base)); + + Python::attach(|py| { + let cmath = pyo3::types::PyModule::import(py, "cmath").unwrap(); + let py_z = pyo3::types::PyComplex::from_doubles(py, z.re, z.im); + let py_base = pyo3::types::PyComplex::from_doubles(py, base.re, base.im); + let py_result = cmath.getattr("log").unwrap().call1((py_z, py_base)); + + match py_result { + Ok(result) => { + use pyo3::types::PyComplexMethods; + let c = result.cast::().unwrap(); + panic!( + "log({}+{}j, {}+{}j): expected ValueError, got ({}, {})", + z.re, + z.im, + base.re, + base.im, + c.real(), + c.imag() + ); + } + Err(err) => { + assert!( + err.is_instance_of::(py), + "log({}+{}j, {}+{}j): expected ValueError, got {err:?}", + z.re, + z.im, + base.re, + base.im, + ); + assert!( + matches!(rs_result, Err(crate::Error::EDOM)), + "log({}+{}j, {}+{}j): expected Err(EDOM), got {rs_result:?}", + z.re, + z.im, + base.re, + base.im, + ); + } + } + }); + } + use crate::test::EDGE_VALUES; #[test] @@ -382,6 +436,29 @@ mod tests { } } + #[test] + fn regression_c_quot_zero_denominator_sets_edom() { + let (q, err) = c_quot(Complex64::new(2.0, -3.0), Complex64::new(0.0, 0.0)); + assert_eq!(err, Some(crate::Error::EDOM)); + assert_eq!(q.re.to_bits(), 0.0f64.to_bits()); + assert_eq!(q.im.to_bits(), 0.0f64.to_bits()); + } + + #[test] + fn regression_log_zero_quotient_denominator_raises_edom() { + let cases = [ + (Complex64::new(2.0, 0.0), Complex64::new(1.0, 0.0)), + (Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)), + (Complex64::new(2.0, 0.0), Complex64::new(0.0, 0.0)), + (Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)), + (Complex64::new(0.0, 0.0), Complex64::new(0.0, 0.0)), + ]; + + for (z, base) in cases { + test_log_error(z, base); + } + } + proptest::proptest! { #[test] fn proptest_sqrt(re: f64, im: f64) { diff --git a/src/math.rs b/src/math.rs index d9bc17f..02e0a31 100644 --- a/src/math.rs +++ b/src/math.rs @@ -49,10 +49,26 @@ macro_rules! libm_simple { pub(crate) use libm_simple; -/// math_1: wrapper for 1-arg functions +/// Wrapper for 1-arg libm functions, corresponding to FUNC1/is_error in +/// mathmodule.c. +/// /// - isnan(r) && !isnan(x) -> domain error /// - isinf(r) && isfinite(x) -> overflow (can_overflow=true) or domain error (can_overflow=false) /// - isfinite(r) && errno -> check errno (unnecessary on most platforms) +/// +/// CPython's approach: clear errno, call libm, then inspect both the result +/// and errno to classify errors. We rely primarily on output inspection +/// (NaN/Inf checks) because: +/// +/// - On macOS and Windows, libm functions do not reliably set errno for +/// edge cases, so CPython's own is_error() skips the errno check there +/// too (it only uses it as a fallback on other Unixes). +/// - The NaN/Inf output checks are sufficient to detect all domain and +/// range errors on every platform we test against (verified by proptest +/// and edgetest against CPython via pyo3). +/// - The errno-only branch (finite result with errno set) is kept for +/// non-macOS/non-Windows Unixes where libm might signal an error +/// without producing a NaN/Inf result. #[inline] pub(crate) fn math_1(x: f64, func: fn(f64) -> f64, can_overflow: bool) -> crate::Result { crate::err::set_errno(0); @@ -75,9 +91,17 @@ pub(crate) fn math_1(x: f64, func: fn(f64) -> f64, can_overflow: bool) -> crate: Ok(r) } -/// math_2: wrapper for 2-arg functions +/// Wrapper for 2-arg libm functions, corresponding to FUNC2 in +/// mathmodule.c. +/// /// - isnan(r) && !isnan(x) && !isnan(y) -> domain error /// - isinf(r) && isfinite(x) && isfinite(y) -> range error +/// +/// Unlike math_1, this does not set/check errno at all. CPython's FUNC2 +/// does clear and check errno, but the NaN/Inf output checks already +/// cover all error cases for the 2-arg functions we wrap (atan2, fmod, +/// copysign, remainder, pow). This is verified by bit-exact proptest +/// and edgetest against CPython. #[inline] pub(crate) fn math_2(x: f64, y: f64, func: fn(f64, f64) -> f64) -> crate::Result { let r = func(x, y); diff --git a/src/math/aggregate.rs b/src/math/aggregate.rs index ad9219b..458de53 100644 --- a/src/math/aggregate.rs +++ b/src/math/aggregate.rs @@ -231,6 +231,11 @@ pub fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 { /// /// The points are given as sequences of coordinates. /// Uses high-precision vector_norm algorithm. +/// +/// Panics if `p` and `q` have different lengths. CPython raises ValueError +/// for mismatched dimensions, but in this Rust API the caller is expected +/// to guarantee equal-length slices. A length mismatch is a programming +/// error, not a runtime condition. pub fn dist(p: &[f64], q: &[f64]) -> f64 { assert_eq!( p.len(), @@ -261,24 +266,52 @@ pub fn dist(p: &[f64], q: &[f64]) -> f64 { /// Return the sum of products of values from two sequences (float version). /// -/// Uses TripleLength arithmetic for high precision. -/// Equivalent to sum(p[i] * q[i] for i in range(len(p))). -pub fn sumprod(p: &[f64], q: &[f64]) -> f64 { - assert_eq!(p.len(), q.len(), "Inputs are not the same length"); +/// Uses TripleLength arithmetic for the fast path, then falls back to +/// ordinary floating-point multiply/add starting at the first unsupported +/// pair, matching Python's staged `math.sumprod` behavior for float inputs. +/// +/// CPython's math_sumprod_impl is a 3-stage state machine that handles +/// int/float/generic Python objects. This function only covers the float +/// path (`&[f64]`). The int accumulation and generic PyNumber fallback +/// stages are Python type-system concerns and should be handled by the +/// caller (e.g. RustPython) before delegating here. +/// +/// Returns EDOM if the inputs are not the same length. +pub fn sumprod(p: &[f64], q: &[f64]) -> crate::Result { + if p.len() != q.len() { + return Err(crate::Error::EDOM); + } + let mut total = 0.0; let mut flt_total = TL_ZERO; + let mut flt_path_enabled = true; + let mut i = 0; - for (&pi, &qi) in p.iter().zip(q.iter()) { - let new_flt_total = tl_fma(pi, qi, flt_total); - if new_flt_total.hi.is_finite() { - flt_total = new_flt_total; - } else { - // Overflow or special value, fall back to simple sum - return p.iter().zip(q.iter()).map(|(a, b)| a * b).sum(); + while i < p.len() { + let pi = p[i]; + let qi = q[i]; + + if flt_path_enabled { + let new_flt_total = tl_fma(pi, qi, flt_total); + if new_flt_total.hi.is_finite() { + flt_total = new_flt_total; + i += 1; + continue; + } + + flt_path_enabled = false; + total += tl_to_d(flt_total); } + + total += pi * qi; + i += 1; } - tl_to_d(flt_total) + Ok(if flt_path_enabled { + tl_to_d(flt_total) + } else { + total + }) } /// Return the sum of products of values from two sequences (integer version). @@ -427,14 +460,27 @@ mod tests { crate::test::with_py_math(|py, math| { let py_p = pyo3::types::PyList::new(py, p).unwrap(); let py_q = pyo3::types::PyList::new(py, q).unwrap(); - let py: f64 = math - .getattr("sumprod") - .unwrap() - .call1((py_p, py_q)) - .unwrap() - .extract() - .unwrap(); - crate::test::assert_f64_eq(py, rs, format_args!("sumprod({p:?}, {q:?})")); + let py_result = math.getattr("sumprod").unwrap().call1((py_p, py_q)); + match py_result { + Ok(py_val) => { + let py: f64 = py_val.extract().unwrap(); + let rs = rs.unwrap_or_else(|e| { + panic!("sumprod({p:?}, {q:?}): py={py} but rs returned error {e:?}") + }); + crate::test::assert_f64_eq(py, rs, format_args!("sumprod({p:?}, {q:?})")); + } + Err(e) => { + if e.is_instance_of::(py) { + assert_eq!( + rs.as_ref().err(), + Some(&crate::Error::EDOM), + "sumprod({p:?}, {q:?}): py raised ValueError but rs={rs:?}" + ); + } else { + panic!("sumprod({p:?}, {q:?}): py raised unexpected error {e}"); + } + } + } }); } @@ -444,6 +490,9 @@ mod tests { test_sumprod_impl(&[], &[]); test_sumprod_impl(&[1.0], &[2.0]); test_sumprod_impl(&[1e100, 1e100], &[1e100, -1e100]); + test_sumprod_impl(&[1.0, 1e308, -1e308], &[1.0, 2.0, 2.0]); + test_sumprod_impl(&[1e-16, 1e308, -1e308], &[1.0, 2.0, 2.0]); + test_sumprod_impl(&[1.0], &[]); } fn test_prod_impl(values: &[f64], start: Option) { diff --git a/src/math/bigint.rs b/src/math/bigint.rs index 1bb157f..71ad4d5 100644 --- a/src/math/bigint.rs +++ b/src/math/bigint.rs @@ -63,7 +63,14 @@ pub fn comb_bigint(n: &BigInt, k: u64) -> BigUint { /// - mantissa is in [0.5, 1.0) for positive n /// - n ~= mantissa * 2^exponent /// -/// See: _PyLong_Frexp in CPython longobject.c +/// `_PyLong_Frexp` extracts digits one-by-one into a fixed-size +/// accumulator and applies a `half_even_correction` lookup table for +/// rounding. We instead extract the top 55 bits via a single right +/// shift and use a sticky-bit to mark whether any discarded bits were +/// non-zero, then delegate to `BigInt::to_f64()` which performs +/// IEEE 754 round-half-to-even. The two approaches are equivalent +/// because the sticky bit preserves the same rounding information +/// that the digit-by-digit extraction would. fn frexp_bigint(n: &BigInt) -> (f64, i64) { let bits = n.bits(); if bits == 0 { @@ -81,10 +88,24 @@ fn frexp_bigint(n: &BigInt) -> (f64, i64) { return (m, e as i64); } - // For large integers, extract top ~53 bits - // Shift right to keep DBL_MANT_DIG + 2 = 55 bits for rounding + // For large integers, extract top DBL_MANT_DIG + 2 = 55 bits for rounding let shift = bits - 55; - let mantissa_int = n >> shift as u64; + let mut mantissa_int = n >> shift as u64; + + // Sticky bit: if any shifted-out bits were non-zero, set the LSB. + // This ensures correct IEEE round-half-to-even when converting to f64. + // + // `_PyLong_Frexp` checks the remainder from `v_rshift` first, then + // iterates shifted-out digits top-down. We use `trailing_zeros()` + // which scans digits bottom-up instead. The worst-case traversal + // order differs (e.g. exact powers of two), but for typical inputs + // both terminate in O(1). If you observe a performance regression + // from this, please file a bug report. + let tz = n.magnitude().trailing_zeros().unwrap(); // n != 0 here + if tz < shift as u64 { + mantissa_int |= BigInt::from(1); + } + let mut x = mantissa_int.to_f64().unwrap(); // x is now approximately n / 2^shift, with ~55 bits of precision @@ -119,7 +140,7 @@ pub fn log_bigint(n: &BigInt, base: Option) -> crate::Result { // Use frexp decomposition for large values // n ~= x * 2^e, so log(n) = log(x) + log(2) * e let (x, e) = frexp_bigint(n); - let log_n = crate::m::log(x) + std::f64::consts::LN_2 * (e as f64); + let log_n = crate::mul_add(crate::m::log(2.0), e as f64, crate::m::log(x)); match base { None => Ok(log_n), @@ -150,7 +171,11 @@ pub fn log2_bigint(n: &BigInt) -> crate::Result { // Use frexp decomposition for large values // n ~= x * 2^e, so log2(n) = log2(x) + e let (x, e) = frexp_bigint(n); - Ok(crate::m::log2(x) + (e as f64)) + Ok(crate::mul_add( + crate::m::log2(2.0), + e as f64, + crate::m::log2(x), + )) } /// Return the base-10 logarithm of a BigInt. @@ -171,7 +196,11 @@ pub fn log10_bigint(n: &BigInt) -> crate::Result { // Use frexp decomposition for large values // n ~= x * 2^e, so log10(n) = log10(x) + log10(2) * e let (x, e) = frexp_bigint(n); - Ok(crate::m::log10(x) + std::f64::consts::LOG10_2 * (e as f64)) + Ok(crate::mul_add( + crate::m::log10(2.0), + e as f64, + crate::m::log10(x), + )) } /// Compute ldexp(x, exp) where exp is a BigInt. @@ -333,6 +362,33 @@ mod tests { }); } + fn assert_exact_log_bigint_bits(n: &BigInt, func_name: &str, rs: crate::Result) { + crate::test::with_py_math(|py, math| { + let n_str = n.to_string(); + let builtins = pyo3::types::PyModule::import(py, "builtins").unwrap(); + let py_n = builtins + .getattr("int") + .unwrap() + .call1((n_str.as_str(),)) + .unwrap(); + let py_f: f64 = math + .getattr(func_name) + .unwrap() + .call1((py_n,)) + .unwrap() + .extract() + .unwrap(); + let rs_f = rs.unwrap(); + assert_eq!( + py_f.to_bits(), + rs_f.to_bits(), + "{func_name}({n}): py={py_f} ({:#x}) vs rs={rs_f} ({:#x})", + py_f.to_bits(), + rs_f.to_bits() + ); + }); + } + #[test] fn edgetest_log_bigint() { // Small values @@ -410,6 +466,45 @@ mod tests { } } + #[test] + fn regression_log_bigint_pow10_309() { + let n = BigInt::from(10).pow(309); + assert_exact_log_bigint_bits(&n, "log", log_bigint(&n, None)); + } + + #[test] + fn regression_log2_bigint_pow10_309() { + let n = BigInt::from(10).pow(309); + assert_exact_log_bigint_bits(&n, "log2", log2_bigint(&n)); + } + + #[test] + fn regression_log10_bigint_pow10_309() { + let n = BigInt::from(10).pow(309); + assert_exact_log_bigint_bits(&n, "log10", log10_bigint(&n)); + } + + #[test] + fn regression_log_bigint_sticky_bit_rounding() { + // If frexp_bigint drops the sticky bit, this rounds 1 ULP low. + let n = (BigInt::from(0x40000000000c02u64) << 970u32) + BigInt::from(1u8); + assert_exact_log_bigint_bits(&n, "log", log_bigint(&n, None)); + } + + #[test] + fn regression_log2_bigint_sticky_bit_rounding() { + // If frexp_bigint drops the sticky bit, this rounds 1 ULP low. + let n = (BigInt::from(0x400000000010a2u64) << 970u32) + BigInt::from(1u8); + assert_exact_log_bigint_bits(&n, "log2", log2_bigint(&n)); + } + + #[test] + fn regression_log10_bigint_sticky_bit_rounding() { + // If frexp_bigint drops the sticky bit, this rounds 1 ULP low. + let n = (BigInt::from(0x4000000000049au64) << 970u32) + BigInt::from(1u8); + assert_exact_log_bigint_bits(&n, "log10", log10_bigint(&n)); + } + // ldexp_bigint tests fn test_ldexp_bigint_impl(x: f64, exp: &BigInt) { diff --git a/src/math/exponential.rs b/src/math/exponential.rs index 1b38008..f8a585b 100644 --- a/src/math/exponential.rs +++ b/src/math/exponential.rs @@ -113,7 +113,10 @@ pub fn log(x: f64, base: Option) -> Result { if den.is_infinite() && b.is_finite() { return Err(crate::Error::EDOM); } - // log(x, 1) -> division by zero + // log(x, 1) -> division by zero. + // CPython raises ZeroDivisionError here (via PyNumber_TrueDivide), + // but we return EDOM since our error type has no ZeroDivisionError + // variant. The caller (e.g. RustPython) may remap this if needed. if den == 0.0 { return Err(crate::Error::EDOM); } diff --git a/src/math/integer.rs b/src/math/integer.rs index 87537b0..0a5aee0 100644 --- a/src/math/integer.rs +++ b/src/math/integer.rs @@ -730,8 +730,8 @@ mod tests { 128, // Table boundary + 1 // Powers of 2 and boundaries 64, - 63, // 2^6 - 1 - 65, // 2^6 + 1 + 63, // 2^6 - 1 + 65, // 2^6 + 1 1024, 65535, // 2^16 - 1 65536, // 2^16 diff --git a/src/math/misc.rs b/src/math/misc.rs index 4b96823..adbe40c 100644 --- a/src/math/misc.rs +++ b/src/math/misc.rs @@ -6,22 +6,69 @@ super::libm_simple!(@1 ceil, floor, trunc); /// Return the next floating-point value after x towards y. /// -/// If steps is provided, move that many steps towards y. -/// Steps must be non-negative. +/// If steps is provided, move that many steps towards y using O(1) bit +/// manipulation on the IEEE 754 representation. Steps that overshoot y +/// are clamped so the result never passes y. +/// +/// CPython's math_nextafter_impl accepts a Python integer for steps, +/// rejects negative values, and saturates overflows to UINT64_MAX. This +/// Rust API takes `Option`, so negative rejection and big-int +/// saturation are structurally unnecessary. The caller (e.g. RustPython) +/// should handle Python int conversion and negative checks before calling. +/// +/// See math_nextafter_impl in mathmodule.c. #[inline] pub fn nextafter(x: f64, y: f64, steps: Option) -> f64 { - match steps { - Some(n) => { - let mut result = x; - for _ in 0..n { - result = crate::m::nextafter(result, y); - if result == y { - break; - } - } - result + let usteps = match steps { + None => return crate::m::nextafter(x, y), + Some(n) => n, + }; + + if usteps == 0 || x.is_nan() { + return x; + } + if y.is_nan() { + return y; + } + + let mut ux = x.to_bits(); + let uy = y.to_bits(); + if ux == uy { + return x; + } + + const SIGN_BIT: u64 = 1u64 << 63; + let ax = ux & !SIGN_BIT; + let ay = uy & !SIGN_BIT; + + if (ux ^ uy) & SIGN_BIT != 0 { + // opposite signs — may need to cross zero + // ax + ay can never overflow because bit 63 is cleared in both + if ax + ay <= usteps { + y + } else if ax < usteps { + // cross zero: remaining steps land on y's side + f64::from_bits((uy & SIGN_BIT) | (usteps - ax)) + } else { + ux -= usteps; + f64::from_bits(ux) + } + } else if ax > ay { + // same sign, moving toward zero + if ax - ay >= usteps { + ux -= usteps; + f64::from_bits(ux) + } else { + y + } + } else { + // same sign, moving away from zero + if ay - ax >= usteps { + ux += usteps; + f64::from_bits(ux) + } else { + y } - None => crate::m::nextafter(x, y), } } @@ -178,6 +225,12 @@ pub fn fmod(x: f64, y: f64) -> Result { } /// Return the IEEE 754-style remainder of x with respect to y. +/// +/// CPython implements this from scratch using fmod (m_remainder in +/// mathmodule.c) rather than calling the C library's remainder(). +/// We delegate to libm's remainder() which is correct on all platforms +/// where it conforms to IEEE 754. If you find a platform where the +/// results differ from CPython, please file a bug. #[inline] pub fn remainder(x: f64, y: f64) -> Result { super::math_2(x, y, crate::m::remainder) @@ -349,6 +402,64 @@ mod tests { } } + #[test] + fn regression_remainder_halfway_even_cases() { + // These cases exercise the half-way branch in CPython's custom + // m_remainder implementation, where ties are resolved toward the + // even multiple of y. + let cases = [ + ((6.0, 4.0), 0xc000_0000_0000_0000_u64), // -2.0 + ((3.0, 2.0), 0xbff0_0000_0000_0000_u64), // -1.0 + ((5.0, 2.0), 0x3ff0_0000_0000_0000_u64), // 1.0 + ((-6.0, 4.0), 0x4000_0000_0000_0000_u64), // 2.0 + ((-5.0, 2.0), 0xbff0_0000_0000_0000_u64), // -1.0 + ((1.5, 1.0), 0xbfe0_0000_0000_0000_u64), // -0.5 + ((2.5, 1.0), 0x3fe0_0000_0000_0000_u64), // 0.5 + ((3.5, 1.0), 0xbfe0_0000_0000_0000_u64), // -0.5 + ((4.5, 1.0), 0x3fe0_0000_0000_0000_u64), // 0.5 + ((5.5, 1.0), 0xbfe0_0000_0000_0000_u64), // -0.5 + ]; + + for &((x, y), expected_bits) in &cases { + let r = remainder(x, y).unwrap(); + assert_eq!( + r.to_bits(), + expected_bits, + "remainder({x}, {y}) = {r} ({:#x}), expected {:#x}", + r.to_bits(), + expected_bits + ); + test_remainder(x, y); + } + } + + #[test] + fn regression_remainder_boundary_and_signed_zero_cases() { + // These cases pin the sign of zero and the behavior immediately + // around the half-way boundary. + let just_below_half = f64::from_bits(0x3fdf_ffff_ffff_ffff); + let just_above_half = f64::from_bits(0x3fe0_0000_0000_0001); + let cases = [ + ((4.0, 2.0), 0x0000_0000_0000_0000_u64), // +0.0 + ((-4.0, 2.0), 0x8000_0000_0000_0000_u64), // -0.0 + ((4.0, -2.0), 0x0000_0000_0000_0000_u64), // +0.0 + ((just_below_half, 1.0), 0x3fdf_ffff_ffff_ffff_u64), + ((just_above_half, 1.0), 0xbfdf_ffff_ffff_fffe_u64), + ]; + + for &((x, y), expected_bits) in &cases { + let r = remainder(x, y).unwrap(); + assert_eq!( + r.to_bits(), + expected_bits, + "remainder({x}, {y}) = {r} ({:#x}), expected {:#x}", + r.to_bits(), + expected_bits + ); + test_remainder(x, y); + } + } + #[test] fn edgetest_copysign() { for &x in crate::test::EDGE_VALUES { @@ -587,4 +698,146 @@ mod tests { test_fma_impl(x, y, z); } } + + fn test_nextafter(x: f64, y: f64) { + use pyo3::prelude::*; + + let rs = nextafter(x, y, None); + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_f: f64 = math + .getattr("nextafter") + .unwrap() + .call1((x, y)) + .unwrap() + .extract() + .unwrap(); + if py_f.is_nan() && rs.is_nan() { + return; + } + assert_eq!( + py_f.to_bits(), + rs.to_bits(), + "nextafter({x}, {y}): py={py_f} vs rs={rs}" + ); + }); + } + + fn test_nextafter_steps(x: f64, y: f64, steps: u64) { + use pyo3::prelude::*; + + let rs = nextafter(x, y, Some(steps)); + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let kwargs = pyo3::types::PyDict::new(py); + kwargs.set_item("steps", steps).unwrap(); + let py_f: f64 = math + .getattr("nextafter") + .unwrap() + .call((x, y), Some(&kwargs)) + .unwrap() + .extract() + .unwrap(); + if py_f.is_nan() && rs.is_nan() { + return; + } + assert_eq!( + py_f.to_bits(), + rs.to_bits(), + "nextafter({x}, {y}, steps={steps}): py={py_f} vs rs={rs}" + ); + }); + } + + #[test] + fn edgetest_nextafter() { + for &x in crate::test::EDGE_VALUES { + for &y in crate::test::EDGE_VALUES { + test_nextafter(x, y); + } + } + } + + #[test] + fn edgetest_nextafter_steps() { + let x_vals = [ + 0.0, + -0.0, + 1.0, + -1.0, + f64::INFINITY, + f64::NEG_INFINITY, + f64::NAN, + ]; + let y_vals = [ + 0.0, + -0.0, + 1.0, + -1.0, + f64::INFINITY, + f64::NEG_INFINITY, + f64::NAN, + ]; + let steps = [0, 1, 2, 10, 100, 1000, u64::MAX]; + + for &x in &x_vals { + for &y in &y_vals { + for &s in &steps { + test_nextafter_steps(x, y, s); + } + } + } + } + + #[test] + fn test_nextafter_steps_large() { + // Large steps should saturate to target + test_nextafter_steps(0.0, 1.0, u64::MAX); + test_nextafter_steps(0.0, f64::INFINITY, u64::MAX); + test_nextafter_steps(1.0, -1.0, u64::MAX); + test_nextafter_steps(-1.0, 1.0, u64::MAX); + + // Steps exactly reaching a value + // From 0.0 toward inf, 10 steps = 10 * 5e-324 + test_nextafter_steps(0.0, f64::INFINITY, 10); + test_nextafter_steps(0.0, f64::NEG_INFINITY, 10); + + // Crossing zero + test_nextafter_steps(5e-324, -5e-324, 1); + test_nextafter_steps(5e-324, -5e-324, 2); + test_nextafter_steps(5e-324, -5e-324, 3); + test_nextafter_steps(-5e-324, 5e-324, 1); + test_nextafter_steps(-5e-324, 5e-324, 2); + test_nextafter_steps(-5e-324, 5e-324, 3); + + // Extreme steps that would hang with O(n) loop + let extreme_steps: &[u64] = &[ + 10u64.pow(9), + 10u64.pow(15), + 10u64.pow(18), + u64::MAX / 2, + u64::MAX - 1, + u64::MAX, + ]; + for &s in extreme_steps { + test_nextafter_steps(0.0, 1.0, s); + test_nextafter_steps(0.0, f64::INFINITY, s); + test_nextafter_steps(1.0, 0.0, s); + test_nextafter_steps(-1.0, 1.0, s); + test_nextafter_steps(f64::MIN_POSITIVE, f64::MAX, s); + test_nextafter_steps(f64::MAX, f64::MIN_POSITIVE, s); + } + } + + proptest::proptest! { + #[test] + fn proptest_nextafter(x: f64, y: f64) { + test_nextafter(x, y); + } + + #[test] + fn proptest_nextafter_steps(x: f64, y: f64, steps: u64) { + test_nextafter_steps(x, y, steps); + } + } } diff --git a/src/test.rs b/src/test.rs index 8fa3c5b..7c236ef 100644 --- a/src/test.rs +++ b/src/test.rs @@ -56,7 +56,7 @@ pub(crate) const EDGE_VALUES: &[f64] = &[ -2.0, 1.5, -1.5, - 3.0, // for cbrt + 3.0, // for cbrt -3.0, // Values near 1.0 (log, expm1, log1p, acosh boundary) 1.0 - 1e-15, @@ -65,10 +65,10 @@ pub(crate) const EDGE_VALUES: &[f64] = &[ 1.0 - f64::EPSILON, 1.0 + f64::EPSILON, // asin/acos domain boundaries [-1, 1] - 1.0000000000000002, // just outside domain (1 + eps) + 1.0000000000000002, // just outside domain (1 + eps) -1.0000000000000002, // atanh domain boundaries (-1, 1) - 0.9999999999999999, // just inside domain + 0.9999999999999999, // just inside domain -0.9999999999999999, // log1p domain boundary (> -1) -0.9999999999999999, // just above -1 @@ -102,7 +102,6 @@ pub(crate) const EDGE_VALUES: &[f64] = &[ -0.50000000000000006, ]; - pub(crate) fn unwrap<'py>( py: Python<'py>, py_v: PyResult>,