diff --git a/Numeric/SpecFunctions.hs b/Numeric/SpecFunctions.hs index f6f1621..c5591c9 100644 --- a/Numeric/SpecFunctions.hs +++ b/Numeric/SpecFunctions.hs @@ -32,6 +32,9 @@ module Numeric.SpecFunctions ( , log1p , log1pmx , log2 + -- * Log-sum-exp + , logSumExp + , logSumExpPair -- * Exponent , expm1 -- * Factorial diff --git a/Numeric/SpecFunctions/Internal.hs b/Numeric/SpecFunctions/Internal.hs index 608d31c..2de5ef3 100644 --- a/Numeric/SpecFunctions/Internal.hs +++ b/Numeric/SpecFunctions/Internal.hs @@ -936,6 +936,35 @@ log1pmx x where ax = abs x +-- | Compute log(sum(exp(x_i))) in a numerically stable way using +-- the log-sum-exp trick. This is useful when working with log +-- probabilities to avoid overflow and underflow. +-- +-- Uses the identity: +-- +-- \[ +-- \log \sum_i \exp(x_i) = m + \log \sum_i \exp(x_i - m) +-- \] +-- +-- where \(m = \max_i x_i\). +-- +-- Returns @-Infinity@ for an empty vector. +logSumExp :: U.Vector Double -> Double +logSumExp xs + | U.null xs = m_neg_inf + | otherwise = m + log (U.sum (U.map (\x -> exp (x - m)) xs)) + where + m = U.maximum xs + +-- | Compute @log(exp(a) + exp(b))@ in a numerically stable way. +-- +-- This is a special case of 'logSumExp' for two arguments, useful +-- when combining two log-probabilities. +logSumExpPair :: Double -> Double -> Double +logSumExpPair a b + | a >= b = a + log1p (exp (b - a)) + | otherwise = b + log1p (exp (a - b)) + -- | /O(log n)/ Compute the logarithm in base 2 of the given value. log2 :: Int -> Int log2 v0 diff --git a/tests/Tests/SpecFunctions.hs b/tests/Tests/SpecFunctions.hs index d51ca8f..9c44aed 100644 --- a/tests/Tests/SpecFunctions.hs +++ b/tests/Tests/SpecFunctions.hs @@ -91,6 +91,31 @@ tests = testGroup "Special functions" checkTabularPure 1 (show x) exact (log1p x) ] ---------------- + , testGroup "logSumExp" + [ testProperty "logSumExp [a] == a" $ \a -> + not (isNaN a) ==> logSumExp (U.singleton a) == a + , testProperty "logSumExpPair commutative" $ \a b -> + not (isNaN a) && not (isNaN b) ==> + logSumExpPair a b == logSumExpPair b a + , testProperty "logSumExpPair a a == a + log 2" $ \a -> + not (isNaN a) && not (isInfinite a) ==> + within 2 (logSumExpPair a a) (a + log 2) + , testProperty "logSumExp recovers log(sum(exp(x)))" $ \(getNonEmpty -> xs) -> + let v = U.fromList xs + naive = log (U.sum (U.map exp v)) + stable = logSumExp v + in not (any isNaN xs) && not (any isInfinite xs) && all (\x -> abs x < 300) xs ==> + within 4 naive stable + , testCase "logSumExp empty == -Infinity" $ + assertBool "should be -Infinity" (isInfinite (logSumExp U.empty) && logSumExp U.empty < 0) + , testCase "logSumExp with large values" $ do + let result = logSumExp (U.fromList [1000, 1001, 1002]) + assertBool "should be close to 1002.41" (within 4 result 1002.4076059644443) + , testCase "logSumExpPair log-probability combination" $ do + let result = logSumExpPair (-1000) (-1001) + assertBool "should be close to -999.687" (within 4 result (-999.6867383124818)) + ] + ---------------- , testGroup "gamma function" [ testCase "logGamma table [fractional points" $ forTable "tests/tables/loggamma.dat" $ \[x, exact] -> do