线段树(Segment Tree)

线段树是一种二叉树结构,用于高效处理区间查询和区间更新。可以在 O(log n) 时间内完成「区间求和」「区间最大值」等操作。

什么是线段树?

线段树是一种完全二叉树,每个节点代表一个区间:

数组: [1, 3, 5, 7, 9, 11, 13, 15]
        (下标 0-7)

线段树结构:
                    [1, 15] (根节点:整个区间)
                   /        \
            [1, 8]           [9, 15]
           /    \            /      \
      [1, 4]   [5, 8]    [9, 12]   [13, 15]
      /    \    /   \     /    \     /    \
   [1,2] [3,4][5,6][7,8][9,10][11,12][13,14][15,15]
   /  \   /  \  /  \  / \   /  \   /  \    /  \
   [1][3] [5][7][9][11][13][15] ...

为什么是二叉树?

  • 每个节点把区间劈成两半
  • 叶子节点是单个元素
  • 深度约等于 log₂(n)

线段树 vs 树状数组

特性 线段树 树状数组
功能 通用(任意区间操作) 特定(前缀和为主)
代码复杂度 较高 低(几行搞定)
区间更新 O(log n) O(log n)
区间查询 O(log n) O(log n)
懒更新 支持 不支持
灵活性

选择原则

  • 只会「单点更新 + 前缀和」→ 用树状数组
  • 需要「区间更新 + 区间查询」→ 必须用线段树

线段树模板

基本结构

class SegmentTree:
    def __init__(self, nums):
        self.n = len(nums)
        # 4倍空间:线段树节点数不超过 4n
        self.tree = [0] * (4 * self.n)
        if self.n > 0:
            self._build(1, 0, self.n - 1, nums)
    
    def _build(self, node, start, end, nums):
        """递归构建线段树"""
        if start == end:
            self.tree[node] = nums[start]
        else:
            mid = (start + end) // 2
            left_node = node * 2
            right_node = node * 2 + 1
            
            self._build(left_node, start, mid, nums)
            self._build(right_node, mid + 1, end, nums)
            
            # 根据需求修改:求和/最大值/最小值
            self.tree[node] = self.tree[left_node] + self.tree[right_node]
    
    def query(self, left, right):
        """区间查询 [left, right]"""
        return self._query(1, 0, self.n - 1, left, right)
    
    def _query(self, node, start, end, left, right):
        if left <= start and end <= right:
            # 当前区间完全在查询范围内
            return self.tree[node]
        
        if end < left or start > right:
            # 完全不在范围内
            return 0
        
        # 部分重叠,递归查询
        mid = (start + end) // 2
        left_sum = self._query(node * 2, start, mid, left, right)
        right_sum = self._query(node * 2 + 1, mid + 1, end, left, right)
        
        return left_sum + right_sum
    
    def update(self, index, val):
        """单点更新"""
        self._update(1, 0, self.n - 1, index, val)
    
    def _update(self, node, start, end, index, val):
        if start == end:
            self.tree[node] = val
            return
        
        mid = (start + end) // 2
        if index <= mid:
            self._update(node * 2, start, mid, index, val)
        else:
            self._update(node * 2 + 1, mid + 1, end, index, val)
        
        self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]

LeetCode 实战

题目 1:区域和检索 - 数组不可变

LeetCode 303 - 前缀和变种,用线段树(或直接前缀和)。

class NumArray:
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)
        if self.n > 0:
            self._build(1, 0, self.n - 1, nums)
    
    def _build(self, node, l, r, nums):
        if l == r:
            self.tree[node] = nums[l]
        else:
            mid = (l + r) // 2
            self._build(node * 2, l, mid, nums)
            self._build(node * 2 + 1, mid + 1, r, nums)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def sumRange(self, left, right):
        return self._query(1, 0, self.n - 1, left, right)
    
    def _query(self, node, l, r, ql, qr):
        if ql <= l and r <= qr:
            return self.tree[node]
        if r < ql or l > qr:
            return 0
        mid = (l + r) // 2
        return (self._query(node * 2, l, mid, ql, qr) +
                self._query(node * 2 + 1, mid + 1, r, ql, qr))

