Union Find¶
Disjoint Set Union (DSU) (or Union Find in some contexts) is a data structure that keeps track of a set of elements partitioned into a number of disjoint (non-overlapping) subsets. It provides two main operations: find and union.
Basic Implementation¶
The very basic implementation of Union Find can be done using a list to represent the parent of each element. Initially, each element is its own parent (i.e., each element is in its own set, I use root to represent the parent of each element). The find function is used to find the representative (or “root”) of the set that an element belongs to, while the union function is used to merge two sets together.
Notice that in the find function, we perform path compression by setting the parent of each visited node directly to the root.
root = list(range(n))
def find(x):
if x == root[x]: return x
# path compression
root[x] = find(root[x])
return root[x]
def union(x,y):
x, y = find(x),find(y)
root[y] = x
This basic implementation has a time complexity of \(O(n)\) for the find operation in the worst case, which can happen when the tree becomes very deep.
Optimization:Heuristical Union¶
To optimize the time complexity, we can use union by rank (or union by size), which can help keep the tree flat. By the trick of union by rank, we always attach the smaller tree under the root of the larger tree. This way, we can ensure that the depth of the tree remains logarithmic.
The easier and more intuitive way is to keep track of the size of each tree and always attach the smaller tree under the root of the larger tree.
size = [1] * n
def union(x,y):
x, y = find(x),find(y)
if x == y: return
if size[x] < size[y]:
x, y = y, x
root[y] = x
size[x] += size[y]
Here, rank refers to the depth of the tree. We always attach the tree with smaller rank under the root of the tree with larger rank. If both trees have the same rank, we can choose one as the new root and increase its rank by 1.
rank = [0] * n
def union(x,y):
x, y = find(x),find(y)
if x == y: return
if rank[x] < rank[y]:
x, y = y, x
elif rank[x] == rank[y]:
rank[x] += 1
root[y] = x
Sometimes, we feel hard to understand the concept of rank, so we usually use size instead of rank to achieve the same effect.
Time Complexity: \(O(\alpha(n))\) for both find and union, where \(\alpha(n)\) is the inverse Ackermann function, which grows very slowly and is practically constant for all reasonable values of \(n\). By using both path compression and union by rank/size, we can achieve almost constant time complexity for both operations.
Variant: Weighted DSU¶
We can also maintain some additional information for nodes that related to the parent/root of each element, not only the parent itself. For example, we can maintain the distance from each node to its parent, and then we can easily calculate the distance between any two nodes in the same set by using the distance to their common ancestor (the root of the set). This is called Weighted DSU.
root = list(range(n))
dist = [0] * n
def find(x):
if x != root[x]:
# path compression
origin = root[x]
root[x] = find(root[x])
dist[x] += dist[origin]
return root[x]
def union(x, y, w):
x, y = find(x), find(y)
if x == y: return
root[y] = x
dist[y] = dist[x] + w - dist[y]
For example, in LC 3887. Incremental Even Weighted Cycle Queries, for every connected component, we can maintain the distance from each node to the root of the component. More specifically, only the parity of the distance matters in this problem, so we can maintain the distance modulo 2. Once we are about to union two nodes in the same component, we need to check if the parity of the distance from these two nodes to their common ancestor (the root of the component) plus the weight of the edge between them is even or not. If it is even, we can add this edge without creating an odd cycle, otherwise, we cannot add this edge.
class Solution:
def numberOfEdgesAdded(self, n: int, edges: List[List[int]]) -> int:
res = 0
parent = list(range(n))
dist = [0] * n
size = [1] * n
def find(x: int) -> int:
if parent[x] != x:
p, d = find(parent[x])
parent[x] = p
dist[x] += d
return parent[x], dist[x]
def union(x: int, y: int, w: int) -> bool:
px, dx = find(x)
py, dy = find(y)
if px == py:
return False
if size[px] < size[py]:
px, py, dx, dy = py, px, dy, dx
parent[py] = px
dist[py] = w + dx - dy
size[px] += size[py]
return True
for u, v, w in edges:
pu, du = find(u)
pv, dv = find(v)
if pu != pv:
union(u, v, w)
res += 1
elif (du + dv + w) % 2 == 0:
res += 1
return res