diff --git a/.claude/SKILL.md b/.claude/SKILL.md index 2c7e3f984..40de68a6b 100644 --- a/.claude/SKILL.md +++ b/.claude/SKILL.md @@ -304,6 +304,26 @@ Always use explicit multiplication and parentheses in Big-O expressions for clar // Space: O(n) ``` +### For Loop Body on Its Own Line + +Always place the body of a `for` loop on its own line, even for single statements. +This improves readability, especially in nested loops: + +```java +// ✗ BAD — body on same line as for +for (int j = 0; j < n; j++) augmented[i][j] = matrix[i][j]; + +// ✓ GOOD — body on its own line +for (int j = 0; j < n; j++) + augmented[i][j] = matrix[i][j]; + +// ✓ GOOD — nested for loops, each level on its own line +for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + for (int k = 0; k < n; k++) + result[i][j] += m1[i][k] * m2[k][j]; +``` + ### Avoid Java Streams Streams hurt readability for learners. Use plain loops instead: diff --git a/README.md b/README.md index 6d09a7cd9..f997bc418 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ Algorithms and data structures are fundamental to efficient code and good software design. Creating and designing excellent algorithms is required for being an exemplary programmer. This repository's goal is to demonstrate how to correctly implement common data structures and algorithms in the simplest and most elegant ways. +🎬 Many of the algorithms and data structures in this repo have companion video explanations on the [William Fiset YouTube channel](https://www.youtube.com/@WilliamFiset-videos) — so if the code alone doesn't click, grab some popcorn and watch the videos! + # Running an algorithm implementation To compile and run any of the algorithms here, you need at least JDK version 8 and [Bazel](https://bazel.build/) diff --git a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/Node.java b/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/Node.java index 00279ac4f..59bab0eb8 100644 --- a/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/Node.java +++ b/src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/Node.java @@ -1,23 +1,49 @@ +package com.williamfiset.algorithms.datastructures.segmenttree; + /** - * Segment Trees are an extremely useful data structure when dealing with ranges or intervals. They - * take O(n) time and space to construct, but they can do range updates or queries in O(log(n)) - * time. This data structure is quite flexible; although the code below supports minimum and sum - * queries, these could be modified to perform other types of queries. This implementation uses lazy - * propagation (which allows for O(log(n)) range updates instead of O(n)). It should also be noted - * that this implementation could easily be modified to support coordinate compression (you should - * only have to change a few lines in the constructor). + * Pointer-Based Segment Tree with Lazy Propagation + * + * A segment tree built with explicit left/right child pointers (rather than + * a flat array). Supports range sum queries, range min queries, and range + * updates (add a value to every element in an interval), all in O(log(n)). + * + * Lazy propagation defers updates to child nodes until they are actually + * needed, keeping range updates at O(log(n)) instead of O(n). + * + * Each node covers a half-open interval [minPos, maxPos). Leaves cover a + * single element [i, i+1). The combine function computes both sum and min + * simultaneously as values propagate up. + * + * Use cases: + * - Range sum / min queries with range add updates + * - Problems requiring coordinate compression (easy to adapt constructor) + * + * Time: O(n) construction, O(log(n)) per query and update + * Space: O(n) * * @author Micah Stairs */ -package com.williamfiset.algorithms.datastructures.segmenttree; - public class Node { - static final int INF = Integer.MAX_VALUE; + private static final int INF = Integer.MAX_VALUE; + + private Node left, right; + + // This node covers the half-open interval [minPos, maxPos) + private int minPos, maxPos; - Node left, right; - int minPos, maxPos, min = 0, sum = 0, lazy = 0; + // Aggregate values for this node's range + private int min = 0, sum = 0; + // Pending update that hasn't been pushed to children yet + private int lazy = 0; + + /** + * Creates a segment tree from an array of values. + * + * @param values the initial values for the leaves + * @throws IllegalArgumentException if values is null + */ public Node(int[] values) { if (values == null) throw new IllegalArgumentException("Null input to segment tree."); buildTree(0, values.length); @@ -26,113 +52,110 @@ public Node(int[] values) { } } + /** + * Creates an empty segment tree of the given size, with all values at 0. + * + * @param sz the number of elements + */ public Node(int sz) { buildTree(0, sz); } - private Node(int l, int r) { - buildTree(l, r); - } - - // Recursive method that builds the segment tree + // Recursively builds the tree structure for the range [l, r). + // Leaves cover [i, i+1); internal nodes split at the midpoint. private void buildTree(int l, int r) { - if (l < 0 || r < 0 || r < l) throw new IllegalArgumentException("Illegal range: (" + l + "," + r + ")"); minPos = l; maxPos = r; - // Reached leaf - if (l == r - 1) { - left = right = null; - - // Add children - } else { + // Internal node — split at midpoint + if (r - l > 1) { int mid = (l + r) / 2; left = new Node(l, mid); right = new Node(mid, r); } } - // Adjust all values in the interval [l, r) by a particular amount - public void update(int l, int r, int change) { + private Node(int l, int r) { + buildTree(l, r); + } - // Do lazy updates to children + /** + * Adds {@code change} to every element in the half-open interval [l, r). + * + * @param l left endpoint (inclusive) + * @param r right endpoint (exclusive) + * @param change the value to add to each element in [l, r) + * + * Time: O(log(n)) + */ + public void update(int l, int r, int change) { propagate(); - // Node's range fits inside query range if (l <= minPos && maxPos <= r) { - + // Fully inside — apply update directly sum += change * (maxPos - minPos); min += change; - - // Lazily propagate update to children + // Lazily defer to children if (left != null) left.lazy += change; if (right != null) right.lazy += change; - - // Ranges do not overlap } else if (r <= minPos || l >= maxPos) { - - // Do nothing - - // Ranges partially overlap + // No overlap } else { - - if (left != null) left.update(l, r, change); - if (right != null) right.update(l, r, change); - sum = (left == null ? 0 : left.sum) + (right == null ? 0 : right.sum); - min = Math.min((left == null ? INF : left.min), (right == null ? INF : right.min)); + // Partial overlap — recurse into children. + // Partial overlap only happens at internal nodes (leaves always + // fully match or fully miss), so left and right are never null here. + left.update(l, r, change); + right.update(l, r, change); + sum = left.sum + right.sum; + min = Math.min(left.min, right.min); } } - // Get the sum in the interval [l, r) + /** + * Returns the sum of elements in the half-open interval [l, r). + * + * @param l left endpoint (inclusive) + * @param r right endpoint (exclusive) + * @return the sum of all elements in [l, r) + * + * Time: O(log(n)) + */ public int sum(int l, int r) { - - // Do lazy updates to children propagate(); - - // Node's range fits inside query range if (l <= minPos && maxPos <= r) return sum; - - // Ranges do not overlap - else if (r <= minPos || l >= maxPos) return 0; - - // Ranges partially overlap - else return (left == null ? 0 : left.sum(l, r)) + (right == null ? 0 : right.sum(l, r)); + if (r <= minPos || l >= maxPos) return 0; + return left.sum(l, r) + right.sum(l, r); } - // Get the minimum value in the interval [l, r) + /** + * Returns the minimum element in the half-open interval [l, r). + * + * @param l left endpoint (inclusive) + * @param r right endpoint (exclusive) + * @return the minimum value in [l, r) + * + * Time: O(log(n)) + */ public int min(int l, int r) { - - // Do lazy updates to children propagate(); - - // Node's range fits inside query range if (l <= minPos && maxPos <= r) return min; - - // Ranges do not overlap - else if (r <= minPos || l >= maxPos) return INF; - - // Ranges partially overlap - else - return Math.min( - (left == null ? INF : left.min(l, r)), (right == null ? INF : right.min(l, r))); + if (r <= minPos || l >= maxPos) return INF; + return Math.min(left.min(l, r), right.min(l, r)); } - // Does any updates to this node that haven't been done yet, and lazily updates its children - // NOTE: This method must be called before updating or accessing a node + /** + * Applies any pending lazy update to this node and defers it to children. + * Must be called before reading or modifying a node's values. + */ private void propagate() { - if (lazy != 0) { - sum += lazy * (maxPos - minPos); min += lazy; - - // Lazily propagate updates to children if (left != null) left.lazy += lazy; if (right != null) right.lazy += lazy; - lazy = 0; } } diff --git a/src/main/java/com/williamfiset/algorithms/dp/CoinChange.java b/src/main/java/com/williamfiset/algorithms/dp/CoinChange.java index ba4a7f211..abeab4098 100644 --- a/src/main/java/com/williamfiset/algorithms/dp/CoinChange.java +++ b/src/main/java/com/williamfiset/algorithms/dp/CoinChange.java @@ -1,36 +1,58 @@ +package com.williamfiset.algorithms.dp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + /** - * The coin change problem is an unbounded knapsack problem variant. The problem asks you to find - * the minimum number of coins required for a certain amount of change given the coin denominations. - * You may use each coin denomination as many times as you please. + * Coin Change Problem (Unbounded Knapsack Variant) + * + * Given a set of coin denominations and a target amount, find the minimum + * number of coins needed to make that amount. Each coin denomination may + * be used unlimited times. + * + * Three implementations are provided: * - *

Tested against: https://leetcode.com/problems/coin-change + * 1. coinChange() — 2D DP table, O(m*n) time/space, recovers selected coins + * 2. coinChangeSpaceEfficient() — 1D DP array, O(m*n) time, O(n) space, recovers selected coins + * 3. coinChangeRecursive() — top-down with memoization, skips unreachable states * - *

Run locally: + * Where m = number of coin denominations, n = target amount. * - *