题目 2:区域和检索 - 数组可修改

LeetCode 307 - 支持单点更新的区间求和。

class NumArray:
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)
        if self.n > 0:
            self._build(1, 0, self.n - 1, nums)
    
    def _build(self, node, l, r, nums):
        if l == r:
            self.tree[node] = nums[l]
        else:
            mid = (l + r) // 2
            self._build(node * 2, l, mid, nums)
            self._build(node * 2 + 1, mid + 1, r, nums)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def update(self, index, val):
        self._update(1, 0, self.n - 1, index, val)
    
    def _update(self, node, l, r, idx, val):
        if l == r:
            self.tree[node] = val
        else:
            mid = (l + r) // 2
            if idx <= mid:
                self._update(node * 2, l, mid, idx, val)
            else:
                self._update(node * 2 + 1, mid + 1, r, idx, val)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def sumRange(self, left, right):
        return self._query(1, 0, self.n - 1, left, right)
    
    def _query(self, node, l, r, ql, qr):
        if ql <= l and r <= qr:
            return self.tree[node]
        if r < ql or l > qr:
            return 0
        mid = (l + r) // 2
        return (self._query(node * 2, l, mid, ql, qr) +
                self._query(node * 2 + 1, mid + 1, r, ql, qr))

懒更新线段树(Lazy Propagation)

为什么需要懒更新?

普通线段树的「区间更新」是 O(n) 的:每个节点都要更新。

懒更新:先记录这个更新,等查询时再真正应用

class LazySegmentTree:
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)
        self.lazy = [0] * (4 * self.n)  # 懒标记
        if self.n > 0:
            self._build(1, 0, self.n - 1, nums)
    
    def _build(self, node, l, r, nums):
        if l == r:
            self.tree[node] = nums[l]
        else:
            mid = (l + r) // 2
            self._build(node * 2, l, mid, nums)
            self._build(node * 2 + 1, mid + 1, r, nums)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def _push_down(self, node, l, r):
        """下推懒标记到子节点"""
        if self.lazy[node] != 0:
            mid = (l + r) // 2
            left, right = node * 2, node * 2 + 1
            
            # 应用到左子节点
            self.lazy[left] += self.lazy[node]
            self.tree[left] += self.lazy[node] * (mid - l + 1)
            
            # 应用到右子节点
            self.lazy[right] += self.lazy[node]
            self.tree[right] += self.lazy[node] * (r - mid)
            
            # 清除当前节点的懒标记
            self.lazy[node] = 0
    
    def range_update(self, ql, qr, val):
        """区间更新 [ql, qr] 加上 val"""
        self._range_update(1, 0, self.n - 1, ql, qr, val)
    
    def _range_update(self, node, l, r, ql, qr, val):
        if ql <= l and r <= qr:
            # 完全覆盖,直接应用
            self.tree[node] += val * (r - l + 1)
            self.lazy[node] += val
            return
        
        if r < ql or l > qr:
            return
        
        # 下推懒标记
        self._push_down(node, l, r)
        
        mid = (l + r) // 2
        self._range_update(node * 2, l, mid, ql, qr, val)
        self._range_update(node * 2 + 1, mid + 1, r, ql, qr, val)
        
        self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def range_query(self, ql, qr):
        """区间查询 [ql, qr]"""
        return self._range_query(1, 0, self.n - 1, ql, qr)
    
    def _range_query(self, node, l, r, ql, qr):
        if ql <= l and r <= qr:
            return self.tree[node]
        
        if r < ql or l > qr:
            return 0
        
        self._push_down(node, l, r)
        
        mid = (l + r) // 2
        return (self._range_query(node * 2, l, mid, ql, qr) +
                self._range_query(node * 2 + 1, mid + 1, r, ql, qr))

