1+ package com .williamfiset .algorithms .dp ;
2+
3+ import java .awt .geom .Point2D ;
4+ import java .util .ArrayList ;
5+ import java .util .Arrays ;
6+ import java .util .Collections ;
7+ import java .util .List ;
8+
19/**
2- * Implementation of the Minimum Weight Perfect Matching (MWPM) problem. In this problem you are
3- * given a distance matrix which gives the distance from each node to every other node, and you want
4- * to pair up all the nodes to one another minimizing the overall cost.
10+ * Minimum Weight Perfect Matching (MWPM)
511 *
6- * <p>Tested against: UVA 10911 - Forming Quiz Teams
12+ * Given n nodes and a symmetric distance matrix, pairs up all nodes to minimize
13+ * total matching cost. Uses bitmask DP where each state represents a subset of
14+ * matched nodes. Two solvers are provided:
715 *
8- * <p>To Run: bazel run //src/main/java/com/williamfiset/algorithms/dp:MinimumWeightPerfectMatching
16+ * - Top-down (recursive with memoization): naturally skips unreachable states
17+ * - Bottom-up (iterative): builds solutions from pairs upward
918 *
10- * <p>Time Complexity: O(n * 2^n)
19+ * Requires n to be even (otherwise no perfect matching exists) and n <= 32
20+ * (bitmask representation limit).
21+ *
22+ * Tested against: UVA 10911 - Forming Quiz Teams
23+ *
24+ * Time: O(n^2*2^n)
25+ * Space: O(2^n)
1126 *
1227 * @author William Fiset
1328 */
14- package com .williamfiset .algorithms .dp ;
15-
16- import java .awt .geom .*;
17- import java .util .*;
18-
1929public class MinimumWeightPerfectMatching {
2030
2131 // Inputs
2232 private final int n ;
23- private double [][] cost ;
33+ private final double [][] cost ;
2434
2535 // Internal
2636 private final int END_STATE ;
@@ -30,7 +40,13 @@ public class MinimumWeightPerfectMatching {
3040 private double minWeightCost ;
3141 private int [] matching ;
3242
33- // The cost matrix should be a symmetric (i.e cost[i][j] = cost[j][i])
43+ /**
44+ * Creates a MWPM solver for the given cost matrix.
45+ *
46+ * @param cost symmetric n x n distance matrix (cost[i][j] = cost[j][i])
47+ *
48+ * @throws IllegalArgumentException if matrix is null, empty, odd-sized, or too large
49+ */
3450 public MinimumWeightPerfectMatching (double [][] cost ) {
3551 if (cost == null ) throw new IllegalArgumentException ("Input cannot be null" );
3652 n = cost .length ;
@@ -45,8 +61,12 @@ public MinimumWeightPerfectMatching(double[][] cost) {
4561 this .cost = cost ;
4662 }
4763
64+ /**
65+ * Returns the minimum total cost of a perfect matching.
66+ * Lazily solves using the recursive solver if neither solver has run yet.
67+ */
4868 public double getMinWeightCost () {
49- solveRecursive ();
69+ if (! solved ) solveRecursive ();
5070 return minWeightCost ;
5171 }
5272
@@ -68,11 +88,21 @@ public double getMinWeightCost() {
6888 * }</pre>
6989 */
7090 public int [] getMinWeightCostMatching () {
71- solveRecursive ();
91+ if (! solved ) solveRecursive ();
7292 return matching ;
7393 }
7494
75- // Recursive impl
95+ // ==================== Solver 1: Top-down (recursive with memoization) ====================
96+
97+ /**
98+ * Solves using top-down recursion with memoization. Starting from the full set
99+ * of nodes, it finds the lowest-numbered unmatched node and tries pairing it
100+ * with every other unmatched node, recursing on the reduced state.
101+ *
102+ * This approach naturally skips unreachable states (states that can't be formed
103+ * by removing pairs from the full set), so it often visits fewer states than
104+ * the iterative solver.
105+ */
76106 public void solveRecursive () {
77107 if (solved ) return ;
78108 Double [] dp = new Double [1 << n ];
@@ -83,24 +113,17 @@ public void solveRecursive() {
83113 }
84114
85115 private double f (int state , Double [] dp , int [] history ) {
86- if (dp [state ] != null ) {
87- return dp [state ];
88- }
89- if (state == 0 ) {
90- return 0 ;
91- }
92- int p1 , p2 ;
93- // Seek to find active bit position (p1)
94- for (p1 = 0 ; p1 < n ; p1 ++) {
95- if ((state & (1 << p1 )) > 0 ) {
96- break ;
97- }
98- }
116+ if (dp [state ] != null ) return dp [state ];
117+ if (state == 0 ) return 0 ;
118+
119+ // Find the lowest set bit position (p1) — always pair this node first
120+ int p1 = Integer .numberOfTrailingZeros (state );
121+
99122 int bestState = -1 ;
100123 double minimum = Double .MAX_VALUE ;
101124
102- for ( p2 = p1 + 1 ; p2 < n ; p2 ++) {
103- // Position `p2` is on. Try matching the pair (p1, p2) together.
125+ // Try pairing p1 with every other set bit
126+ for ( int p2 = p1 + 1 ; p2 < n ; p2 ++) {
104127 if ((state & (1 << p2 )) > 0 ) {
105128 int reducedState = state ^ (1 << p1 ) ^ (1 << p2 );
106129 double matchCost = f (reducedState , dp , history ) + cost [p1 ][p2 ];
@@ -114,7 +137,17 @@ private double f(int state, Double[] dp, int[] history) {
114137 return dp [state ] = minimum ;
115138 }
116139
117- public void solve () {
140+ // ==================== Solver 2: Bottom-up (iterative) ====================
141+
142+ /**
143+ * Solves using bottom-up iterative DP. Pre-computes all n*(n-1)/2 pair states,
144+ * then iterates over all bitmask states in ascending order, extending each
145+ * valid matching by adding a non-overlapping pair.
146+ *
147+ * This approach visits all 2^n states systematically. It avoids recursion
148+ * overhead and stack depth limits, making it better suited for larger n.
149+ */
150+ public void solveIterative () {
118151 if (solved ) return ;
119152
120153 // The DP state is encoded as a bitmask where the i'th bit is flipped on if the i'th node is
@@ -152,9 +185,8 @@ public void solve() {
152185 for (int state = 0b11; state < (1 << n ); state ++) { // O(2^n)
153186 // Skip states with an odd number of bits (nodes). It's easier (and faster) to
154187 // check dp[state] instead of calling `Integer.bitCount` for the bit count.
155- if (dp [state ] == null ) {
156- continue ;
157- }
188+ if (dp [state ] == null ) continue ;
189+
158190 for (int i = 0 ; i < numPairs ; i ++) { // O(n^2)
159191 int pair = pairStates [i ];
160192 // Ignore states which overlap
@@ -178,65 +210,46 @@ public void solve() {
178210 solved = true ;
179211 }
180212
181- // Populates the `matching` array with a sorted deterministic matching sorted by lowest node
182- // index. For example, if the perfect matching consists of the pairs (3, 4), (1, 5), (0, 2).
183- // The matching is sorted such that the pairs appear in the ordering: (0, 2), (1, 5), (3, 4).
184- // Furthermore, it is guaranteed that for any pair (a, b) that a < b.
213+ /**
214+ * Populates the {@code matching} array with a sorted deterministic matching.
215+ * For example, if the perfect matching consists of the pairs (3, 4), (1, 5), (0, 2),
216+ * the output is sorted as: (0, 2), (1, 5), (3, 4).
217+ * For any pair (a, b), it is guaranteed that a < b.
218+ */
185219 private void reconstructMatching (int [] history ) {
186- // A map between pairs of nodes that were matched together.
187220 int [] map = new int [n ];
188221 int [] leftNodes = new int [n / 2 ];
189222
190- // Reconstruct the matching of pairs of nodes working backwards through computed states.
223+ // Walk backwards through computed states to recover matched pairs
191224 for (int i = 0 , state = END_STATE ; state != 0 ; state = history [state ]) {
192- // Isolate the pair used by xoring the state with the state used to generate it.
193225 int pairUsed = state ^ history [state ];
194226
195- int leftNode = getBitPosition (Integer .lowestOneBit (pairUsed ));
196- int rightNode = getBitPosition (Integer .highestOneBit (pairUsed ));
227+ int leftNode = Integer . numberOfTrailingZeros (Integer .lowestOneBit (pairUsed ));
228+ int rightNode = Integer . numberOfTrailingZeros (Integer .highestOneBit (pairUsed ));
197229
198230 leftNodes [i ++] = leftNode ;
199231 map [leftNode ] = rightNode ;
200232 }
201233
202- // Sort the left nodes in ascending order.
203- java .util .Arrays .sort (leftNodes );
234+ Arrays .sort (leftNodes );
204235
205236 matching = new int [n ];
206237 for (int i = 0 ; i < n / 2 ; i ++) {
207238 matching [2 * i ] = leftNodes [i ];
208- int rightNode = map [leftNodes [i ]];
209- matching [2 * i + 1 ] = rightNode ;
210- }
211- }
212-
213- // Gets the zero base index position of the 1 bit in `k`. `k` must be a power of 2, so there is
214- // only ever 1 bit in the binary representation of k.
215- private int getBitPosition (int k ) {
216- int count = -1 ;
217- while (k > 0 ) {
218- count ++;
219- k >>= 1 ;
239+ matching [2 * i + 1 ] = map [leftNodes [i ]];
220240 }
221- return count ;
222241 }
223242
224- /* Example */
225-
226243 public static void main (String [] args ) {
227- // test1();
228- // for (int i = 0; i < 50; i++) {
229- // if (include(i)) System.out.printf("%2d %7s\n", i, Integer.toBinaryString(i));
230- // }
231- }
232-
233- private static boolean include (int i ) {
234- boolean toInclude = Integer .bitCount (i ) >= 2 && Integer .bitCount (i ) % 2 == 0 ;
235- return toInclude ;
244+ test1 ();
245+ test2 ();
236246 }
237247
248+ // Example 1: Uses the RECURSIVE solver.
249+ // Generates 2D points that form vertical pairs, shuffles them, and verifies
250+ // the MWPM correctly matches each pair (cost = 1 per pair, total = n/2).
238251 private static void test1 () {
239- // int n = 18 ;
252+ System . out . println ( "=== Recursive solver ===" ) ;
240253 int n = 6 ;
241254 List <Point2D > pts = new ArrayList <>();
242255
@@ -248,13 +261,13 @@ private static void test1() {
248261 Collections .shuffle (pts );
249262
250263 double [][] cost = new double [n ][n ];
251- for (int i = 0 ; i < n ; i ++) {
252- for (int j = 0 ; j < n ; j ++) {
264+ for (int i = 0 ; i < n ; i ++)
265+ for (int j = 0 ; j < n ; j ++)
253266 cost [i ][j ] = pts .get (i ).distance (pts .get (j ));
254- }
255- }
256267
257268 MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching (cost );
269+ mwpm .solveRecursive ();
270+
258271 double minCost = mwpm .getMinWeightCost ();
259272 if (minCost != n / 2 ) {
260273 System .out .printf ("MWPM cost is wrong! Got: %.5f But wanted: %d\n " , minCost , n / 2 );
@@ -275,7 +288,10 @@ private static void test1() {
275288 }
276289 }
277290
291+ // Example 2: Uses the ITERATIVE solver.
292+ // Simple 4-node symmetric matrix where the optimal matching costs 2.0.
278293 private static void test2 () {
294+ System .out .println ("=== Iterative solver ===" );
279295 double [][] costMatrix = {
280296 {0 , 2 , 1 , 2 },
281297 {2 , 0 , 2 , 1 },
@@ -284,12 +300,12 @@ private static void test2() {
284300 };
285301
286302 MinimumWeightPerfectMatching mwpm = new MinimumWeightPerfectMatching (costMatrix );
303+ mwpm .solveIterative ();
304+
287305 double cost = mwpm .getMinWeightCost ();
288306 if (cost != 2.0 ) {
289307 System .out .println ("error cost not 2" );
290308 }
291- System .out .println (cost );
292- // System.out.println(mwpm.solve2());
293-
309+ System .out .println (cost ); // 2.0
294310 }
295311}
0 commit comments