Segment Tree

TODO

What is a segment tree?

Imagine that you are given an array of integers, namely A, of length n. There are total \({{n+1} \choose 2} = n(n+1)/2\) non-empty subarrays, written as A[i:j+1] with 0<=i<=j<n. To maintain some statistics (e.g. sum, min, max and etc.) about each subarray for further queries, we have some straightforward solutions:

  • If we naively keep all \(n(n+1)/2\) subarrays physically, we need \(O(n^2)\) spaces, while every time we need to query some information about a specified interval [i,j], it costs \(O(1)\) time. However, if you need to update a single element in A, to keep all data about all intervals up-to-date, you still need \(O(n^2)\) time to update all intervals containing the element.

  • In opposite, if we only keep every single element for each index (every A[i]), in other words, n intervals of length 1, as what a built-in array exactly does, we only require \(O(n)\) space to build them and \(O(1)\) to update any single element. But the time cost for each query of interval data increases to \(O(n)\).

Yes, as you see, this is a trade-off between space and time, as well as a trade-off between construction, the following updates, and queries. A segment tree is designed to balance those costs and break the dilemma, which offers better performance to build the array quickly and supports numerous times of both update and query operations.

We build a segment tree from an array by recursively dividing every subarray [left,right] into two smaller subarrays as [left,mid] and [mid+1,right], where mid=(left+right)//2. Every subarray [left,right] is represented as a TreeNode, with left child [left,mid] and right child [mid+1,right]. For example, assume we have an array A initialized as:

A = [0, 1, 2, 3, 4, 5, 6, 7]

First, we build the segment tree as the picture below and track the sum of every subarray represented by a TreeNode. We start to talk about segment trees from a very special case, as a perfect binary tree:

We see that we create 2n-1 TreeNodes for different subarrays. The leaf nodes are used to store single elements (interval [i,i]). For each query or update request, we go through a post-order DFS from the root of the entire tree (the array [0,n-1]). For example, in the case above, every sum value for each TreeNode is calculated by adding the sum of two child nodes. Therefore, whenever we request to update a single element A[i] contained in either child node, we must first traverse to any TreeNode that includes this index i and then go to one of its children. After any child node is modified, the sum of the parent node must also be re-calculated by adding the two updated sums of its children.

Note

Even though each node in the picture is depicted as a subarray, remember that it is just for demonstration. a TreeNode never stores any specific value of elements in the interval [i,j], instead it only keeps track of some information about the interval (e.g., sum in this case) and pointers to its children, which always take up \(O(1)\) space for each TreeNode.

Let’s talk about querying the sum of an interval [i,j], which may not be represented exactly by any single TreeNode in the tree right now. For example, if we want to query sum(A[1:8]), we need to start a traversal like:

Start from the root node, we search for the intervals:

  • If current node is completely contained in [1,7], just return the sum of it.

  • If the left half of the interval represented by current node, search the left child and add the returned sum of it to current result.

  • Similarly, if the right half of the interval represented by current node, search the right child and add the returned sum of it to current result.

Template (for fixed n-length array)

Initialize

If you have already confirmed that it is a fixed n-length array and you want to build a segment tree for it, you can represent such a binary tree by a 1-D array. The height of such a binary tree is \(\lceil \log n \rceil\), which can have \(2^{(\lceil \log n \rceil + 1)} - 1 \le 4n - 1\) at most. Thus, we usually create an array of length 4n to represent nodes in that segment tree.

For example, if we are only interested in info including sum, min and max of each interval, we create a 4n-length array for each of them.

class SegTree:
    def __init__(self, n:int):
        self.n = n
        self.sum = [0]*(4*n)
        self.min = [0]*(4*n)
        self.max = [0]*(4*n)

Recall that how we use a linear array nodes to represent a binary tree (index from 1): if a node is stored at nodes[i], then its left child node is stored at nodes[2*i] while its right child node is stored at nodes[2*i+1].

Build from Array

Then, if we have an array at start and want to build the segment tree from it. We use build(self, root:int,l:int, r:int, arr:list) defined below:

    def build(self, root:int,l:int, r:int, arr:list) -> None:
        if l == r:
            self.sum[root] = arr[l]
            self.min[root] = arr[l]
            self.max[root] = arr[l]
            return
        mid = (l+r)//2
        self.build(root*2, l, mid, arr)
        self.build(root*2+1, mid+1, r, arr)
        self.sum[root] = self.sum[root*2] + self.sum[root*2+1]
        self.min[root] = min(self.min[root*2], self.min[root*2+1])
        self.max[root] = max(self.max[root*2], self.max[root*2+1])
  • root: current root node of the subtree to build, represented as an int

  • l and r: the subarray arr[l:r+1] that we need to build the segment tree

  • Time complexity: \(O(n)\)