bazel run //src/main/java/com/williamfiset/algorithms/dp:CoinChange + * Tested against: https://leetcode.com/problems/coin-change * * @author William Fiset, william.alexandre.fiset@gmail.com */ -package com.williamfiset.algorithms.dp; - -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; - public class CoinChange { + /** + * Holds the result of a coin change computation: the minimum number of coins + * (if a solution exists) and the actual coins selected. + */ public static class Solution { - // Contains the minimum number of coins to make a certain amount, if a solution exists. + /** The minimum number of coins to make the target amount, or empty if impossible. */ Optional minCoins = Optional.empty(); - // The coins selected as part of the optimal solution. - List selectedCoins = new ArrayList(); + /** The coins selected as part of the optimal solution. */ + List selectedCoins = new ArrayList<>(); } - // TODO(william): setting an explicit infinity could lead to a wrong answer for - // very large values. Prefer to use null instead. - private static final int INF = Integer.MAX_VALUE / 2; - + // ==================== Implementation 1: 2D DP table ==================== + + /** + * Solves coin change using a 2D DP table. + * + * dp[i][j] = minimum coins needed to make amount j using the first i coin types. + * After computing the table, backtracks to recover which coins were selected. + * + * @param coins array of coin denominations (all positive) + * @param n the target amount + * @return a Solution containing the min coin count and selected coins + * + * Time: O(m*n) where m = coins.length + * Space: O(m*n) + */ public static Solution coinChange(int[] coins, final int n) { if (coins == null) throw new IllegalArgumentException("Coins array is null"); if (coins.length == 0) throw new IllegalArgumentException("No coin values :/"); @@ -41,40 +63,41 @@ public static Solution coinChange(int[] coins, final int n) { } final int m = coins.length; - // Initialize table and set first row to be infinity - int[][] dp = new int[m + 1][n + 1]; - java.util.Arrays.fill(dp[0], INF); - dp[1][0] = 0; - // Iterate through all the coins + // dp[i][j] = min coins using first i denominations to make amount j. + // Row 0 is a sentinel: no coins available, so everything is impossible (null). + // Column 0 is always 0: it takes 0 coins to make amount 0. + Integer[][] dp = new Integer[m + 1][n + 1]; + for (int i = 0; i <= m; i++) + dp[i][0] = 0; + for (int i = 1; i <= m; i++) { int coinValue = coins[i - 1]; for (int j = 1; j <= n; j++) { - // Consider not selecting this coin + // Option 1: don't use coin i — carry forward from previous row dp[i][j] = dp[i - 1][j]; - // Try selecting this coin if it's better - if (j - coinValue >= 0 && dp[i][j - coinValue] + 1 < dp[i][j]) { - dp[i][j] = dp[i][j - coinValue] + 1; + // Option 2: use coin i if it fits and yields fewer coins + if (j - coinValue >= 0 && dp[i][j - coinValue] != null) { + int withCoin = dp[i][j - coinValue] + 1; + if (dp[i][j] == null || withCoin < dp[i][j]) { + dp[i][j] = withCoin; + } } } } - // p(dp); - Solution solution = new Solution(); - if (dp[m][n] != INF) { - solution.minCoins = Optional.of(dp[m][n]); - } else { - return solution; - } + if (dp[m][n] == null) return solution; + solution.minCoins = Optional.of(dp[m][n]); + // Backtrack to recover selected coins for (int change = n, coinIndex = m; coinIndex > 0; ) { int coinValue = coins[coinIndex - 1]; - boolean canSelectCoin = change - coinValue >= 0; - if (canSelectCoin && dp[coinIndex][change - coinValue] < dp[coinIndex][change]) { + boolean canSelect = change - coinValue >= 0 && dp[coinIndex][change - coinValue] != null; + if (canSelect && dp[coinIndex][change - coinValue] < dp[coinIndex][change]) { solution.selectedCoins.add(coinValue); change -= coinValue; } else { @@ -85,55 +108,82 @@ public static Solution coinChange(int[] coins, final int n) { return solution; } + // ==================== Implementation 2: Space-efficient 1D DP ==================== + + /** + * Solves coin change using a space-efficient 1D DP array. + * + * dp[j] = minimum coins needed to make amount j using any denomination. + * After computing, backtracks greedily to recover selected coins. + * + * Compare with coinChange(): same time complexity but uses O(n) space + * instead of O(m*n) by collapsing the coin dimension. + * + * @param coins array of coin denominations (all positive) + * @param n the target amount + * @return a Solution containing the min coin count and selected coins + * + * Time: O(m*n) + * Space: O(n) + */ public static Solution coinChangeSpaceEfficient(int[] coins, int n) { if (coins == null) throw new IllegalArgumentException("Coins array is null"); - // Initialize table and set everything to infinity except first cell - int[] dp = new int[n + 1]; - java.util.Arrays.fill(dp, INF); + // dp[j] = min coins to make amount j, null means impossible + Integer[] dp = new Integer[n + 1]; dp[0] = 0; for (int i = 1; i <= n; i++) { for (int coin : coins) { - if (i - coin < 0) { - continue; - } - if (dp[i - coin] + 1 < dp[i]) { - dp[i] = dp[i - coin] + 1; + if (i - coin >= 0 && dp[i - coin] != null) { + int withCoin = dp[i - coin] + 1; + if (dp[i] == null || withCoin < dp[i]) { + dp[i] = withCoin; + } } } } Solution solution = new Solution(); - if (dp[n] != INF) { - solution.minCoins = Optional.of(dp[n]); - } else { - return solution; - } + if (dp[n] == null) return solution; + solution.minCoins = Optional.of(dp[n]); + // Backtrack greedily: at each amount, pick the coin that leads to the fewest coins for (int i = n; i > 0; ) { - int selectedCoinValue = INF; - int cellWithFewestCoins = dp[i]; + int bestCoin = -1; + int bestCount = dp[i]; for (int coin : coins) { - if (i - coin < 0) { - continue; - } - if (dp[i - coin] < cellWithFewestCoins) { - cellWithFewestCoins = dp[i - coin]; - selectedCoinValue = coin; + if (i - coin >= 0 && dp[i - coin] != null && dp[i - coin] < bestCount) { + bestCount = dp[i - coin]; + bestCoin = coin; } } - solution.selectedCoins.add(selectedCoinValue); - i -= selectedCoinValue; + solution.selectedCoins.add(bestCoin); + i -= bestCoin; } - // Return the minimum number of coins needed return solution; } - // The recursive approach has the advantage that it does not have to visit - // all possible states like the tabular approach does. This can speedup - // things especially if the coin denominations are large. + // ==================== Implementation 3: Top-down recursive with memoization ==================== + + /** + * Solves coin change using top-down recursion with memoization. + * + * Unlike the two bottom-up implementations above, the recursive approach + * only visits states reachable from the target amount. This can be faster + * when coin denominations are large (many states are skipped). + * + * Note: returns -1 instead of Optional.empty() for impossible cases, + * and does not recover the selected coins. + * + * @param coins array of coin denominations (all positive) + * @param n the target amount + * @return the minimum number of coins, or -1 if impossible + * + * Time: O(m*n) + * Space: O(n) + */ public static int coinChangeRecursive(int[] coins, int n) { if (coins == null) throw new IllegalArgumentException("Coins array is null"); if (n < 0) return -1; @@ -142,77 +192,62 @@ public static int coinChangeRecursive(int[] coins, int n) { return coinChangeRecursive(n, coins, dp); } - // Private helper method to actually go the recursion private static int coinChangeRecursive(int n, int[] coins, int[] dp) { if (n < 0) return -1; if (n == 0) return 0; if (dp[n] != 0) return dp[n]; - int minCoins = INF; - for (int coinValue : coins) { - int value = coinChangeRecursive(n - coinValue, coins, dp); - if (value != -1 && value < minCoins) minCoins = value + 1; + int minCoins = Integer.MAX_VALUE; + for (int coin : coins) { + int value = coinChangeRecursive(n - coin, coins, dp); + if (value != -1 && value < minCoins) + minCoins = value + 1; } - // If we weren't able to find some coins to make our - // amount then cache -1 as the answer. - return dp[n] = (minCoins == INF) ? -1 : minCoins; - } - - // DP table print function. Used for debugging. - private static void p(int[][] dp) { - for (int[] r : dp) { - for (int v : r) { - System.out.printf("%4d, ", v == INF ? -1 : v); - } - System.out.println(); - } - } - - private static void p(int[] dp) { - for (int v : dp) { - System.out.printf("%4d, ", v == INF ? -1 : v); - } - System.out.println(); + // Cache -1 if no combination of coins can make this amount + return dp[n] = (minCoins == Integer.MAX_VALUE) ? -1 : minCoins; } public static void main(String[] args) { - // example1(); - // example2(); - // example3(); + example1(); + example2(); + example3(); example4(); } - private static void example4() { - int n = 11; - int[] coins = {2, 4, 1}; - // System.out.println(coinChange(coins, n).minCoins); - System.out.println(coinChangeSpaceEfficient(coins, n)); - // System.out.println(coinChangeRecursive(coins, n)); - // System.out.println(coinChange(coins, n).selectedCoins); - } - private static void example1() { int[] coins = {2, 6, 1}; - System.out.println(coinChange(coins, 17).minCoins); - System.out.println(coinChange(coins, 17).selectedCoins); - System.out.println(coinChangeSpaceEfficient(coins, 17)); - System.out.println(coinChangeRecursive(coins, 17)); + System.out.println("--- coins={2,6,1}, amount=17 ---"); + System.out.println("2D DP: " + coinChange(coins, 17).minCoins); // Optional[4] + System.out.println(" selected: " + coinChange(coins, 17).selectedCoins); // [6, 6, 2, 2, 1] + System.out.println("1D DP: " + coinChangeSpaceEfficient(coins, 17).minCoins); // Optional[4] + System.out.println("Recursive: " + coinChangeRecursive(coins, 17)); // 4 } private static void example2() { int[] coins = {2, 3, 5}; - System.out.println(coinChange(coins, 12).minCoins); - System.out.println(coinChange(coins, 12).selectedCoins); - System.out.println(coinChangeSpaceEfficient(coins, 12)); - System.out.println(coinChangeRecursive(coins, 12)); + System.out.println("--- coins={2,3,5}, amount=12 ---"); + System.out.println("2D DP: " + coinChange(coins, 12).minCoins); // Optional[3] + System.out.println(" selected: " + coinChange(coins, 12).selectedCoins); // [5, 5, 2] + System.out.println("1D DP: " + coinChangeSpaceEfficient(coins, 12).minCoins); // Optional[3] + System.out.println("Recursive: " + coinChangeRecursive(coins, 12)); // 3 } private static void example3() { int[] coins = {3, 4, 7}; - System.out.println(coinChange(coins, 17).minCoins); - System.out.println(coinChange(coins, 17).selectedCoins); - System.out.println(coinChangeSpaceEfficient(coins, 17)); - System.out.println(coinChangeRecursive(coins, 17)); + System.out.println("--- coins={3,4,7}, amount=17 ---"); + System.out.println("2D DP: " + coinChange(coins, 17).minCoins); // Optional[3] + System.out.println(" selected: " + coinChange(coins, 17).selectedCoins); // [7, 7, 3] + System.out.println("1D DP: " + coinChangeSpaceEfficient(coins, 17).minCoins); // Optional[3] + System.out.println("Recursive: " + coinChangeRecursive(coins, 17)); // 3 + } + + private static void example4() { + int[] coins = {2, 4, 1}; + System.out.println("--- coins={2,4,1}, amount=11 ---"); + System.out.println("2D DP: " + coinChange(coins, 11).minCoins); // Optional[4] + System.out.println(" selected: " + coinChange(coins, 11).selectedCoins); // [4, 4, 2, 1] + System.out.println("1D DP: " + coinChangeSpaceEfficient(coins, 11).minCoins); // Optional[4] + System.out.println("Recursive: " + coinChangeRecursive(coins, 11)); // 4 } } diff --git a/src/main/java/com/williamfiset/algorithms/dp/MinimumWeightPerfectMatching.java b/src/main/java/com/williamfiset/algorithms/dp/MinimumWeightPerfectMatching.java index 64932209a..9bf76eff4 100644 --- a/src/main/java/com/williamfiset/algorithms/dp/MinimumWeightPerfectMatching.java +++ b/src/main/java/com/williamfiset/algorithms/dp/MinimumWeightPerfectMatching.java @@ -1,26 +1,43 @@ +package com.williamfiset.algorithms.dp; + +import java.awt.geom.Point2D; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + /** - * Implementation of the Minimum Weight Perfect Matching (MWPM) problem. In this problem you are - * given a distance matrix which gives the distance from each node to every other node, and you want - * to pair up all the nodes to one another minimizing the overall cost. + * Minimum Weight Perfect Matching (MWPM) + * + * Given n nodes and a symmetric n x n cost matrix, the goal is to find a + * perfect matching — a set of n/2 pairs that covers every node exactly once — + * such that the sum of the edge costs of the chosen pairs is minimized. + * + * For example, given 4 nodes with cost matrix: + * cost = {{0,2,1,2}, {2,0,2,1}, {1,2,0,2}, {2,1,2,0}} + * the optimal matching is (0,2) and (1,3) with total cost 1 + 1 = 2. + * + * Uses bitmask DP where each bit in the state represents whether a node has + * been matched. Two solvers are provided: * - *

