Coverage for src/ensae_projects/challenge/blossom.py: 83%
436 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-07-20 04:37 +0200
« prev ^ index » next coverage.py v7.1.0, created at 2023-07-20 04:37 +0200
1"""
2@file
3@brief Source: https://github.com/koniiiik/edmonds-blossom/blob/master/blossom.py
5An implementation of Edmonds' blossom algorithm for finding minimum-weight
6maximum matchings.
7"""
9import fractions
10import functools
11import sys
14INF = float('inf')
15# Possible levels of blossoms.
16LEVEL_EVEN = 0
17LEVEL_ODD = 1
18# Out-of-tree blossoms; they always appear in pairs connected by a matched
19# edge.
20LEVEL_OOT = -1
21# Blossoms embedded in another blossom.
22LEVEL_EMBED = -2
25def cached_property(fun):
26 """A memoize decorator for class properties."""
27 @functools.wraps(fun)
28 def get(self):
29 try:
30 return self._cache[fun]
31 except AttributeError:
32 self._cache = {}
33 except KeyError:
34 pass
35 ret = self._cache[fun] = fun(self)
36 return ret
37 return property(get)
40class MaximumDualReached(Exception):
41 """
42 Indicates that we have reached the maximum dual solution and cannot
43 improve it further.
44 """
47class EdgeTraversalError(Exception):
48 pass
51class EdgeNotOutgoing(EdgeTraversalError):
52 """
53 Indicates that a traversal was requested through an edge from a set of
54 vertices containing both its endpoints.
55 """
58class EdgeNotIncident(EdgeTraversalError):
59 """
60 Indicates that a traversal was requested through an edge from a set
61 that doesn't contain any of its endpoints.
62 """
65class TreeStructureChanged(Exception):
66 """
67 Used whenever the structure of an alternating tree is changed to abort
68 current traversal and initiate a new one.
69 """
72class StructureUpToDate(Exception):
73 """
74 This gets raised as soon as the structure of all trees is up-to-date,
75 i.e. there are no more instances of any of the four cases.
76 """
79class Edge:
81 def __init__(self, v1, v2, value):
82 self.vertices = frozenset((v1, v2))
83 self.value = value
84 self.selected = 0
86 @property
87 def extremities(self):
88 return tuple(sorted(v.id for v in self.vertices))
90 def __hash__(self):
91 return hash(id(self))
93 def __eq__(self, other):
94 return self is other
96 def __str__(self):
97 return "(%d, %d)" % tuple(sorted(v.id for v in self.vertices))
99 def __repr__(self):
100 return "<Edge: %s>" % (str(self),)
102 def traverse_from(self, v):
103 """Returns the other endpoint of an edge.
105 The argument can be either a vertex, a Blossom or a set of
106 vertices. The argument is supposed to contain exactly one
107 endpoint.
108 """
109 if isinstance(v, Blossom):
110 v = v.members
111 diff = self.vertices - v
112 if len(diff) == 0:
113 raise EdgeNotOutgoing()
114 if len(diff) > 1:
115 raise EdgeNotIncident()
116 return next(iter(diff))
118 def calculate_charge(self):
119 """Calculates the total charge on this edge.
121 The charge is calculated as the sum of all blossoms containing
122 each vertex minus twice the sum of all blossoms containing both
123 vertices at the same time.
124 """
125 it = iter(self.vertices)
126 blossom = next(it)
127 first_owners = set()
128 charge = 0
129 while blossom is not None:
130 charge += blossom.charge
131 first_owners.add(blossom)
132 blossom = blossom.owner
134 blossom = next(it)
135 common = None
136 while blossom is not None:
137 charge += blossom.charge
138 if common is None and blossom in first_owners:
139 common = blossom
140 blossom = blossom.owner
142 while common is not None:
143 charge -= 2 * common.charge
144 common = common.owner
146 return charge
148 def get_remaining_charge(self):
149 return self.value - self.calculate_charge()
151 def toggle_selection(self):
152 """Toggles the membership of self in the current matching.
153 """
154 assert self.get_remaining_charge() == 0, ("toggle_selection called "
155 "on non-tight edge")
156 self.selected = 1 - self.selected
159class Blossom:
160 # For nontrivial blossoms, the charge cannot decrease below 0.
161 minimum_charge = 0
163 def __init__(self, cycle, charge=0, level=LEVEL_EVEN):
164 self.charge = fractions.Fraction(charge)
165 self.level = level
166 # Reference to the blossom directly containing this one.
167 self.owner = None
168 # The cycle of blossoms this one consists of. The first element is
169 # the base.
170 self.cycle = tuple(cycle)
171 assert len(self.cycle) % 2 == 1
172 # References to parent and children in a tree. For out-of-tree
173 # pairs the parent is a reference to the peer. For embedded
174 # blossoms the parent is the predecessor in the cycle.
175 self.parent = None
176 self.parent_edge = None
177 self.children = set()
179 def __hash__(self):
180 return hash(id(self))
182 def __eq__(self, other):
183 return self is other
185 def __str__(self):
186 return "(%s)" % (' '.join(str(b) for b in self.cycle))
188 def __repr__(self):
189 return "<%s: %s>" % (self.__class__.__name__, str(self))
191 @cached_property
192 def outgoing_edges(self):
193 """
194 Returns a list of pairs (edge, target vertex).
195 """
196 members = self.members
197 return list({(e, v)
198 for blossom in self.cycle
199 for e, v in blossom.outgoing_edges
200 if v not in members})
202 @cached_property
203 def members(self):
204 return {v for blossom in self.cycle for v in blossom.members}
206 def get_outermost_blossom(self):
207 if self.owner is not None:
208 return self.owner.get_outermost_blossom()
209 return self
211 def get_root(self):
212 assert self.level in (LEVEL_EVEN, LEVEL_ODD), ("get_root called on "
213 "an out-of-tree blossom")
214 if self.parent is None:
215 return self
216 return self.parent.get_root()
218 def get_base_vertex(self):
219 return self.cycle[0].get_base_vertex()
221 def get_max_delta(self):
222 """
223 Finds the maximum allowed charge adjust for this blossom and its
224 children.
225 """
226 if self.level == LEVEL_ODD:
227 # Blossoms on odd levels are going to be decreased.
228 delta = self.charge - self.minimum_charge
229 elif self.level == LEVEL_EVEN:
230 # Even levels get increased. We need to check each outgoing
231 # edge.
232 delta = INF
233 for e, v in self.outgoing_edges:
234 b = v.get_outermost_blossom()
235 remaining = e.get_remaining_charge()
237 if b.level == LEVEL_EVEN:
238 # Both ends of e are on even level, both get
239 # increased, therefore each can only get one half of
240 # remaining capacity.
241 delta = min(delta, remaining / 2)
242 elif b.level == LEVEL_OOT:
243 # The other end is an out-of-tree blossom whose charge
244 # remains the same.
245 delta = min(delta, remaining)
246 # Odd blossoms don't limit us in any way as the difference
247 # gets canceled out on the other end.
249 # Recurse into children.
250 return min([delta] + [child.get_max_delta() for child in self.children])
252 def adjust_charge(self, delta):
253 """
254 Decides what is supposed to happen on charge adjusts and recurses.
255 """
256 if self.level == LEVEL_EVEN:
257 self.charge += delta
258 elif self.level == LEVEL_ODD:
259 self.charge -= delta
260 assert self.charge >= self.minimum_charge, ("the charge of a "
261 "blossom dropped "
262 "below minimum")
264 for child in self.children:
265 child.adjust_charge(delta)
267 def alter_tree(self, roots):
268 """Detects and handles the four cases where trees need to be altered.
269 """
270 if self.level == LEVEL_ODD:
271 if self.charge == 0:
272 self.expand(roots)
273 elif self.level == LEVEL_EVEN:
274 self.handle_tight_edges(roots)
275 else:
276 assert False, ("alter_tree called on blossom of level %d" %
277 self.level)
279 for child in self.children:
280 child.alter_tree(roots)
282 def handle_tight_edges(self, roots):
283 """Finds any fresh tight edges.
285 If a tight edge leads to an out-of-tree blossom, attach the pair
286 (P2).
288 If a tight edge leads to a blossom in the same tree as this one
289 (the root blossom is the same), shrink (P3).
291 If a tight edge leads to a blossom with a different root, augment
292 (P4).
293 """
294 assert self.level == LEVEL_EVEN, ("handle_tight_edges called on "
295 "non-even blossom.")
297 for e, v in self.outgoing_edges:
298 if e is self.parent_edge:
299 continue
300 remaining_charge = e.get_remaining_charge()
301 assert remaining_charge >= 0, "found an overcharged edge"
302 if remaining_charge > 0:
303 continue
304 other_blossom = v.get_outermost_blossom()
305 if other_blossom.level == LEVEL_ODD:
306 # This blossom can be from a different tree -- we don't
307 # particularly care. The charge adjusts will be
308 # compensated anyway.
309 continue
310 if other_blossom.level == LEVEL_OOT:
311 self.attach_out_of_tree_pair(other_blossom, e)
312 continue
313 if other_blossom.get_root() == self.get_root():
314 self.shrink_with_peer(other_blossom, e, roots)
315 else:
316 self.augment_matching(other_blossom, e, roots)
318 def attach_out_of_tree_pair(self, target, edge):
319 """Handles case (P2).
320 """
321 assert self.level == LEVEL_EVEN
322 assert target.level == LEVEL_OOT
323 assert len(target.children) == 0
325 self.children.add(target)
326 target_peer = target.parent
327 target.parent = self
328 target.parent_edge = edge
329 target.level = LEVEL_ODD
330 target_peer.level = LEVEL_EVEN
331 target.children.add(target_peer)
332 assert len(target_peer.children) == 0
333 raise TreeStructureChanged("Attached blossom on edge %s" % edge)
335 def shrink_with_peer(self, other, edge, roots):
336 """Shrinks the cycle along given edge into a new blossom. (P3)
337 """
338 assert self.level == LEVEL_EVEN
339 assert other.level == LEVEL_EVEN
340 # Find the closest common ancestor and the chains of parents
341 # leading to them.
342 ancestors, parent_chain1, parent_chain2 = {}, [], []
343 blossom = self
344 while blossom is not None:
345 ancestors[blossom] = len(parent_chain1)
346 parent_chain1.append(blossom)
347 assert blossom.parent is None or blossom in blossom.parent.children
348 blossom = blossom.parent
350 blossom = other
351 while blossom not in ancestors:
352 parent_chain2.append(blossom)
353 assert blossom.parent is None or blossom in blossom.parent.children
354 blossom = blossom.parent
356 parent_chain2.append(blossom)
357 common_ancestor = blossom
358 # We need to store these values here since they get rewritten in the
359 # following loop.
360 new_parent = common_ancestor.parent
361 new_parent_edge = common_ancestor.parent_edge
363 # Remove references to other components of the new blossom from each
364 # component's children list.
365 for blossom in (self, other):
366 while blossom is not common_ancestor:
367 blossom.parent.children.remove(blossom)
368 blossom = blossom.parent
370 # Repoint the parent references in parent_chain2 to point in the
371 # other direction. This will close the cycle.
372 prev_edge, prev_blossom = edge, self
373 for blossom in parent_chain2:
374 prev_edge, prev_blossom, blossom.parent, blossom.parent_edge = (
375 blossom.parent_edge, blossom, prev_blossom, prev_edge
376 )
378 # The cycle now consists of the reverse of parent_chain1 up to and
379 # including common_ancestor + parent_chain2 sans common_ancestor.
380 cycle = parent_chain1[
381 ancestors[common_ancestor]::-1] + parent_chain2[:-1]
383 new_blossom = Blossom(cycle)
384 for blossom in cycle:
385 new_blossom.children.update(blossom.children)
386 blossom.owner = new_blossom
387 blossom.children.clear()
388 blossom.level = LEVEL_EMBED
390 for child in new_blossom.children:
391 child.parent = new_blossom
393 if new_parent is None:
394 registry = roots
395 else:
396 registry = new_parent.children
398 registry.remove(common_ancestor)
399 registry.add(new_blossom)
400 new_blossom.parent = new_parent
401 new_blossom.parent_edge = new_parent_edge
403 raise TreeStructureChanged("Shrunk cycle on edge %s" % edge)
405 def augment_matching(self, other_blossom, edge, roots):
406 """Augments the matching along the alternating path containing given edge. (P4)
407 """
408 # Always look for a path of even length within a blossom. Recurse
409 # into sub-blossoms.
410 assert edge.selected == 0, ("trying to augment via an already "
411 "selected edge")
412 self.flip_root_path(edge, roots)
413 other_blossom.flip_root_path(edge, roots)
414 edge.toggle_selection()
415 raise TreeStructureChanged("Augmented on edge %s" % edge)
417 def flip_root_path(self, edge, roots):
418 """Flips edge selection on the alternating path from self to root.
420 Argument edge is the edge from a child from which the alternating
421 path leads through self.
423 The children of this blossom are detached and this blossom becomes
424 part of an out-of-tree pair around the given edge.
425 """
426 assert self.level in (LEVEL_EVEN, LEVEL_ODD)
427 v1 = self.get_base_vertex()
428 if self.level == LEVEL_EVEN:
429 v2 = next(iter(self.members & edge.vertices))
430 else:
431 v2 = next(iter(self.members & self.parent_edge.vertices))
433 self.flip_alternating_path(v1, v2)
435 # We need to store these values because detach_from_parent
436 # modifies them and we need to recurse later.
437 prev_parent, prev_parent_edge = self.parent, self.parent_edge
439 if self.level == LEVEL_EVEN:
440 self.detach_children(roots)
441 # Become a peer to the blossom on the other side of edge.
442 self.detach_from_parent(edge, roots)
443 else:
444 # Adjust self to become a peer to our parent.
445 # Our child should be detached by now.
446 assert len(self.children) == 0
447 assert self.parent is not None
448 self.parent.children.remove(self)
449 self.level = LEVEL_OOT
451 if prev_parent is not None:
452 prev_parent_edge.toggle_selection()
453 prev_parent.flip_root_path(prev_parent_edge, roots)
455 def flip_alternating_path(self, v1, v2):
456 """Flips edge selection on the alternating path from v1 to v2.
458 v1 and v2 are the two vertices at the boundaries of this blossom
459 along the augmenting alternating path. One of the two vertices
460 needs to be the base of this blossom.
461 """
462 assert v1 in self.members
463 assert v2 in self.members
464 if v1 is v2:
465 return
467 if v1 not in self.cycle[0].members:
468 v1, v2 = v2, v1
470 # v1 is in the base blossom, find the blossom containing v2
471 lasti = 0
472 for i, b in enumerate(self.cycle):
473 lasti = i
474 if v2 in b.members:
475 break
476 i = lasti
478 # Trivial case: if both v1 and v2 are in the same blossom, we
479 # don't need to do anything at this level.
480 if i == 0:
481 self.cycle[0].flip_alternating_path(v1, v2)
482 return
484 # self.cycle has odd length, pick the direction in which the path
485 # from cycle[i] to cycle[0] has even length.
486 sub_calls, edges = [], []
487 if i % 2 == 0:
488 # Proceed from the base forwards toward self.cycle[i].
489 start, finish = 0, i
490 else:
491 # Proceed from self.cycle[i] forwards toward base.
492 start, finish = i - len(self.cycle), 0
493 v1, v2 = v2, v1
495 prev_vertex = v1
496 for j in range(start, finish):
497 edge = self.cycle[j + 1].parent_edge
498 sub_calls.append((self.cycle[j], prev_vertex,
499 edge.traverse_from(self.cycle[j + 1])))
500 edges.append(edge)
501 prev_vertex = edge.traverse_from(self.cycle[j])
502 sub_calls.append((self.cycle[finish], prev_vertex, v2))
504 assert len(sub_calls) % 2 == 1
505 assert len(edges) % 2 == 0
507 for e in edges:
508 e.toggle_selection()
509 for blossom, x1, x2 in sub_calls:
510 blossom.flip_alternating_path(x1, x2)
512 # Shift self.cycle to new base to keep the invariant that cycle[0]
513 # is the base. cycle[i] should become cycle[0].
514 self.cycle = self.cycle[i:] + self.cycle[:i]
516 def detach_children(self, roots):
517 """Detaches all children and turns them into out-of-tree pairs.
518 """
519 # We need to make a copy of self.children here since it gets
520 # modified in each iteration.
521 for child in list(self.children):
522 child.detach_from_parent(None, roots)
524 def detach_from_parent(self, edge, roots):
525 """Detaches itself from the parent and forms an out-of-tree pair.
527 If called on an odd blossom, edge needs to be None and the only
528 child will be chosen. Otherwise, an edge leading to a peer needs
529 to be supplied.
531 If an edge is specified, it assumes the peer adjusts itself,
532 otherwise we also adjust the single child that becomes the peer.
533 """
534 assert (self.level == LEVEL_ODD) ^ (edge is not None)
536 if self.parent is not None:
537 self.parent.children.remove(self)
538 else:
539 roots.remove(self)
540 if edge is not None:
541 peer = edge.traverse_from(self).get_outermost_blossom()
542 self.detach_children(roots)
543 else:
544 peer = next(iter(self.children))
545 self.children.clear()
546 self.level = LEVEL_OOT
547 self.parent = peer
548 if edge is None:
549 self.parent_edge = peer.parent_edge
550 peer.level = LEVEL_OOT
551 peer.detach_children(roots)
552 else:
553 self.parent_edge = edge
555 def expand(self, roots):
556 """
557 Expands this blossom back into its constituents. (P1)
558 """
559 assert self.level == LEVEL_ODD
560 assert self.charge == 0
562 base = self.cycle[0]
563 # Since self is odd, it has exactly one child.
564 assert len(self.children) == 1
565 child = next(iter(self.children))
566 assert child.level == LEVEL_EVEN
567 child.parent = base
568 base.children.add(child)
570 # Find the component connected to parent.
571 boundary_vertex = next(iter(self.parent_edge.vertices & self.members))
572 lasti = 0
573 last_bl = None
574 for i, blossom in enumerate(self.cycle):
575 lasti = i
576 last_bl = blossom
577 if boundary_vertex in blossom.members:
578 break
579 i = lasti
580 blossom = last_bl
582 if i % 2 == 0:
583 # Repoint the even-length part of the cycle to form a part of
584 # the tree.
585 prev_parent, prev_edge = self.parent, self.parent_edge
586 for j in range(i, -1, -1):
587 b = self.cycle[j]
588 assert b.level == LEVEL_EMBED
589 assert prev_parent.level in (LEVEL_EVEN, LEVEL_ODD)
590 assert b.owner is self
591 b.level = 1 - prev_parent.level
592 b.owner = None
593 if prev_parent is not None:
594 prev_parent.children.add(b)
595 prev_parent, prev_edge, b.parent, b.parent_edge = (
596 b, b.parent_edge, prev_parent, prev_edge
597 )
598 pairs_start, pairs_end = -1, i - len(self.cycle)
599 else:
600 # The even part of our cycle has the correct pointers already,
601 # fix their levels and ownership.
602 blossom.parent = self.parent
603 blossom.parent_edge = self.parent_edge
604 for j in range(i - len(self.cycle), 1):
605 b = self.cycle[j]
606 assert b.level == LEVEL_EMBED
607 assert b.parent.level in (LEVEL_EVEN, LEVEL_ODD)
608 assert b.owner is self
609 b.level = 1 - b.parent.level
610 b.parent.children.add(b)
611 b.owner = None
612 pairs_start, pairs_end = i - 1, 0
614 # Turn the odd-length part of the cycle into out-of-tree pairs.
615 for j in range(pairs_start, pairs_end, -2):
616 b = self.cycle[j]
617 peer = b.parent
618 assert b.level == LEVEL_EMBED
619 assert b.parent.level == LEVEL_EMBED
620 assert b.owner is self
621 assert peer.owner is self
622 assert len(b.children) == 0
623 assert len(peer.children) == 0
624 peer.parent, peer.parent_edge = b, b.parent_edge
625 b.owner = peer.owner = None
626 b.level = peer.level = LEVEL_OOT
628 if self.parent is None:
629 registry = roots
630 else:
631 registry = self.parent.children
632 registry.remove(self)
633 registry.add(self.cycle[i])
635 raise TreeStructureChanged("Expanded a blossom")
638class Vertex(Blossom):
639 # Single vertices are allowed to have negative charges.
640 minimum_charge = -INF
642 def __init__(self, idi):
643 self.id = idi
644 self.edges = []
645 super(Vertex, self).__init__(cycle=[self], charge=0)
647 def __str__(self):
648 return "%d" % (self.id,)
650 def add_edge_to(self, other, value):
651 e = Edge(self, other, value)
652 self.edges.append(e)
653 other.edges.append(e)
655 @cached_property
656 def outgoing_edges(self):
657 return [(e, e.traverse_from(self)) for e in self.edges]
659 @cached_property
660 def members(self):
661 return {self}
663 def get_base_vertex(self):
664 return self
666 def expand(self, roots):
667 # For simple vertices this is a no-op.
668 pass
671def get_max_delta(roots):
672 """
673 Returns the maximal value by which we can improve the dual solution
674 by adjusting charges on alternating trees.
675 """
676 if len(roots) == 0:
677 # All blossoms are matched.
678 raise MaximumDualReached()
680 delta = INF
681 for root in roots:
682 delta = min(delta, root.get_max_delta())
684 assert delta >= 0
686 if delta <= 0:
687 raise MaximumDualReached()
689 return delta
692def read_input(input_file):
693 N, M = [int(x) for x in next(input_file).split()]
694 if N is None or M is None:
695 raise ValueError("Unable to process '{0}'".format(input_file))
696 vertices = {}
698 for line in input_file:
699 u, v, w = [int(x) for x in line.split()]
700 for _v in (u, v):
701 if _v not in vertices:
702 vertices[_v] = Vertex(_v)
704 u, v = vertices[u], vertices[v]
705 u.add_edge_to(v, fractions.Fraction(w))
707 return vertices
710def update_tree_structures(roots):
711 try:
712 while True:
713 try:
714 for root in roots:
715 root.alter_tree()
716 raise StructureUpToDate()
717 except TreeStructureChanged:
718 pass
719 except StructureUpToDate:
720 pass
723if __name__ == "__main__":
725 def main_blossom():
727 if len(sys.argv) > 1:
728 input_file = open(sys.argv[1]) # pylint: disable=R1732
729 else:
730 input_file = sys.stdin
731 vertices = read_input(input_file)
732 aroots = set(vertices.values())
733 try:
734 while True:
735 delta = get_max_delta(aroots)
736 sys.stderr.write("Adjusting by %s\n" % (delta,))
737 for root_ in aroots:
738 root_.adjust_charge(delta)
739 update_tree_structures(aroots)
740 except MaximumDualReached:
741 pass
743 M = set()
744 for v in vertices.values():
745 M.update(e for e in v.edges if e.selected)
747 total_weight = sum(e.value for e in M)
748 print(total_weight)
749 for e in M:
750 print("%s %s" % (e, e.value))
752 main_blossom()