When initialize the segment tree with an known array, we can use self.build(1,0,n-1,arr). Note that we always start a tree rooted at node 1 because it gives the correct indexes of child nodes using the formula above (instead 0 can’t).

Update (Single Element)

To update the value of a single element in the array, we first need to find all intervals that contains that element, which must be a path from the root node to a leaf.

    def update(self, root:int, l:int, r:int, idx:int, val:int) -> None:
        if l == r:
            self.sum[root] = val
            self.min[root] = val
            self.max[root] = val
            return
        mid = (l+r)//2
        if idx <= mid:
            self.update(root*2, l, mid, idx, val)
        else:
            self.update(root*2+1, mid+1, r, idx, val)
        self.sum[root] = self.sum[root*2] + self.sum[root*2+1]
        self.min[root] = min(self.min[root*2], self.min[root*2+1])
        self.max[root] = max(self.max[root*2], self.max[root*2+1])
  • root: current root node of the subtree that represents the current interval

  • l and r: the endpoints of the current interval

  • idx: the index of element in the raw array to update

  • val: the value used to update the element at idx

  • Time Complexity: \(O(\log n)\)

  • Entry: self.update(root=1,l=0,r=n-1,idx=idx,val=val)

Query

To query some data (e.g., sum here) of an interval [L,R], we search the intervals level by level:

  1. if the current intevals is totally included in [L,R], just return the data of it.

  2. if not, check if the left child ([L,mid]) and the right child ([mid+1,R]) recursively.

    def query_sum(self, root:int, l:int, r:int, L:int, R:int) -> int:
        if l >= L and r <= R:
            return self.sum[root]
        mid = (l + r) // 2
        res = 0
        if L <= mid:
            res += self.query_sum(root*2, l, mid, L, R)
        if R > mid:
            res += self.query_sum(root*2+1, mid+1, r, L, R)
        return res
  • root: current root node of the subtree that represents the current interval

  • l and r: the endpoints of the current interval

  • L and R: the endpoints of the interval to query

  • Time Complexity: \(O(\log n)\)

  • Entry: self.query_sum(root=1,l=0,r=n-1,L=L,R=R)

Lazy Propagation

To update a range of elements in the array, we don’t have to update every single element in the range. Instead, we can update the interval represented by a TreeNode and mark it as lazy to indicate that its children are not up-to-date. When we need to query the data of a TreeNode, we first check if it is lazy and update its children if necessary.

    def push_down(self, root:int, l:int, r:int) -> None:
        if self.lazy[root] != 0:
            self.sum[root] += (r-l+1) * self.lazy[root]
            self.min[root] += self.lazy[root]
            self.max[root] += self.lazy[root]
            if l != r:
                self.lazy[root*2] += self.lazy[root]
                self.lazy[root*2+1] += self.lazy[root]
            self.lazy[root] = 0
  • root: current root node of the subtree that represents the current interval

  • l and r: the endpoints of the current interval

  • Time Complexity: \(O(1)\)

  • Entry: self.push_down(root=1,l=0,r=n-1)

Update Range

To update a range of elements in the array, we can use the lazy propagation technique. We only need to update the current node and mark it as lazy, without updating its children immediately.

    def update_range(self, root:int, l:int, r:int, L:int, R:int, val:int) -> None:
        if l > R or r < L:
            return
        if l >= L and r <= R:
            self.lazy[root] += val
            self.push_down(root, l, r)
            return
        mid = (l + r) // 2
        self.update_range(root*2, l, mid, L, R, val)
        self.update_range(root*2+1, mid+1, r, L, R, val)
        self.sum[root] = self.sum[root*2] + self.sum[root*2+1]
        self.min[root] = min(self.min[root*2], self.min[root*2+1])
        self.max[root] = max(self.max[root*2], self.max[root*2+1])

Read More

See [1]

Comparison with Binary Indexed Tree (Fenwick Tree)

BIT is a data structure that can also be used to maintain the prefix sum of an array. It is more space-efficient than a segment tree, as it only requires n space to store the prefix sums. However, it can only support point updates and prefix sum queries, while a segment tree can support range updates and range queries. Therefore, if you need to perform range updates or range queries, a segment tree is more suitable.

Alternative Templates

Some subtle modifications can be made to the template above to fit different scenarios.