Tested against: UVA 10911 - Forming Quiz Teams + * - solveRecursive(): top-down with memoization, skips unreachable states + * - solveIterative(): bottom-up, builds matchings from pairs upward * - *

To Run: bazel run //src/main/java/com/williamfiset/algorithms/dp:MinimumWeightPerfectMatching + * Requires n to be even (otherwise no perfect matching exists) and n <= 32 + * (bitmask representation limit). * - *

Time Complexity: O(n * 2^n) + * Tested against: UVA 10911 - Forming Quiz Teams + * + * Time: O(n^2*2^n) + * Space: O(2^n) * * @author William Fiset */ -package com.williamfiset.algorithms.dp; - -import java.awt.geom.*; -import java.util.*; - public class MinimumWeightPerfectMatching { // Inputs private final int n; - private double[][] cost; + private final double[][] cost; // Internal private final int END_STATE; @@ -30,7 +47,13 @@ public class MinimumWeightPerfectMatching { private double minWeightCost; private int[] matching; - // The cost matrix should be a symmetric (i.e cost[i][j] = cost[j][i]) + /** + * Creates a MWPM solver for the given cost matrix. + * + * @param cost symmetric n x n distance matrix (cost[i][j] = cost[j][i]) + * + * @throws IllegalArgumentException if matrix is null, empty, odd-sized, or too large + */ public MinimumWeightPerfectMatching(double[][] cost) { if (cost == null) throw new IllegalArgumentException("Input cannot be null"); n = cost.length; @@ -45,8 +68,12 @@ public MinimumWeightPerfectMatching(double[][] cost) { this.cost = cost; } + /** + * Returns the minimum total cost of a perfect matching. + * Lazily solves using the recursive solver if neither solver has run yet. + */ public double getMinWeightCost() { - solveRecursive(); + if (!solved) solveRecursive(); return minWeightCost; } @@ -68,11 +95,21 @@ public double getMinWeightCost() { * } */ public int[] getMinWeightCostMatching() { - solveRecursive(); + if (!solved) solveRecursive(); return matching; } - // Recursive impl + // ==================== Solver 1: Top-down (recursive with memoization) ==================== + + /** + * Solves using top-down recursion with memoization. Starting from the full set + * of nodes, it finds the lowest-numbered unmatched node and tries pairing it + * with every other unmatched node, recursing on the reduced state. + * + * This approach naturally skips unreachable states (states that can't be formed + * by removing pairs from the full set), so it often visits fewer states than + * the iterative solver. + */ public void solveRecursive() { if (solved) return; Double[] dp = new Double[1 << n]; @@ -83,24 +120,17 @@ public void solveRecursive() { } private double f(int state, Double[] dp, int[] history) { - if (dp[state] != null) { - return dp[state]; - } - if (state == 0) { - return 0; - } - int p1, p2; - // Seek to find active bit position (p1) - for (p1 = 0; p1 < n; p1++) { - if ((state & (1 << p1)) > 0) { - break; - } - } + if (dp[state] != null) return dp[state]; + if (state == 0) return 0; + + // Find the lowest set bit position (p1) — always pair this node first + int p1 = Integer.numberOfTrailingZeros(state); + int bestState = -1; double minimum = Double.MAX_VALUE; - for (p2 = p1 + 1; p2 < n; p2++) { - // Position `p2` is on. Try matching the pair (p1, p2) together. + // Try pairing p1 with every other set bit + for (int p2 = p1 + 1; p2 < n; p2++) { if ((state & (1 << p2)) > 0) { int reducedState = state ^ (1 << p1) ^ (1 << p2); double matchCost = f(reducedState, dp, history) + cost[p1][p2]; @@ -114,7 +144,17 @@ private double f(int state, Double[] dp, int[] history) { return dp[state] = minimum; } - public void solve() { + // ==================== Solver 2: Bottom-up (iterative) ==================== + + /** + * Solves using bottom-up iterative DP. Pre-computes all n*(n-1)/2 pair states, + * then iterates over all bitmask states in ascending order, extending each + * valid matching by adding a non-overlapping pair. + * + * This approach visits all 2^n states systematically. It avoids recursion + * overhead and stack depth limits, making it better suited for larger n. + */ + public void solveIterative() { if (solved) return; // The DP state is encoded as a bitmask where the i'th bit is flipped on if the i'th node is @@ -152,9 +192,8 @@ public void solve() { for (int state = 0b11; state < (1 << n); state++) { // O(2^n) // Skip states with an odd number of bits (nodes). It's easier (and faster) to // check dp[state] instead of calling `Integer.bitCount` for the bit count. - if (dp[state] == null) { - continue; - } + if (dp[state] == null) continue; + for (int i = 0; i < numPairs; i++) { // O(n^2) int pair = pairStates[i]; // Ignore states which overlap @@ -178,65 +217,46 @@ public void solve() { solved = true; } - // Populates the `matching` array with a sorted deterministic matching sorted by lowest node - // index. For example, if the perfect matching consists of the pairs (3, 4), (1, 5), (0, 2). - // The matching is sorted such that the pairs appear in the ordering: (0, 2), (1, 5), (3, 4). - // Furthermore, it is guaranteed that for any pair (a, b) that a < b. + /** + * Populates the {@code matching} array with a sorted deterministic matching. + * For example, if the perfect matching consists of the pairs (3, 4), (1, 5), (0, 2), + * the output is sorted as: (0, 2), (1, 5), (3, 4). + * For any pair (a, b), it is guaranteed that a < b. + */ private void reconstructMatching(int[] history) { - // A map between pairs of nodes that were matched together. int[] map = new int[n]; int[] leftNodes = new int[n / 2]; - // Reconstruct the matching of pairs of nodes working backwards through computed states. + // Walk backwards through computed states to recover matched pairs for (int i = 0, state = END_STATE; state != 0; state = history[state]) { - // Isolate the pair used by xoring the state with the state used to generate it. int pairUsed = state ^ history[state]; - int leftNode = getBitPosition(Integer.lowestOneBit(pairUsed)); - int rightNode = getBitPosition(Integer.highestOneBit(pairUsed)); + int leftNode = Integer.numberOfTrailingZeros(Integer.lowestOneBit(pairUsed)); + int rightNode = Integer.numberOfTrailingZeros(Integer.highestOneBit(pairUsed)); leftNodes[i++] = leftNode; map[leftNode] = rightNode; } - // Sort the left nodes in ascending order. - java.util.Arrays.sort(leftNodes); + Arrays.sort(leftNodes); matching = new int[n]; for (int i = 0; i < n / 2; i++) { matching[2 * i] = leftNodes[i]; - int rightNode = map[leftNodes[i]]; - matching[2 * i + 1] = rightNode; + matching[2 * i + 1] = map[leftNodes[i]]; } } - // Gets the zero base index position of the 1 bit in `k`. `k` must be a power of 2, so there is - // only ever 1 bit in the binary representation of k. - private int getBitPosition(int k) { - int count = -1; - while (k > 0) { - count++; - k >>= 1; - } - return count; - } - - /* Example */ - public static void main(String[] args) { - // test1(); - // for (int i = 0; i < 50; i++) { - // if (include(i)) System.out.printf("%2d %7s\n", i, Integer.toBinaryString(i)); - // } - } - - private static boolean include(int i) { - boolean toInclude = Integer.bitCount(i) >= 2 && Integer.bitCount(i) % 2 == 0; - return toInclude; + test1(); + test2(); } + // Example 1: Uses the RECURSIVE solver. + // Generates 2D points that form vertical pairs, shuffles them, and verifies + // the MWPM correctly matches each pair (cost = 1 per pair, total = n/2). private static void test1() { - // int n = 18; + System.out.println("=== Recursive solver ==="); int n = 6; List pts = new ArrayList<>(); @@ -248,13 +268,13 @@ private static void test1() { Collections.shuffle(pts); double[][] cost = new double[n][n]; - for (int i = 0; i < n; i++) { - for (int j = 0; j < n; j++) { + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) cost[i][j] = pts.get(i).distance(pts.get(j)); - } - } MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + mwpm.solveRecursive(); + double minCost = mwpm.getMinWeightCost(); if (minCost != n / 2) { System.out.printf("MWPM cost is wrong! Got: %.5f But wanted: %d\n", minCost, n / 2); @@ -275,7 +295,10 @@ private static void test1() { } } + // Example 2: Uses the ITERATIVE solver. + // Simple 4-node symmetric matrix where the optimal matching costs 2.0. private static void test2() { + System.out.println("=== Iterative solver ==="); double[][] costMatrix = { {0, 2, 1, 2}, {2, 0, 2, 1}, @@ -284,12 +307,12 @@ private static void test2() { }; MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(costMatrix); + mwpm.solveIterative(); + double cost = mwpm.getMinWeightCost(); if (cost != 2.0) { System.out.println("error cost not 2"); } - System.out.println(cost); - // System.out.println(mwpm.solve2()); - + System.out.println(cost); // 2.0 } } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/FreivaldsAlgorithm.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/FreivaldsAlgorithm.java index 4d62c6ba2..d81845942 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/FreivaldsAlgorithm.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/FreivaldsAlgorithm.java @@ -1,56 +1,72 @@ +package com.williamfiset.algorithms.linearalgebra; + +import java.util.Arrays; + /** - * Freivalds' algorithm is a probabilistic randomized algorithm used to verify matrix - * multiplication. Given three n x n matrices, Freivalds' algorithm determines in O(kn^2) whether - * the matrices are equal for a chosen k value with a probability of failure less than 2^-k. + * Freivald's Algorithm for Probabilistic Matrix Multiplication Verification + * + * Given three n x n matrices A, B, and C, determines whether A * B = C + * without computing the full product. Instead, it picks a random binary + * vector r and checks if A * (B * r) = C * r, which takes O(n^2) per trial. + * + * Repeating k independent trials gives a failure probability less than 2^(-k). * - *

Time Complexity: O(kn^2) + * Use cases: + * - Verifying matrix multiplication faster than recomputing it + * - Randomized algorithms and Monte Carlo methods + * + * Time: O(k*n^2) + * Space: O(n) * * @author William Fiset, william.alexandre.fiset@gmail.com */ -package com.williamfiset.algorithms.linearalgebra; - public class FreivaldsAlgorithm { - // Randomly sets the values in the vector to either 0 or 1 - private static void randomizeVector(int[] vector) { - for (int i = 0; i < vector.length; i++) { - vector[i] = (Math.random() < 0.5) ? 0 : 1; - } - } - - // Compute the product of a vector with a matrix. - private static int[] product(int[] v, int[][] matrix) { - - int N = matrix.length; - int[] vector = new int[N]; - - for (int i = 0; i < N; i++) for (int j = 0; j < N; j++) vector[i] += v[j] * matrix[i][j]; - - return vector; - } - - // Freivalds' algorithm is a probabilistic randomized algorithm used to verify - // matrix multiplication. Given three n x n matrices, Freivalds' algorithm - // determines in O(kn^2) whether the matrices are equal for a chosen k value - // with a probability of failure less than 2^-k. + /** + * Verifies whether A * B = C using k independent random trials. + * + * @param A first n x n matrix + * @param B second n x n matrix + * @param C the alleged product matrix (A * B) + * @param k number of trials (failure probability < 2^(-k)) + * @return true if the test passes all k rounds (likely equal), false if definitely not equal + * + * Time: O(k*n^2) + */ public static boolean freivalds(int[][] A, int[][] B, int[][] C, int k) { - + if (k <= 0) throw new IllegalArgumentException("k must be positive"); final int n = A.length; if (A[0].length != n || B.length != n || B[0].length != n || C.length != n || C[0].length != n) throw new IllegalArgumentException("Input must be three nxn matrices"); int[] v = new int[n]; - do { - + for (int trial = 0; trial < k; trial++) { randomizeVector(v); + // Compare C*v against A*(B*v) — both are O(n^2) matrix-vector products int[] expected = product(v, C); int[] result = product(product(v, B), A); - if (!java.util.Arrays.equals(expected, result)) return false; - - } while (--k > 0); + if (!Arrays.equals(expected, result)) return false; + } return true; } + + /** Randomly sets each element of the vector to 0 or 1. */ + private static void randomizeVector(int[] vector) { + for (int i = 0; i < vector.length; i++) { + vector[i] = (Math.random() < 0.5) ? 0 : 1; + } + } + + /** Computes the product of a vector with a matrix, O(n^2). */ + private static int[] product(int[] v, int[][] matrix) { + int n = matrix.length; + int[] result = new int[n]; + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + result[i] += v[j] * matrix[i][j]; + return result; + } } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/GaussianElimination.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/GaussianElimination.java index 2ea82f33a..c723ded65 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/GaussianElimination.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/GaussianElimination.java @@ -1,26 +1,43 @@ +package com.williamfiset.algorithms.linearalgebra; + /** - * Solve a system of linear equations using Gaussian elimination. To work with this code the linear - * equations must be specified as a matrix augmented with the constants as the right-most column. + * Gaussian Elimination for Solving Linear Systems + * + * Solves a system of linear equations by reducing an augmented matrix + * [A | b] to reduced row echelon form (RREF). After reduction, the + * solution (if unique) appears in the rightmost column. + * + * The algorithm uses partial pivoting (row swaps) for numerical stability + * and eliminates both above and below each pivot to produce RREF directly. * - *

Time Complexity: O(c*r^2) + * Use cases: + * - Solving systems of linear equations + * - Checking for inconsistency or infinite solutions + * + * Time: O(r^2*c) where r = rows, c = columns + * Space: O(1) (in-place) + * + * @author William Fiset, william.alexandre.fiset@gmail.com */ -package com.williamfiset.algorithms.linearalgebra; - class GaussianElimination { - // Define a small value of epsilon to compare double values - static final double EPS = 0.00000001; + private static final double EPS = 0.00000001; - // Solves a system of linear equations as an augmented matrix - // with the rightmost column containing the constants. The answers - // will be stored on the rightmost column after the algorithm is done. - // NOTE: make sure your matrix is consistent and does not have multiple - // solutions when you solve the system if you want a unique valid answer. - // Time Complexity: O(r²c) + /** + * Reduces an augmented matrix to RREF in-place. After solving, the + * rightmost column contains the solution values (if the system is + * consistent with a unique solution). + * + * @param augmentedMatrix the [A | b] matrix to reduce + * + * Time: O(r^2*c) + */ static void solve(double[][] augmentedMatrix) { int nRows = augmentedMatrix.length, nCols = augmentedMatrix[0].length, lead = 0; for (int r = 0; r < nRows; r++) { if (lead >= nCols) break; + + // Find pivot: row with non-zero entry in the lead column int i = r; while (Math.abs(augmentedMatrix[i][lead]) < EPS) { if (++i == nRows) { @@ -28,22 +45,33 @@ static void solve(double[][] augmentedMatrix) { if (++lead == nCols) return; } } + + // Swap pivot row into position double[] temp = augmentedMatrix[r]; augmentedMatrix[r] = augmentedMatrix[i]; augmentedMatrix[i] = temp; + + // Scale pivot row so leading entry becomes 1 double lv = augmentedMatrix[r][lead]; - for (int j = 0; j < nCols; j++) augmentedMatrix[r][j] /= lv; + for (int j = 0; j < nCols; j++) + augmentedMatrix[r][j] /= lv; + + // Eliminate all other rows in this column for (i = 0; i < nRows; i++) { if (i != r) { lv = augmentedMatrix[i][lead]; - for (int j = 0; j < nCols; j++) augmentedMatrix[i][j] -= lv * augmentedMatrix[r][j]; + for (int j = 0; j < nCols; j++) + augmentedMatrix[i][j] -= lv * augmentedMatrix[r][j]; } } lead++; } } - // Checks if the matrix is inconsistent + /** + * Checks if the reduced matrix represents an inconsistent system + * (a row of all zeros on the left with a non-zero constant on the right). + */ static boolean isInconsistent(double[][] arr) { int nCols = arr[0].length; outer: @@ -56,7 +84,10 @@ static boolean isInconsistent(double[][] arr) { return false; } - // Make sure your matrix is consistent as well + /** + * Checks if the reduced matrix has more unknowns than non-empty rows, + * indicating infinitely many solutions. Call after verifying consistency. + */ static boolean hasMultipleSolutions(double[][] arr) { int nCols = arr[0].length, nEmptyRows = 0; outer: @@ -68,7 +99,6 @@ static boolean hasMultipleSolutions(double[][] arr) { } public static void main(String[] args) { - // Suppose we want to solve the following system for // the variables x, y, z: // @@ -76,7 +106,6 @@ public static void main(String[] args) { // x + 2y - z = 18 // 6x - y + 0 = 12 // Then we would setup the following augment matrix: - double[][] augmentedMatrix = { {2, -3, 5, 10}, {1, 2, -1, 18}, @@ -86,11 +115,9 @@ public static void main(String[] args) { solve(augmentedMatrix); if (!hasMultipleSolutions(augmentedMatrix) && !isInconsistent(augmentedMatrix)) { - double x = augmentedMatrix[0][3]; double y = augmentedMatrix[1][3]; double z = augmentedMatrix[2][3]; - // x ~ 3.755, y ~ 10.531, z ~ 6.816 System.out.printf("x = %.3f, y = %.3f, z = %.3f\n", x, y, z); } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixDeterminantLaplaceExpansion.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixDeterminantLaplaceExpansion.java index b0c933f13..027d3251e 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixDeterminantLaplaceExpansion.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixDeterminantLaplaceExpansion.java @@ -1,18 +1,84 @@ +package com.williamfiset.algorithms.linearalgebra; + /** - * This is an implementation of finding the determinant of an nxn matrix using Laplace/cofactor - * expansion. Although this method is mathematically beautiful, it is computationally intensive and - * not practical for matrices beyond the size of 7-8. + * Matrix Determinant via Laplace (Cofactor) Expansion + * + * Computes the determinant of an n x n matrix by recursively expanding + * along the first row. Each expansion reduces the matrix size by 1, + * creating n subproblems of size (n-1) x (n-1). * - *

Time Complexity: ~O((n+2)!) + * Mathematically elegant but computationally expensive — not practical + * for matrices larger than about 7-8. For larger matrices, use + * Gaussian elimination (O(n^3)) instead. + * + * Includes optimized closed-form formulas for 1x1, 2x2, and 3x3 bases. + * + * Time: ~O((n+2)!) + * Space: O(n^2*n!) due to recursive submatrix allocation * * @author William Fiset, william.alexandre.fiset@gmail.com */ -package com.williamfiset.algorithms.linearalgebra; - public class MatrixDeterminantLaplaceExpansion { - // Define a small value of epsilon to compare double values - static final double EPS = 0.00000001; + private static final double EPS = 0.00000001; + + /** + * Computes the determinant of an n x n matrix. + * + * @param matrix the square matrix + * @return the determinant value + * + * Time: ~O((n+2)!) + */ + public static double determinant(double[][] matrix) { + final int n = matrix.length; + if (n == 1) return matrix[0][0]; + if (n == 2) return matrix[0][0] * matrix[1][1] - matrix[0][1] * matrix[1][0]; + return laplace(matrix); + } + + /** + * Recursive cofactor expansion along the first row. + * Base case is the 3x3 Sarrus formula. + */ + private static double laplace(double[][] m) { + final int n = m.length; + if (n == 3) { + double a = m[0][0], b = m[0][1], c = m[0][2]; + double d = m[1][0], e = m[1][1], f = m[1][2]; + double g = m[2][0], h = m[2][1], i = m[2][2]; + return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g); + } + int det = 0; + for (int i = 0; i < n; i++) { + double c = m[0][i]; + if (c > EPS) { + int sign = ((i & 1) == 0) ? +1 : -1; + det += sign * m[0][i] * laplace(constructMinor(m, 0, i)); + } + } + return det; + } + + /** + * Constructs the (n-1) x (n-1) minor matrix by excluding the + * specified row and column from the input matrix. + */ + private static double[][] constructMinor(double[][] mat, int excludingRow, int excludingCol) { + int n = mat.length; + double[][] minor = new double[n - 1][n - 1]; + int rPtr = -1; + for (int i = 0; i < n; i++) { + if (i == excludingRow) continue; + ++rPtr; + int cPtr = -1; + for (int j = 0; j < n; j++) { + if (j == excludingCol) continue; + minor[rPtr][++cPtr] = mat[i][j]; + } + } + return minor; + } public static void main(String[] args) { @@ -148,71 +214,4 @@ public static void main(String[] args) { }; // determinant(mat0) = 17265530 (1.726553E7) System.out.println(determinant(m)); } - - // Given an n*n matrix, this method finds the determinant using Laplace/cofactor expansion. - // Time Complexity: ~O((n+2)!) - public static double determinant(double[][] matrix) { - - final int n = matrix.length; - - // Use closed form for 1x1 determinant - if (n == 1) return matrix[0][0]; - - // Use closed form for 2x2 determinant - if (n == 2) return matrix[0][0] * matrix[1][1] - matrix[0][1] * matrix[1][0]; - - // For 3x3 matrices and up use Laplace/cofactor expansion - return laplace(matrix); - } - - // This method uses cofactor expansion to compute the determinant - // of a matrix. Unfortunately, this method is very slow and uses - // A LOT of memory, hence it is not too practical for large matrices. - private static double laplace(double[][] m) { - - final int n = m.length; - - // Base case is 3x3 determinant - if (n == 3) { - /* - * Used as a temporary variables to make calculation easy - * | a b c | - * | d e f | - * | g h i | - */ - double a = m[0][0], b = m[0][1], c = m[0][2]; - double d = m[1][0], e = m[1][1], f = m[1][2]; - double g = m[2][0], h = m[2][1], i = m[2][2]; - return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g); - } - int det = 0; - for (int i = 0; i < n; i++) { - double c = m[0][i]; - if (c > EPS) { - int sign = ((i & 1) == 0) ? +1 : -1; - det += sign * m[0][i] * laplace(constructMatrix(m, 0, i)); - } - } - return det; - } - - // Constructs a matrix one dimension smaller than the last by - // excluding the top row and some selected column. This - // method ends up consuming a lot of space we called recursively multiple times - // since it allocates memory for a new matrix. - private static double[][] constructMatrix(double[][] mat, int excludingRow, int excludingCol) { - int n = mat.length; - double[][] newMatrix = new double[n - 1][n - 1]; - int rPtr = -1; - for (int i = 0; i < n; i++) { - if (i == excludingRow) continue; - ++rPtr; - int cPtr = -1; - for (int j = 0; j < n; j++) { - if (j == excludingCol) continue; - newMatrix[rPtr][++cPtr] = mat[i][j]; - } // end of inner loop - } // end of outer loop - return newMatrix; - } // end of createSubMatrix } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixInverse.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixInverse.java index e773b82ac..f1f7a6695 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixInverse.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixInverse.java @@ -1,37 +1,55 @@ +package com.williamfiset.algorithms.linearalgebra; + +import java.util.Arrays; + /** - * Use Gaussian elimination on an augmented matrix to find the inverse of a matrix. + * Matrix Inverse via Gaussian Elimination * - *

Time Complexity: O(n^3) + * Finds the inverse of an n x n matrix by augmenting it with the identity + * matrix [A | I] and reducing to RREF. If A is invertible, the result is + * [I | A^(-1)]. + * + * Time: O(n^3) + * Space: O(n^2) + * + * @author William Fiset, william.alexandre.fiset@gmail.com */ -package com.williamfiset.algorithms.linearalgebra; - class MatrixInverse { - // Define a small value of epsilon to compare double values - static final double EPS = 0.00000001; + private static final double EPS = 0.00000001; - // Invert the specified matrix. Assumes invertibility. Time Complexity: O(r²c) + /** + * Computes the inverse of a square matrix using Gaussian elimination. + * + * @param matrix the n x n matrix to invert + * @return the inverse matrix, or null if the matrix is not square + * + * Time: O(n^3) + */ static double[][] inverse(double[][] matrix) { if (matrix.length != matrix[0].length) return null; int n = matrix.length; + + // Build augmented matrix [A | I] double[][] augmented = new double[n][n * 2]; for (int i = 0; i < n; i++) { - for (int j = 0; j < n; j++) augmented[i][j] = matrix[i][j]; + for (int j = 0; j < n; j++) + augmented[i][j] = matrix[i][j]; augmented[i][i + n] = 1; } + solve(augmented); + + // Extract the inverse from the right half double[][] inv = new double[n][n]; - for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) inv[i][j] = augmented[i][j + n]; + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + inv[i][j] = augmented[i][j + n]; return inv; } - // Solves a system of linear equations as an augmented matrix - // with the rightmost column containing the constants. The answers - // will be stored on the rightmost column after the algorithm is done. - // NOTE: make sure your matrix is consistent and does not have multiple - // solutions before you solve the system if you want a unique valid answer. - // Time Complexity: O(r²c) - static void solve(double[][] augmentedMatrix) { + /** Reduces an augmented matrix to RREF in-place. */ + private static void solve(double[][] augmentedMatrix) { int nRows = augmentedMatrix.length, nCols = augmentedMatrix[0].length, lead = 0; for (int r = 0; r < nRows; r++) { if (lead >= nCols) break; @@ -46,51 +64,58 @@ static void solve(double[][] augmentedMatrix) { augmentedMatrix[r] = augmentedMatrix[i]; augmentedMatrix[i] = temp; double lv = augmentedMatrix[r][lead]; - for (int j = 0; j < nCols; j++) augmentedMatrix[r][j] /= lv; + for (int j = 0; j < nCols; j++) + augmentedMatrix[r][j] /= lv; for (i = 0; i < nRows; i++) { if (i != r) { lv = augmentedMatrix[i][lead]; - for (int j = 0; j < nCols; j++) augmentedMatrix[i][j] -= lv * augmentedMatrix[r][j]; + for (int j = 0; j < nCols; j++) + augmentedMatrix[i][j] -= lv * augmentedMatrix[r][j]; } } lead++; } } - // Checks if the matrix is inconsistent + /** + * Checks if the reduced matrix represents an inconsistent system + * (a row of all zeros on the left with a non-zero constant on the right). + */ static boolean isInconsistent(double[][] arr) { int nCols = arr[0].length; outer: for (int y = 0; y < arr.length; y++) { if (Math.abs(arr[y][nCols - 1]) > EPS) { - for (int x = 0; x < nCols - 1; x++) if (Math.abs(arr[y][x]) > EPS) continue outer; + for (int x = 0; x < nCols - 1; x++) + if (Math.abs(arr[y][x]) > EPS) continue outer; return true; } } return false; } - // Make sure your matrix is consistent as well + /** + * Checks if the reduced matrix has more unknowns than non-empty rows, + * indicating infinitely many solutions. Call after verifying consistency. + */ static boolean hasMultipleSolutions(double[][] arr) { int nCols = arr[0].length, nEmptyRows = 0; outer: for (int y = 0; y < arr.length; y++) { - for (int x = 0; x < nCols; x++) if (Math.abs(arr[y][x]) > EPS) continue outer; + for (int x = 0; x < nCols; x++) + if (Math.abs(arr[y][x]) > EPS) continue outer; nEmptyRows++; } return nCols - 1 > arr.length - nEmptyRows; } public static void main(String[] args) { - - // Check this matrix is invertable double[][] matrix = { {2, -4, 0}, {0, 6, 0}, {2, 2, -2} }; - double[][] inv = inverse(matrix); - for (double[] row : inv) System.out.println(java.util.Arrays.toString(row)); + for (double[] row : inv) System.out.println(Arrays.toString(row)); } } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixMultiplication.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixMultiplication.java index 4c417fe73..1980aab32 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixMultiplication.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixMultiplication.java @@ -1,42 +1,56 @@ +package com.williamfiset.algorithms.linearalgebra; + +import java.util.Arrays; + /** - * Multiply two matrices together and get their product + * Standard Matrix Multiplication * - *

Time Complexity: O(n^3) + * Computes the product C = A * B using the naive triple-loop algorithm. + * Matrix A has dimensions (aRows x aCols) and B has (bRows x bCols); + * multiplication is only valid when aCols == bRows, producing a + * (aRows x bCols) result. + * + * Time: O(n^3) for n x n matrices + * Space: O(n^2) for the result matrix * * @author Micah Stairs */ -package com.williamfiset.algorithms.linearalgebra; - class MatrixMultiplication { - // Returns the result of the product of the matrices 'a' and 'b' - // or null if the matrices are the wrong dimensions + /** + * Returns the product of matrices a and b, or null if dimensions are incompatible. + * + * @param a the left matrix (aRows x aCols) + * @param b the right matrix (bRows x bCols), requires aCols == bRows + * @return the product matrix (aRows x bCols), or null if aCols != bRows + * + * Time: O(aRows * bCols * aCols) + */ static double[][] multiply(double[][] a, double[][] b) { int aRows = a.length, aCols = a[0].length; int bRows = b.length, bCols = b[0].length; if (aCols != bRows) return null; double[][] c = new double[aRows][bCols]; for (int i = 0; i < aRows; i++) - for (int j = 0; j < bCols; j++) for (int k = 0; k < aCols; k++) c[i][j] += a[i][k] * b[k][j]; + for (int j = 0; j < bCols; j++) + for (int k = 0; k < aCols; k++) + c[i][j] += a[i][k] * b[k][j]; return c; } public static void main(String[] args) { - double[][] a = { {1, 2, 3, 4}, {4, 3, 2, 1}, {1, 2, 2, 1} }; - double[][] b = { {1, 0}, {2, 1}, {0, 3}, {0, 0} }; - double[][] c = multiply(a, b); - for (double[] row : c) System.out.println(java.util.Arrays.toString(row)); + for (double[] row : c) System.out.println(Arrays.toString(row)); } } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixPower.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixPower.java index ce21e254e..857ee940c 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixPower.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/MatrixPower.java @@ -1,69 +1,81 @@ +package com.williamfiset.algorithms.linearalgebra; + +import java.util.Arrays; + /** - * Raise an nxn square matrix to a certain power p. + * Matrix Exponentiation (Binary Exponentiation) * - *

Time Complexity: O(n^3log(p)) + * Raises an n x n square matrix to the power p using repeated squaring + * (binary exponentiation). Instead of multiplying the matrix p times + * (O(n^3*p)), this decomposes p into binary and squares the matrix at + * each bit, achieving O(n^3*log(p)). + * + * Use cases: + * - Computing Fibonacci numbers in O(log(n)) + * - Solving linear recurrences efficiently + * - Graph path counting (A^k gives the number of k-length paths) + * + * Time: O(n^3*log(p)) + * Space: O(n^2) * * @author William Fiset, william.alexandre.fiset@gmail.com */ -package com.williamfiset.algorithms.linearalgebra; - public class MatrixPower { - static long[][] matrixDeepCopy(long[][] M) { - final int N = M.length; - long[][] newMatrix = new long[N][N]; - for (int i = 0; i < N; i++) newMatrix[i] = M[i].clone(); - return newMatrix; - } - - // Perform matrix multiplication, O(n^3) - static long[][] squareMatrixMult(long[][] m1, long[][] m2) { - - final int N = m1.length; - long[][] newMatrix = new long[N][N]; - - for (int i = 0; i < N; i++) - for (int j = 0; j < N; j++) - for (int k = 0; k < N; k++) - // Overflow can happen here, watch out! - newMatrix[i][j] = newMatrix[i][j] + m1[i][k] * m2[k][j]; - - return newMatrix; - } - - // Raise a matrix to the pth power. If p is negative - // return null and if p is zero return the identity. - // NOTE: Make sure the matrix is a square matrix and - // also watch out for overflow as the numbers climb quickly! + /** + * Raises a square matrix to the power p using binary exponentiation. + * + * @param matrix the n x n matrix to exponentiate + * @param p the exponent (returns identity for p=0, null for p<0) + * @return matrix^p, or null if p is negative + * + * Time: O(n^3*log(p)) + */ static long[][] matrixPower(long[][] matrix, long p) { - if (p < 0) return null; - final int N = matrix.length; - long[][] newMatrix = null; + final int n = matrix.length; - // Return identity matrix + // p = 0 → return identity matrix if (p == 0) { - newMatrix = new long[N][N]; - for (int i = 0; i < N; i++) newMatrix[i][i] = 1L; - } else { - - long[][] P = matrixDeepCopy(matrix); - - while (p > 0) { + long[][] identity = new long[n][n]; + for (int i = 0; i < n; i++) + identity[i][i] = 1L; + return identity; + } - if ((p & 1L) == 1L) { - if (newMatrix == null) newMatrix = matrixDeepCopy(P); - else newMatrix = squareMatrixMult(newMatrix, P); - } + long[][] result = null; + long[][] base = matrixDeepCopy(matrix); - // Repeatedly square P every loop, O(n^3) - P = squareMatrixMult(P, P); - p >>= 1L; + // Binary exponentiation: decompose p into bits + while (p > 0) { + if ((p & 1L) == 1L) { + result = (result == null) ? matrixDeepCopy(base) : squareMatrixMult(result, base); } + base = squareMatrixMult(base, base); + p >>= 1L; } - return newMatrix; + return result; + } + + /** Standard O(n^3) matrix multiplication for square matrices. */ + private static long[][] squareMatrixMult(long[][] m1, long[][] m2) { + final int n = m1.length; + long[][] result = new long[n][n]; + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + for (int k = 0; k < n; k++) + result[i][j] += m1[i][k] * m2[k][j]; + return result; + } + + private static long[][] matrixDeepCopy(long[][] m) { + final int n = m.length; + long[][] copy = new long[n][n]; + for (int i = 0; i < n; i++) + copy[i] = m[i].clone(); + return copy; } public static void main(String[] args) { @@ -113,11 +125,11 @@ public static void main(String[] args) { // [0, 0, 0, 1, 0, 0] // [0, 0, 0, 0, 1, 0] // [0, 0, 0, 0, 0, 1] - } static void print2DMatrix(long[][] M) { - for (long[] m : M) System.out.println(java.util.Arrays.toString(m)); + for (long[] m : M) + System.out.println(Arrays.toString(m)); System.out.println(); } } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/ModularLinearAlgebra.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/ModularLinearAlgebra.java index e2bfa92a5..8d969013d 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/ModularLinearAlgebra.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/ModularLinearAlgebra.java @@ -1,19 +1,39 @@ -/** Solve a system of linear equations in a finite field Time Complexity: O(r^2c) */ package com.williamfiset.algorithms.linearalgebra; +/** + * Linear Algebra in Finite Fields (Modular Arithmetic) + * + * Solves systems of linear equations and computes matrix inverses in + * Z_p (the integers modulo a prime p). All arithmetic is performed + * modulo p, using precomputed multiplicative inverses. + * + * Includes: + * - Reduced row echelon form (RREF) in Z_p + * - Matrix inverse in Z_p + * - Consistency and uniqueness checks + * - Extended GCD for computing modular inverses + * + * Time: O(r^2*c) for RREF, O(n^3) for matrix inverse + * Space: O(n^2) + * + * @author William Fiset, william.alexandre.fiset@gmail.com + */ class ModularLinearAlgebra { - // Takes an augmented matrix as input along with a prime - // number as the order of the finite field on which the - // calculations are being performed. The inv[] array is - // the multiplicative inverse of each element in this - // finite field. After running this method, the input - // matrix arr[] will be in reduced row echelon form - // Time Complexity: O(r^2c) + /** + * Reduces an augmented matrix to RREF in the finite field Z_p. + * + * @param arr the augmented matrix [A | b] + * @param prime the prime modulus p + * @param inv precomputed multiplicative inverses in Z_p (inv[i] = i^(-1) mod p) + * + * Time: O(r^2*c) + */ static void rref(int[][] arr, int prime, int[] inv) { int n = arr.length, m = arr[0].length; int r = 0; for (int i = 0; i < m - 1 && r < n; i++) { + // Find pivot row if (arr[r][i] == 0) { for (int k = r + 1; k < n; k++) { if (arr[k][i] != 0) { @@ -24,34 +44,49 @@ static void rref(int[][] arr, int prime, int[] inv) { } } } - if (arr[r][i] == 0) { - continue; - } + if (arr[r][i] == 0) continue; + + // Scale pivot row so leading entry becomes 1 int inverse = inv[arr[r][i]]; - for (int k = i; k < m; k++) arr[r][k] = (arr[r][k] * inverse) % prime; + for (int k = i; k < m; k++) + arr[r][k] = (arr[r][k] * inverse) % prime; + + // Eliminate all other rows in this column for (int j = 0; j < n; j++) { int c = arr[j][i]; if (j == r || c == 0) continue; arr[j][i] = 0; - for (int k = i + 1; k < m; k++) arr[j][k] = (arr[j][k] - c * arr[r][k] + c * prime) % prime; + for (int k = i + 1; k < m; k++) + arr[j][k] = (arr[j][k] - c * arr[r][k] + c * prime) % prime; } r++; } } - // Finds the inverse of a non-augmented matrix in the finite field - // with order equal to the given prime. + /** + * Computes the inverse of a matrix in Z_p. + * + * @param arr the n x n matrix to invert + * @param prime the prime modulus p + * @param modInv precomputed multiplicative inverses in Z_p + * @return the inverse matrix, or null if not invertible + * + * Time: O(n^3) + */ static int[][] inverse(int[][] arr, int prime, int[] modInv) { if (arr.length != arr[0].length) return null; int n = arr.length; + + // Build augmented matrix [A | I] int[][] augmented = new int[n][n * 2]; for (int i = 0; i < n; i++) { - for (int j = 0; j < n; j++) { + for (int j = 0; j < n; j++) augmented[i][j] = arr[i][j]; - } augmented[i][i + n] = 1; } rref(augmented, prime, modInv); + + // Verify left half is identity; extract right half int[][] inv = new int[n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { @@ -63,39 +98,33 @@ static int[][] inverse(int[][] arr, int prime, int[] modInv) { return inv; } - // To be checked after the augmented matrix has been - // row reduced to reduced row echelon form + /** Checks if the reduced matrix is inconsistent (no solution). */ static boolean isInconsistent(int[][] arr) { int nCols = arr[0].length; outer: for (int y = 0; y < arr.length; y++) { if (arr[y][nCols - 1] != 0) { - for (int x = 0; x < nCols - 1; x++) { + for (int x = 0; x < nCols - 1; x++) if (arr[y][x] != 0) continue outer; - } return true; } } return false; } - // To be checked after the augmented matrix has been - // row reduced to reduced row echelon form and checked - // for consistency + /** Checks if the system has multiple solutions (underdetermined). */ static boolean hasMultipleSolutions(int[][] arr) { - int nCols = arr[0].length; - int nEmptyRows = 0; + int nCols = arr[0].length, nEmptyRows = 0; outer: for (int y = 0; y < arr.length; y++) { - for (int x = 0; x < nCols; x++) { + for (int x = 0; x < nCols; x++) if (arr[y][x] != 0) continue outer; - } nEmptyRows++; } return nCols - 1 > arr.length - nEmptyRows; } - // Returns {gcd(a,b), x, y} such that ax+by=gcd(a,b) + /** Returns {gcd(a,b), x, y} such that a*x + b*y = gcd(a,b). */ static int[] egcd(int a, int b) { if (b == 0) return new int[] {a, 1, 0}; int[] ret = egcd(b, a % b); @@ -105,7 +134,7 @@ static int[] egcd(int a, int b) { return ret; } - // Returns the inverse of x mod m + /** Returns the multiplicative inverse of x mod m. */ static int modInv(int x, int m) { return (egcd(x, m)[1] + m) % m; } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/RotateSquareMatrixInplace.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/RotateSquareMatrixInplace.java index d7bbc46a1..2cf8897d9 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/RotateSquareMatrixInplace.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/RotateSquareMatrixInplace.java @@ -1,25 +1,44 @@ +package com.williamfiset.algorithms.linearalgebra; + +import java.util.Arrays; + /** - * Rotate the entries of a square matrix 90 degrees clockwise. + * In-Place Square Matrix Rotation (90 Degrees Clockwise) + * + * Rotates an n x n matrix 90 degrees clockwise by cycling four elements + * at a time in concentric rings from the outside in. No extra matrix is + * allocated — the rotation is done in-place using a single temp variable. * - *

Time Complexity: O(n^2) + * The key insight is that element (i, j) moves to (j, n-1-i). By cycling + * four elements at once — top→right→bottom→left→top — each element is + * visited exactly once. + * + * Time: O(n^2) + * Space: O(1) * * @author William Fiset, william.alexandre.fiset@gmail.com */ -package com.williamfiset.algorithms.linearalgebra; - public class RotateSquareMatrixInplace { - // Rotates the entries of a square matrix 90 degrees clockwise. + /** + * Rotates the entries of a square matrix 90 degrees clockwise in-place. + * + * @param matrix the n x n matrix to rotate + * + * Time: O(n^2) + */ static void rotate(int[][] matrix) { int n = matrix.length; + // Process concentric rings from outside in for (int i = 0; i < n / 2; i++) { int invI = n - i - 1; + // Cycle four elements at a time along this ring for (int j = i; j < invI; j++) { int invJ = n - j - 1, tmp = matrix[i][j]; - matrix[i][j] = matrix[invJ][i]; - matrix[invJ][i] = matrix[invI][invJ]; - matrix[invI][invJ] = matrix[j][invI]; - matrix[j][invI] = tmp; + matrix[i][j] = matrix[invJ][i]; // left → top + matrix[invJ][i] = matrix[invI][invJ]; // bottom → left + matrix[invI][invJ] = matrix[j][invI]; // right → bottom + matrix[j][invI] = tmp; // top → right } } } @@ -35,7 +54,8 @@ public static void main(String[] args) { }; rotate(matrix); - for (int[] row : matrix) System.out.println(java.util.Arrays.toString(row)); + for (int[] row : matrix) + System.out.println(Arrays.toString(row)); // prints: // [21, 16, 11, 6, 1] // [22, 17, 12, 7, 2] @@ -44,7 +64,8 @@ public static void main(String[] args) { // [25, 20, 15, 10, 5] rotate(matrix); - for (int[] row : matrix) System.out.println(java.util.Arrays.toString(row)); + for (int[] row : matrix) + System.out.println(Arrays.toString(row)); // prints: // [25, 24, 23, 22, 21] // [20, 19, 18, 17, 16] @@ -53,7 +74,8 @@ public static void main(String[] args) { // [5, 4, 3, 2, 1] rotate(matrix); - for (int[] row : matrix) System.out.println(java.util.Arrays.toString(row)); + for (int[] row : matrix) + System.out.println(Arrays.toString(row)); // prints: // [5, 10, 15, 20, 25] // [4, 9, 14, 19, 24] @@ -62,13 +84,13 @@ public static void main(String[] args) { // [1, 6, 11, 16, 21] rotate(matrix); - for (int[] row : matrix) System.out.println(java.util.Arrays.toString(row)); + for (int[] row : matrix) + System.out.println(Arrays.toString(row)); // prints: // [1, 2, 3, 4, 5] // [6, 7, 8, 9, 10] // [11, 12, 13, 14, 15] // [16, 17, 18, 19, 20] // [21, 22, 23, 24, 25] - } } diff --git a/src/main/java/com/williamfiset/algorithms/linearalgebra/Simplex.java b/src/main/java/com/williamfiset/algorithms/linearalgebra/Simplex.java index 48722d2d8..4c4c74cf4 100644 --- a/src/main/java/com/williamfiset/algorithms/linearalgebra/Simplex.java +++ b/src/main/java/com/williamfiset/algorithms/linearalgebra/Simplex.java @@ -1,34 +1,44 @@ +package com.williamfiset.algorithms.linearalgebra; + /** - * This simplex algorithm maximizes an expression subject to a set of constraints + * Simplex Algorithm for Linear Programming + * + * Maximizes a linear objective function subject to linear inequality + * constraints. Uses the standard tableau simplex method with Bland's-like + * pivot selection (most negative coefficient in the objective row). + * + * Input format (tableau matrix m): + * - m[0] = objective row: m[0][0] is the current objective value, + * m[0][j] (j >= 1) are the negated coefficients of the objective function + * - m[i] (i >= 1) = constraint rows: m[i][0] is the RHS constant, + * m[i][j] (j >= 1) are the constraint coefficients + * + * Before calling, normalize the problem: + * 1) RHS must be non-negative (multiply by -1 if needed) + * 2) Add slack variables for <= inequalities + * 3) Add surplus + artificial variables for >= inequalities + * 4) For artificial variables, first maximize -(sum of artificials); + * if optimum is 0, remove artificial columns and re-run * - *

Time complexity: O(n^3) + * Time: O(n^3) per pivot, exponential worst case (rare in practice) + * Space: O(1) (in-place) * * @author Thomas Finn Lidbetter */ -package com.williamfiset.algorithms.linearalgebra; - public class Simplex { - static final double EPS = 1e-9; + private static final double EPS = 1e-9; - // The matrix given as an argument represents the function to be maximized - // and each of the constraints. Constraints and objective function must be - // normalized first through the following steps: - // 1) RHS must be non-negative so multiply any inequalities failing this by -1 - // 2) Add positive coefficient slack variable on LHS of any <= inequality - // 3) Add negative coefficient surplus variable on LHS of any >= inequality - // 4) Add positive coefficient artificial variable on LHS of any >= inequality and any = equality. - // - // If any artificial variables were added, perform simplex once, maximizing the - // negated sum of the artificial variables. If the maximum value is 0, take the - // resulting matrix and remove the artificial variable columns and replace function - // to maximise with original and run simplex again. If maximum value of simplex with - // artificial variables is non-zero there is no solution. First column of m is the constants - // on the RHS of all constraints. First row is the expression to maximise with all - // coefficients negated. M[i][j] is the coefficient of the (j-1)th term in the - // (i-1)th constraint (0 based). + /** + * Runs the simplex algorithm on the given tableau and returns the + * maximum value of the objective function. + * + * @param m the simplex tableau (modified in-place) + * @return the maximum objective value (m[0][0] after termination) + */ public static double simplex(double[][] m) { while (true) { + // Find the most negative coefficient in the objective row (pivot column) double min = -EPS; int c = -1; for (int j = 1; j < m[0].length; j++) { @@ -37,7 +47,9 @@ public static double simplex(double[][] m) { c = j; } } - if (c < 0) break; + if (c < 0) break; // All coefficients non-negative → optimal + + // Find the pivot row using the minimum ratio test min = Double.MAX_VALUE; int r = -1; for (int i = 1; i < m.length; i++) { @@ -49,12 +61,16 @@ public static double simplex(double[][] m) { } } } + + // Pivot: scale pivot row, then eliminate pivot column from all other rows double v = m[r][c]; - for (int j = 0; j < m[r].length; j++) m[r][j] /= v; + for (int j = 0; j < m[r].length; j++) + m[r][j] /= v; for (int i = 0; i < m.length; i++) { if (i != r) { v = m[i][c]; - for (int j = 0; j < m[i].length; j++) m[i][j] -= m[r][j] * v; + for (int j = 0; j < m[i].length; j++) + m[i][j] -= m[r][j] * v; } } } diff --git a/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/SegmentTreeWithPointersTest.java b/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/SegmentTreeWithPointersTest.java index a2c524ce2..b1ae610be 100644 --- a/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/SegmentTreeWithPointersTest.java +++ b/src/test/java/com/williamfiset/algorithms/datastructures/segmenttree/SegmentTreeWithPointersTest.java @@ -5,44 +5,100 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import com.williamfiset.algorithms.utils.TestUtils; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; public class SegmentTreeWithPointersTest { - @BeforeEach - public void setup() {} + @Test + public void testNullInputThrows() { + assertThrows(IllegalArgumentException.class, () -> new Node(null)); + } + + @Test + public void testNegativeSizeThrows() { + assertThrows(IllegalArgumentException.class, () -> new Node(-10)); + } + + @Test + public void testSingleElement() { + int[] values = {7}; + Node tree = new Node(values); + assertThat(tree.sum(0, 1)).isEqualTo(7); + assertThat(tree.min(0, 1)).isEqualTo(7); + } @Test - public void testIllegalSegmentTreeCreation1() { - assertThrows( - IllegalArgumentException.class, - () -> { - Node tree = new Node(null); - }); + public void testSumQuerySingleElements() { + int[] values = {1, 2, 3, 4, 5}; + Node tree = new Node(values); + for (int i = 0; i < values.length; i++) { + assertThat(tree.sum(i, i + 1)).isEqualTo(values[i]); + } } @Test - public void testIllegalSegmentTreeCreation2() { - assertThrows( - IllegalArgumentException.class, - () -> { - int size = -10; - Node tree = new Node(size); - }); + public void testSumQueryFullRange() { + int[] values = {1, 2, 3, 4, 5}; + Node tree = new Node(values); + assertThat(tree.sum(0, 5)).isEqualTo(15); + } + + @Test + public void testMinQuerySingleElements() { + int[] values = {5, 1, 3, 2, 4}; + Node tree = new Node(values); + for (int i = 0; i < values.length; i++) { + assertThat(tree.min(i, i + 1)).isEqualTo(values[i]); + } + } + + @Test + public void testMinQueryFullRange() { + int[] values = {5, 1, 3, 2, 4}; + Node tree = new Node(values); + assertThat(tree.min(0, 5)).isEqualTo(1); + assertThat(tree.min(0, 2)).isEqualTo(1); + assertThat(tree.min(2, 5)).isEqualTo(2); } @Test - public void testSumQuery() { + public void testRangeUpdate() { int[] values = {1, 2, 3, 4, 5}; Node tree = new Node(values); - assertThat(tree.sum(0, 1)).isEqualTo(1); - assertThat(tree.sum(1, 2)).isEqualTo(2); - assertThat(tree.sum(2, 3)).isEqualTo(3); - assertThat(tree.sum(3, 4)).isEqualTo(4); - assertThat(tree.sum(4, 5)).isEqualTo(5); + // Add 10 to elements in [1, 4): values become {1, 12, 13, 14, 5} + tree.update(1, 4, 10); + assertThat(tree.sum(0, 5)).isEqualTo(45); + assertThat(tree.sum(1, 4)).isEqualTo(39); + assertThat(tree.min(0, 5)).isEqualTo(1); + assertThat(tree.min(1, 4)).isEqualTo(12); } + @Test + public void testNegativeValues() { + int[] values = {-3, -1, -4, -1, -5}; + Node tree = new Node(values); + assertThat(tree.sum(0, 5)).isEqualTo(-14); + assertThat(tree.min(0, 5)).isEqualTo(-5); + assertThat(tree.min(0, 3)).isEqualTo(-4); + } + + @Test + public void testMultipleUpdates() { + Node tree = new Node(5); + // Start with all zeros, add 1 to entire range + tree.update(0, 5, 1); + assertThat(tree.sum(0, 5)).isEqualTo(5); + assertThat(tree.min(0, 5)).isEqualTo(1); + + // Add 2 to [2, 4): values become {1, 1, 3, 3, 1} + tree.update(2, 4, 2); + assertThat(tree.sum(0, 5)).isEqualTo(9); + assertThat(tree.min(0, 5)).isEqualTo(1); + assertThat(tree.min(2, 4)).isEqualTo(3); + } + + // Brute-force cross-validation for sum queries on random data @Test public void testAllSumQueries() { int n = 100; @@ -58,12 +114,31 @@ public void testAllSumQueries() { } } - // Finds the sum in an array between [l, r) in the `values` array + // Brute-force cross-validation for min queries on random data + @Test + public void testAllMinQueries() { + int n = 100; + int[] ar = TestUtils.randomIntegerArray(n, -1000, +1000); + Node tree = new Node(ar); + + for (int i = 0; i < n; i++) { + for (int j = i + 1; j < n; j++) { + int bfMin = bruteForceMin(ar, i, j); + int segTreeMin = tree.min(i, j); + assertThat(bfMin).isEqualTo(segTreeMin); + } + } + } + private static long bruteForceSum(int[] values, int l, int r) { long s = 0; - for (int i = l; i < r; i++) { - s += values[i]; - } + for (int i = l; i < r; i++) s += values[i]; return s; } + + private static int bruteForceMin(int[] values, int l, int r) { + int m = Integer.MAX_VALUE; + for (int i = l; i < r; i++) m = Math.min(m, values[i]); + return m; + } } diff --git a/src/test/java/com/williamfiset/algorithms/dp/BUILD b/src/test/java/com/williamfiset/algorithms/dp/BUILD index 5c6c2b8cd..91a1c716e 100644 --- a/src/test/java/com/williamfiset/algorithms/dp/BUILD +++ b/src/test/java/com/williamfiset/algorithms/dp/BUILD @@ -39,5 +39,16 @@ java_test( deps = TEST_DEPS, ) +# bazel test //src/test/java/com/williamfiset/algorithms/dp:MinimumWeightPerfectMatchingTest +java_test( + name = "MinimumWeightPerfectMatchingTest", + srcs = ["MinimumWeightPerfectMatchingTest.java"], + main_class = "org.junit.platform.console.ConsoleLauncher", + use_testrunner = False, + args = ["--select-class=com.williamfiset.algorithms.dp.MinimumWeightPerfectMatchingTest"], + runtime_deps = JUNIT5_RUNTIME_DEPS, + deps = TEST_DEPS, +) + # Run all tests # bazel test //src/test/java/com/williamfiset/algorithms/dp:all diff --git a/src/test/java/com/williamfiset/algorithms/dp/MinimumWeightPerfectMatchingTest.java b/src/test/java/com/williamfiset/algorithms/dp/MinimumWeightPerfectMatchingTest.java new file mode 100644 index 000000000..a72ea7480 --- /dev/null +++ b/src/test/java/com/williamfiset/algorithms/dp/MinimumWeightPerfectMatchingTest.java @@ -0,0 +1,175 @@ +package com.williamfiset.algorithms.dp; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +public class MinimumWeightPerfectMatchingTest { + + @Test + public void testNullInput() { + assertThrows(IllegalArgumentException.class, () -> new MinimumWeightPerfectMatching(null)); + } + + @Test + public void testEmptyMatrix() { + assertThrows( + IllegalArgumentException.class, () -> new MinimumWeightPerfectMatching(new double[0][0])); + } + + @Test + public void testOddSizeMatrix() { + double[][] cost = new double[3][3]; + assertThrows(IllegalArgumentException.class, () -> new MinimumWeightPerfectMatching(cost)); + } + + /** Two nodes — only one possible matching: (0, 1). */ + @Test + public void testTwoNodes() { + double[][] cost = { + {0, 5}, + {5, 0} + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + assertThat(mwpm.getMinWeightCost()).isEqualTo(5.0); + + int[] matching = mwpm.getMinWeightCostMatching(); + assertThat(matching).isEqualTo(new int[] {0, 1}); + } + + /** Four nodes where the optimal pairing is (0,2) and (1,3) with cost 1+1=2. */ + @Test + public void testFourNodes_recursiveSolver() { + double[][] cost = { + {0, 2, 1, 2}, + {2, 0, 2, 1}, + {1, 2, 0, 2}, + {2, 1, 2, 0}, + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + assertThat(mwpm.getMinWeightCost()).isEqualTo(2.0); + + int[] matching = mwpm.getMinWeightCostMatching(); + assertThat(matching).isEqualTo(new int[] {0, 2, 1, 3}); + } + + /** Same four-node case but using the iterative solver. */ + @Test + public void testFourNodes_iterativeSolver() { + double[][] cost = { + {0, 2, 1, 2}, + {2, 0, 2, 1}, + {1, 2, 0, 2}, + {2, 1, 2, 0}, + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + mwpm.solveIterative(); + assertThat(mwpm.getMinWeightCost()).isEqualTo(2.0); + } + + /** Six nodes with three clear cheapest pairs: (0,5), (1,2), (3,4). */ + @Test + public void testSixNodes() { + double[][] cost = { + {0, 9, 9, 9, 9, 1}, + {9, 0, 1, 9, 9, 9}, + {9, 1, 0, 9, 9, 9}, + {9, 9, 9, 0, 1, 9}, + {9, 9, 9, 1, 0, 9}, + {1, 9, 9, 9, 9, 0}, + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + assertThat(mwpm.getMinWeightCost()).isEqualTo(3.0); + + int[] matching = mwpm.getMinWeightCostMatching(); + assertThat(matching).isEqualTo(new int[] {0, 5, 1, 2, 3, 4}); + } + + /** All pairs have equal cost — any matching gives the same total. */ + @Test + public void testUniformCost() { + double[][] cost = { + {0, 3, 3, 3}, + {3, 0, 3, 3}, + {3, 3, 0, 3}, + {3, 3, 3, 0}, + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + assertThat(mwpm.getMinWeightCost()).isEqualTo(6.0); + } + + /** Matching output contains each node exactly once. */ + @Test + public void testMatchingContainsAllNodes() { + double[][] cost = { + {0, 1, 5, 5}, + {1, 0, 5, 5}, + {5, 5, 0, 2}, + {5, 5, 2, 0}, + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + int[] matching = mwpm.getMinWeightCostMatching(); + + boolean[] seen = new boolean[4]; + for (int node : matching) + seen[node] = true; + for (int i = 0; i < 4; i++) + assertThat(seen[i]).isTrue(); + } + + /** Pairs in matching are sorted: left nodes ascending, and a < b in each pair (a,b). */ + @Test + public void testMatchingIsSorted() { + double[][] cost = { + {0, 9, 9, 1, 9, 9}, + {9, 0, 9, 9, 9, 1}, + {9, 9, 0, 9, 1, 9}, + {1, 9, 9, 0, 9, 9}, + {9, 9, 1, 9, 0, 9}, + {9, 1, 9, 9, 9, 0}, + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + int[] matching = mwpm.getMinWeightCostMatching(); + + // Each pair (a,b) has a < b + for (int i = 0; i < matching.length; i += 2) + assertThat(matching[i]).isLessThan(matching[i + 1]); + + // Left nodes are in ascending order + for (int i = 2; i < matching.length; i += 2) + assertThat(matching[i]).isGreaterThan(matching[i - 2]); + } + + /** Verify cost equals sum of matched pair costs. */ + @Test + public void testCostMatchesMatchingSum() { + double[][] cost = { + {0, 3, 7, 2}, + {3, 0, 1, 8}, + {7, 1, 0, 4}, + {2, 8, 4, 0}, + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + double minCost = mwpm.getMinWeightCost(); + int[] matching = mwpm.getMinWeightCostMatching(); + + double sum = 0; + for (int i = 0; i < matching.length; i += 2) + sum += cost[matching[i]][matching[i + 1]]; + assertThat(minCost).isEqualTo(sum); + } + + /** Solving twice returns the same result (tests caching via solved flag). */ + @Test + public void testIdempotent() { + double[][] cost = { + {0, 4}, + {4, 0} + }; + MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching(cost); + double first = mwpm.getMinWeightCost(); + double second = mwpm.getMinWeightCost(); + assertThat(first).isEqualTo(second); + } +}