# Segment Tree {bdg-danger}`TODO` ## What is a segment tree? Imagine that you are given an array of integers, namely `A`, of length `n`. There are total ${{n+1} \choose 2} = n(n+1)/2$ non-empty subarrays, written as `A[i:j+1]` with `0<=i<=j None: if l == r: self.sum[root] = arr[l] self.min[root] = arr[l] self.max[root] = arr[l] return mid = (l+r)//2 self.build(root*2, l, mid, arr) self.build(root*2+1, mid+1, r, arr) self.sum[root] = self.sum[root*2] + self.sum[root*2+1] self.min[root] = min(self.min[root*2], self.min[root*2+1]) self.max[root] = max(self.max[root*2], self.max[root*2+1]) ``` - `root`: current root node of the subtree to build, represented as an int - `l` and `r`: the subarray `arr[l:r+1]` that we need to build the segment tree - **Time complexity**: $O(n)$ When initialize the segment tree with an known array, we can use `self.build(1,0,n-1,arr)`. Note that we always start a tree rooted at node `1` because it gives the correct indexes of child nodes using the formula above (instead `0` can't). ### Update (Single Element) To update the value of a **single** element in the array, we first need to find all intervals that contains that element, which must be a path from the root node to a leaf. ```py def update(self, root:int, l:int, r:int, idx:int, val:int) -> None: if l == r: self.sum[root] = val self.min[root] = val self.max[root] = val return mid = (l+r)//2 if idx <= mid: self.update(root*2, l, mid, idx, val) else: self.update(root*2+1, mid+1, r, idx, val) self.sum[root] = self.sum[root*2] + self.sum[root*2+1] self.min[root] = min(self.min[root*2], self.min[root*2+1]) self.max[root] = max(self.max[root*2], self.max[root*2+1]) ``` - `root`: current root node of the subtree that represents the current interval - `l` and `r`: the endpoints of the current interval - `idx`: the index of element in the raw array to update - `val`: the value used to update the element at `idx` - **Time Complexity:** $O(\log n)$ - **Entry**: `self.update(root=1,l=0,r=n-1,idx=idx,val=val)` ### Query To query some data (e.g., `sum` here) of an interval `[L,R]`, we search the intervals level by level: 1. if the current intevals is totally included in `[L,R]`, just return the data of it. 2. if not, check if the left child (`[L,mid]`) and the right child (`[mid+1,R]`) recursively. ```py def query_sum(self, root:int, l:int, r:int, L:int, R:int) -> int: if l >= L and r <= R: return self.sum[root] mid = (l + r) // 2 res = 0 if L <= mid: res += self.query_sum(root*2, l, mid, L, R) if R > mid: res += self.query_sum(root*2+1, mid+1, r, L, R) return res ``` - `root`: current root node of the subtree that represents the current interval - `l` and `r`: the endpoints of the current interval - `L` and `R`: the endpoints of the interval to query - **Time Complexity:** $O(\log n)$ - **Entry**: `self.query_sum(root=1,l=0,r=n-1,L=L,R=R)` ## Lazy Propagation To update a range of elements in the array, we don't have to update every single element in the range. Instead, we can update the interval represented by a TreeNode and mark it as `lazy` to indicate that its children are not up-to-date. When we need to query the data of a TreeNode, we first check if it is `lazy` and update its children if necessary. ```py def push_down(self, root:int, l:int, r:int) -> None: if self.lazy[root] != 0: self.sum[root] += (r-l+1) * self.lazy[root] self.min[root] += self.lazy[root] self.max[root] += self.lazy[root] if l != r: self.lazy[root*2] += self.lazy[root] self.lazy[root*2+1] += self.lazy[root] self.lazy[root] = 0 ``` - `root`: current root node of the subtree that represents the current interval - `l` and `r`: the endpoints of the current interval - **Time Complexity:** $O(1)$ - **Entry**: `self.push_down(root=1,l=0,r=n-1)` ### Update Range To update a range of elements in the array, we can use the `lazy` propagation technique. We only need to update the current node and mark it as `lazy`, without updating its children immediately. ```py def update_range(self, root:int, l:int, r:int, L:int, R:int, val:int) -> None: if l > R or r < L: return if l >= L and r <= R: self.lazy[root] += val self.push_down(root, l, r) return mid = (l + r) // 2 self.update_range(root*2, l, mid, L, R, val) self.update_range(root*2+1, mid+1, r, L, R, val) self.sum[root] = self.sum[root*2] + self.sum[root*2+1] self.min[root] = min(self.min[root*2], self.min[root*2+1]) self.max[root] = max(self.max[root*2], self.max[root*2+1]) ``` ## Read More See [^1] [^1]: https://cp-algorithms.com/data_structures/segment_tree.html ### Comparison with Binary Indexed Tree (Fenwick Tree) BIT is a data structure that can also be used to maintain the prefix sum of an array. It is more space-efficient than a segment tree, as it only requires `n` space to store the prefix sums. However, it can only support point updates and prefix sum queries, while a segment tree can support range updates and range queries. Therefore, if you need to perform range updates or range queries, a segment tree is more suitable. ## Alternative Templates Some subtle modifications can be made to the template above to fit different scenarios. ````{dropdown} With hashmap: query how many times an element appears in a range ```py class SegmentTree: def __init__(self, nums): self.n = len(nums) self.tree = [{} for _ in range(4 * self.n)] self.build(nums, 0, 0, self.n - 1) def build(self, nums, node, l, r): if l == r: self.tree[node] = {nums[l]: 1} return mid = (l + r) // 2 self.build(nums, 2*node+1, l, mid) self.build(nums, 2*node+2, mid+1, r) self.tree[node] = self.merge(self.tree[2*node+1], self.tree[2*node+2]) def merge(self, left, right): merged = left.copy() for key, val in right.items(): merged[key] = merged.get(key, 0) + val return merged def update(self, index, old_val, new_val, node=0, l=0, r=None): if r is None: r = self.n - 1 if index < l or index > r: return if old_val in self.tree[node]: self.tree[node][old_val] -= 1 if self.tree[node][old_val] == 0: del self.tree[node][old_val] self.tree[node][new_val] = self.tree[node].get(new_val, 0) + 1 if l != r: mid = (l + r) // 2 self.update(index, old_val, new_val, 2*node+1, l, mid) self.update(index, old_val, new_val, 2*node+2, mid+1, r) def query(self, ql, qr, k, node=0, l=0, r=None): if r is None: r = self.n - 1 if qr < l or ql > r: return 0 if ql <= l and r <= qr: return self.tree[node].get(k, 0) mid = (l + r) // 2 return self.query(ql, qr, k, 2*node+1, l, mid) + self.query(ql, qr, k, 2*node+2, mid+1, r) ``` ```` ````{dropdown} Binary search on a segment tree Binary search to find the first occurrence $> k$ A template question [LC3479](https://leetcode.com/problems/fruits-into-baskets-iii): let the node store the *max* value of a range, first check if the left subtree has a value greater than `k`, if not, then check the right subtree. ```py class SegmentTree: def __init__(self, data): self.n = len(data) self.tree = [0] * (4 * self.n) self.build(data, 0, 0, self.n - 1) def build(self, data, node, start, end): if start == end: self.tree[node] = data[start] # Store index directly else: mid = (start + end) // 2 self.build(data, 2 * node + 1, start, mid) self.build(data, 2 * node + 2, mid + 1, end) self.tree[node] = max(self.tree[2 * node + 1], self.tree[2 * node + 2]) def find_and_delete(self, node, start, end, l, r, x): # find the lowest index in the range [l, r] that is greater than or equal to x if start > r or end < l: return False if start == end: if self.tree[node] >= x: self.tree[node] = 0 return True return False mid = (start + end) // 2 if self.tree[node] < x: return False if self.tree[2 * node + 1] >= x: if self.find_and_delete(2 * node + 1, start, mid, l, r, x): self.tree[node] = max(self.tree[2 * node + 1], self.tree[2 * node + 2]) return True if self.tree[2 * node + 2] >= x: if self.find_and_delete(2 * node + 2, mid + 1, end, l, r, x): self.tree[node] = max(self.tree[2 * node + 1], self.tree[2 * node + 2]) return True return False ``` Exercise: [LC3721](https://leetcode.com/problems/longest-balanced-subarray-ii/description/) ```` ````{dropdown} Segment tree on sets [LC.3901 Good Subsequence Queries](https://leetcode.com/problems/good-subsequence-queries/) ```py from math import gcd class Solution: def countGoodSubseq(self, nums: list[int], p: int, queries: list[list[int]]) -> int: n = len(nums) def trans(x): return x // p if x % p == 0 else 0 size = 1 while size < n: size <<= 1 tree_g = [0] * (size << 1) tree_rem = [() for _ in range(size << 1)] def pull(i): l = i << 1 r = l | 1 gl, gr = tree_g[l], tree_g[r] tree_g[i] = gcd(gl, gr) tmp = set() for x in tree_rem[l]: tmp.add(gcd(x, gr)) for x in tree_rem[r]: tmp.add(gcd(gl, x)) tree_rem[i] = tuple(tmp) cnt_div = 0 for i, x in enumerate(nums): v = trans(x) if v > 0: cnt_div += 1 idx = size + i tree_g[idx] = v tree_rem[idx] = (0,) for i in range(n, size): idx = size + i tree_g[idx] = 0 tree_rem[idx] = () for i in range(size - 1, 0, -1): pull(i) res = 0 for idx, val in queries: old_v = trans(nums[idx]) new_v = trans(val) if old_v == 0 and new_v > 0: cnt_div += 1 elif old_v > 0 and new_v == 0: cnt_div -= 1 nums[idx] = val pos = size + idx tree_g[pos] = new_v tree_rem[pos] = (0,) pos >>= 1 while pos: pull(pos) pos >>= 1 if cnt_div == 0: continue if cnt_div < n: if tree_g[1] == 1: res += 1 else: if 1 in tree_rem[1]: res += 1 return res ``` ```` ## Exercises - [LC.3691 Maximum Total Subarray Value II](https://leetcode.com/problems/maximum-total-subarray-value-ii/): Segment tree for range query (`min` and `max`) + heap for top-k