With hashmap: query how many times an element appears in a range
class SegmentTree:
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [{} for _ in range(4 * self.n)]
        self.build(nums, 0, 0, self.n - 1)
    
    def build(self, nums, node, l, r):
        if l == r:
            self.tree[node] = {nums[l]: 1}
            return
        mid = (l + r) // 2
        self.build(nums, 2*node+1, l, mid)
        self.build(nums, 2*node+2, mid+1, r)
        self.tree[node] = self.merge(self.tree[2*node+1], self.tree[2*node+2])
    
    def merge(self, left, right):
        merged = left.copy()
        for key, val in right.items():
            merged[key] = merged.get(key, 0) + val
        return merged

    def update(self, index, old_val, new_val, node=0, l=0, r=None):
        if r is None:
            r = self.n - 1
        if index < l or index > r:
            return
        if old_val in self.tree[node]:
            self.tree[node][old_val] -= 1
            if self.tree[node][old_val] == 0:
                del self.tree[node][old_val]
        self.tree[node][new_val] = self.tree[node].get(new_val, 0) + 1
        if l != r:
            mid = (l + r) // 2
            self.update(index, old_val, new_val, 2*node+1, l, mid)
            self.update(index, old_val, new_val, 2*node+2, mid+1, r)

    def query(self, ql, qr, k, node=0, l=0, r=None):
        if r is None:
            r = self.n - 1
        if qr < l or ql > r:
            return 0
        if ql <= l and r <= qr:
            return self.tree[node].get(k, 0)
        mid = (l + r) // 2
        return self.query(ql, qr, k, 2*node+1, l, mid) + self.query(ql, qr, k, 2*node+2, mid+1, r)
Binary search on a segment tree

Binary search to find the first occurrence \(> k\)

A template question LC3479: let the node store the max value of a range, first check if the left subtree has a value greater than k, if not, then check the right subtree.

class SegmentTree:
    def __init__(self, data):
        self.n = len(data)
        self.tree = [0] * (4 * self.n)
        self.build(data, 0, 0, self.n - 1)

    def build(self, data, node, start, end):
        if start == end:
            self.tree[node] = data[start]  # Store index directly
        else:
            mid = (start + end) // 2
            self.build(data, 2 * node + 1, start, mid)
            self.build(data, 2 * node + 2, mid + 1, end)
            self.tree[node] = max(self.tree[2 * node + 1], self.tree[2 * node + 2])


    def find_and_delete(self, node, start, end, l, r, x):
        # find the lowest index in the range [l, r] that is greater than or equal to x
        if start > r or end < l:
            return False
        if start == end:
            if self.tree[node] >= x:
                self.tree[node] = 0
                return True
            return False
        mid = (start + end) // 2
        if self.tree[node] < x:
            return False
        if self.tree[2 * node + 1] >= x:
            if self.find_and_delete(2 * node + 1, start, mid, l, r, x):
                self.tree[node] = max(self.tree[2 * node + 1], self.tree[2 * node + 2])
                return True
        if self.tree[2 * node + 2] >= x:
            if self.find_and_delete(2 * node + 2, mid + 1, end, l, r, x):
                self.tree[node] = max(self.tree[2 * node + 1], self.tree[2 * node + 2])
                return True
        return False

Exercise: LC3721

Segment tree on sets

LC.3901 Good Subsequence Queries

from math import gcd

class Solution:
    def countGoodSubseq(self, nums: list[int], p: int, queries: list[list[int]]) -> int:
        n = len(nums)

        def trans(x):
            return x // p if x % p == 0 else 0

        size = 1
        while size < n:
            size <<= 1

        tree_g = [0] * (size << 1)
        tree_rem = [() for _ in range(size << 1)]

        def pull(i):
            l = i << 1
            r = l | 1
            gl, gr = tree_g[l], tree_g[r]
            tree_g[i] = gcd(gl, gr)

            tmp = set()
            for x in tree_rem[l]:
                tmp.add(gcd(x, gr))
            for x in tree_rem[r]:
                tmp.add(gcd(gl, x))
            tree_rem[i] = tuple(tmp)

        cnt_div = 0
        for i, x in enumerate(nums):
            v = trans(x)
            if v > 0:
                cnt_div += 1
            idx = size + i
            tree_g[idx] = v
            tree_rem[idx] = (0,)

        for i in range(n, size):
            idx = size + i
            tree_g[idx] = 0
            tree_rem[idx] = ()

        for i in range(size - 1, 0, -1):
            pull(i)

        res = 0

        for idx, val in queries:
            old_v = trans(nums[idx])
            new_v = trans(val)

            if old_v == 0 and new_v > 0:
                cnt_div += 1
            elif old_v > 0 and new_v == 0:
                cnt_div -= 1

            nums[idx] = val

            pos = size + idx
            tree_g[pos] = new_v
            tree_rem[pos] = (0,)

            pos >>= 1
            while pos:
                pull(pos)
                pos >>= 1

            if cnt_div == 0:
                continue

            if cnt_div < n:
                if tree_g[1] == 1:
                    res += 1
            else:
                if 1 in tree_rem[1]:
                    res += 1

        return res

Exercises