题目 3:我的日程安排表 II

LeetCode 731 - 区间添加(可能有重叠),查询某区间内被覆盖的最多次数。

class MyCalendarTwo:
    def __init__(self):
        self.tree = {}
        self.lazy = {}
    
    def book(self, start, end):
        """在 [start, end) 添加日程,返回是否有三重重叠"""
        if self.query(0, 10**9, start, end - 1, 0, 10**9) < 2:
            self.update(0, 10**9, start, end - 1, 1, 0, 10**9)
            return True
        return False
    
    def update(self, nl, nr, ql, qr, val, node, l, r):
        if ql <= nl and nr <= qr:
            self.tree[node] = self.tree.get(node, 0) + val
            self.lazy[node] = self.lazy.get(node, 0) + val
            return
        
        if nr < ql or nl > qr:
            return
        
        mid = (nl + nr) // 2
        self.update(nl, mid, ql, qr, val, node * 2, l, mid)
        self.update(mid + 1, nr, ql, qr, val, node * 2 + 1, mid + 1, r)
        
        self.tree[node] = self.lazy.get(node, 0) + max(
            self.tree.get(node * 2, 0),
            self.tree.get(node * 2 + 1, 0)
        )
    
    def query(self, nl, nr, ql, qr, node, l, r):
        if ql <= nl and nr <= qr:
            return self.tree.get(node, 0)
        
        if nr < ql or nl > qr:
            return 0
        
        mid = (nl + nr) // 2
        return self.lazy.get(node, 0) + max(
            self.query(nl, mid, ql, qr, node * 2, l, mid),
            self.query(mid + 1, nr, ql, qr, node * 2 + 1, mid + 1, r)
        )

线段树的变体

1. 区间最大值线段树

+ 换成 max

def _build(self, node, l, r, nums):
    if l == r:
        self.tree[node] = nums[l]
    else:
        mid = (l + r) // 2
        self._build(node * 2, l, mid, nums)
        self._build(node * 2 + 1, mid + 1, r, nums)
        self.tree[node] = max(self.tree[node * 2], self.tree[node * 2 + 1])

2. 区间赋值线段树

不是加值,而是赋值为固定值:

def _push_down(self, node, l, r):
    if self.lazy[node] is not None:
        mid = (l + r) // 2
        self.tree[node * 2] = self.lazy[node]
        self.tree[node * 2 + 1] = self.lazy[node]
        self.lazy[node * 2] = self.lazy[node]
        self.lazy[node * 2 + 1] = self.lazy[node]
        self.lazy[node] = None

💡 小结

线段树的核心操作

操作 普通线段树 懒更新线段树
单点查询 O(log n) O(log n)
单点更新 O(log n) O(log n)
区间查询 O(log n) O(log n)
区间更新 O(k log n) O(log n)

适用场景

  • 区间求和、最大值、最小值
  • 区间加法、赋值
  • 统计覆盖次数(731)
  • 动态第 K 大(区间第 K 大)

代码模板

class SegmentTree:
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)
        self._build(nums)
    
    def _build(self, nums, node=1, l=0, r=None):
        if r is None: r = self.n - 1
        if l == r:
            self.tree[node] = nums[l]
        else:
            mid = (l + r) // 2
            self._build(nums, node*2, l, mid)
            self._build(nums, node*2+1, mid+1, r)
            self.tree[node] = self.tree[node*2] + self.tree[node*2+1]
    
    def query(self, ql, qr, node=1, l=0, r=None):
        if r is None: r = self.n - 1
        if ql <= l and r <= qr: return self.tree[node]
        if r < ql or l > qr: return 0
        mid = (l + r) // 2
        return (self.query(ql, qr, node*2, l, mid) +
                self.query(ql, qr, node*2+1, mid+1, r))

相关阅读:树状数组(Fenwick Tree)