线段树是一种二叉树结构,用于高效处理区间查询和区间更新。可以在 O(log n) 时间内完成「区间求和」「区间最大值」等操作。
什么是线段树?
线段树是一种完全二叉树,每个节点代表一个区间:
数组: [1, 3, 5, 7, 9, 11, 13, 15]
(下标 0-7)
线段树结构:
[1, 15] (根节点:整个区间)
/ \
[1, 8] [9, 15]
/ \ / \
[1, 4] [5, 8] [9, 12] [13, 15]
/ \ / \ / \ / \
[1,2] [3,4][5,6][7,8][9,10][11,12][13,14][15,15]
/ \ / \ / \ / \ / \ / \ / \
[1][3] [5][7][9][11][13][15] ...
为什么是二叉树?
- 每个节点把区间劈成两半
- 叶子节点是单个元素
- 深度约等于 log₂(n)
线段树 vs 树状数组
| 特性 | 线段树 | 树状数组 |
|---|---|---|
| 功能 | 通用(任意区间操作) | 特定(前缀和为主) |
| 代码复杂度 | 较高 | 低(几行搞定) |
| 区间更新 | O(log n) | O(log n) |
| 区间查询 | O(log n) | O(log n) |
| 懒更新 | 支持 | 不支持 |
| 灵活性 | 高 | 低 |
选择原则:
- 只会「单点更新 + 前缀和」→ 用树状数组
- 需要「区间更新 + 区间查询」→ 必须用线段树
线段树模板
基本结构
class SegmentTree:
def __init__(self, nums):
self.n = len(nums)
# 4倍空间:线段树节点数不超过 4n
self.tree = [0] * (4 * self.n)
if self.n > 0:
self._build(1, 0, self.n - 1, nums)
def _build(self, node, start, end, nums):
"""递归构建线段树"""
if start == end:
self.tree[node] = nums[start]
else:
mid = (start + end) // 2
left_node = node * 2
right_node = node * 2 + 1
self._build(left_node, start, mid, nums)
self._build(right_node, mid + 1, end, nums)
# 根据需求修改:求和/最大值/最小值
self.tree[node] = self.tree[left_node] + self.tree[right_node]
def query(self, left, right):
"""区间查询 [left, right]"""
return self._query(1, 0, self.n - 1, left, right)
def _query(self, node, start, end, left, right):
if left <= start and end <= right:
# 当前区间完全在查询范围内
return self.tree[node]
if end < left or start > right:
# 完全不在范围内
return 0
# 部分重叠,递归查询
mid = (start + end) // 2
left_sum = self._query(node * 2, start, mid, left, right)
right_sum = self._query(node * 2 + 1, mid + 1, end, left, right)
return left_sum + right_sum
def update(self, index, val):
"""单点更新"""
self._update(1, 0, self.n - 1, index, val)
def _update(self, node, start, end, index, val):
if start == end:
self.tree[node] = val
return
mid = (start + end) // 2
if index <= mid:
self._update(node * 2, start, mid, index, val)
else:
self._update(node * 2 + 1, mid + 1, end, index, val)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
LeetCode 实战
题目 1:区域和检索 - 数组不可变
LeetCode 303 - 前缀和变种,用线段树(或直接前缀和)。
class NumArray:
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
if self.n > 0:
self._build(1, 0, self.n - 1, nums)
def _build(self, node, l, r, nums):
if l == r:
self.tree[node] = nums[l]
else:
mid = (l + r) // 2
self._build(node * 2, l, mid, nums)
self._build(node * 2 + 1, mid + 1, r, nums)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def sumRange(self, left, right):
return self._query(1, 0, self.n - 1, left, right)
def _query(self, node, l, r, ql, qr):
if ql <= l and r <= qr:
return self.tree[node]
if r < ql or l > qr:
return 0
mid = (l + r) // 2
return (self._query(node * 2, l, mid, ql, qr) +
self._query(node * 2 + 1, mid + 1, r, ql, qr))
题目 2:区域和检索 - 数组可修改
LeetCode 307 - 支持单点更新的区间求和。
class NumArray:
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
if self.n > 0:
self._build(1, 0, self.n - 1, nums)
def _build(self, node, l, r, nums):
if l == r:
self.tree[node] = nums[l]
else:
mid = (l + r) // 2
self._build(node * 2, l, mid, nums)
self._build(node * 2 + 1, mid + 1, r, nums)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def update(self, index, val):
self._update(1, 0, self.n - 1, index, val)
def _update(self, node, l, r, idx, val):
if l == r:
self.tree[node] = val
else:
mid = (l + r) // 2
if idx <= mid:
self._update(node * 2, l, mid, idx, val)
else:
self._update(node * 2 + 1, mid + 1, r, idx, val)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def sumRange(self, left, right):
return self._query(1, 0, self.n - 1, left, right)
def _query(self, node, l, r, ql, qr):
if ql <= l and r <= qr:
return self.tree[node]
if r < ql or l > qr:
return 0
mid = (l + r) // 2
return (self._query(node * 2, l, mid, ql, qr) +
self._query(node * 2 + 1, mid + 1, r, ql, qr))
懒更新线段树(Lazy Propagation)
为什么需要懒更新?
普通线段树的「区间更新」是 O(n) 的:每个节点都要更新。
懒更新:先记录这个更新,等查询时再真正应用。
class LazySegmentTree:
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
self.lazy = [0] * (4 * self.n) # 懒标记
if self.n > 0:
self._build(1, 0, self.n - 1, nums)
def _build(self, node, l, r, nums):
if l == r:
self.tree[node] = nums[l]
else:
mid = (l + r) // 2
self._build(node * 2, l, mid, nums)
self._build(node * 2 + 1, mid + 1, r, nums)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def _push_down(self, node, l, r):
"""下推懒标记到子节点"""
if self.lazy[node] != 0:
mid = (l + r) // 2
left, right = node * 2, node * 2 + 1
# 应用到左子节点
self.lazy[left] += self.lazy[node]
self.tree[left] += self.lazy[node] * (mid - l + 1)
# 应用到右子节点
self.lazy[right] += self.lazy[node]
self.tree[right] += self.lazy[node] * (r - mid)
# 清除当前节点的懒标记
self.lazy[node] = 0
def range_update(self, ql, qr, val):
"""区间更新 [ql, qr] 加上 val"""
self._range_update(1, 0, self.n - 1, ql, qr, val)
def _range_update(self, node, l, r, ql, qr, val):
if ql <= l and r <= qr:
# 完全覆盖,直接应用
self.tree[node] += val * (r - l + 1)
self.lazy[node] += val
return
if r < ql or l > qr:
return
# 下推懒标记
self._push_down(node, l, r)
mid = (l + r) // 2
self._range_update(node * 2, l, mid, ql, qr, val)
self._range_update(node * 2 + 1, mid + 1, r, ql, qr, val)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def range_query(self, ql, qr):
"""区间查询 [ql, qr]"""
return self._range_query(1, 0, self.n - 1, ql, qr)
def _range_query(self, node, l, r, ql, qr):
if ql <= l and r <= qr:
return self.tree[node]
if r < ql or l > qr:
return 0
self._push_down(node, l, r)
mid = (l + r) // 2
return (self._range_query(node * 2, l, mid, ql, qr) +
self._range_query(node * 2 + 1, mid + 1, r, ql, qr))
题目 3:我的日程安排表 II
LeetCode 731 - 区间添加(可能有重叠),查询某区间内被覆盖的最多次数。
class MyCalendarTwo:
def __init__(self):
self.tree = {}
self.lazy = {}
def book(self, start, end):
"""在 [start, end) 添加日程,返回是否有三重重叠"""
if self.query(0, 10**9, start, end - 1, 0, 10**9) < 2:
self.update(0, 10**9, start, end - 1, 1, 0, 10**9)
return True
return False
def update(self, nl, nr, ql, qr, val, node, l, r):
if ql <= nl and nr <= qr:
self.tree[node] = self.tree.get(node, 0) + val
self.lazy[node] = self.lazy.get(node, 0) + val
return
if nr < ql or nl > qr:
return
mid = (nl + nr) // 2
self.update(nl, mid, ql, qr, val, node * 2, l, mid)
self.update(mid + 1, nr, ql, qr, val, node * 2 + 1, mid + 1, r)
self.tree[node] = self.lazy.get(node, 0) + max(
self.tree.get(node * 2, 0),
self.tree.get(node * 2 + 1, 0)
)
def query(self, nl, nr, ql, qr, node, l, r):
if ql <= nl and nr <= qr:
return self.tree.get(node, 0)
if nr < ql or nl > qr:
return 0
mid = (nl + nr) // 2
return self.lazy.get(node, 0) + max(
self.query(nl, mid, ql, qr, node * 2, l, mid),
self.query(mid + 1, nr, ql, qr, node * 2 + 1, mid + 1, r)
)
线段树的变体
1. 区间最大值线段树
把 + 换成 max:
def _build(self, node, l, r, nums):
if l == r:
self.tree[node] = nums[l]
else:
mid = (l + r) // 2
self._build(node * 2, l, mid, nums)
self._build(node * 2 + 1, mid + 1, r, nums)
self.tree[node] = max(self.tree[node * 2], self.tree[node * 2 + 1])
2. 区间赋值线段树
不是加值,而是赋值为固定值:
def _push_down(self, node, l, r):
if self.lazy[node] is not None:
mid = (l + r) // 2
self.tree[node * 2] = self.lazy[node]
self.tree[node * 2 + 1] = self.lazy[node]
self.lazy[node * 2] = self.lazy[node]
self.lazy[node * 2 + 1] = self.lazy[node]
self.lazy[node] = None
💡 小结
线段树的核心操作:
| 操作 | 普通线段树 | 懒更新线段树 |
|---|---|---|
| 单点查询 | O(log n) | O(log n) |
| 单点更新 | O(log n) | O(log n) |
| 区间查询 | O(log n) | O(log n) |
| 区间更新 | O(k log n) | O(log n) |
适用场景:
- 区间求和、最大值、最小值
- 区间加法、赋值
- 统计覆盖次数(731)
- 动态第 K 大(区间第 K 大)
代码模板:
class SegmentTree:
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
self._build(nums)
def _build(self, nums, node=1, l=0, r=None):
if r is None: r = self.n - 1
if l == r:
self.tree[node] = nums[l]
else:
mid = (l + r) // 2
self._build(nums, node*2, l, mid)
self._build(nums, node*2+1, mid+1, r)
self.tree[node] = self.tree[node*2] + self.tree[node*2+1]
def query(self, ql, qr, node=1, l=0, r=None):
if r is None: r = self.n - 1
if ql <= l and r <= qr: return self.tree[node]
if r < ql or l > qr: return 0
mid = (l + r) // 2
return (self.query(ql, qr, node*2, l, mid) +
self.query(ql, qr, node*2+1, mid+1, r))
相关阅读:树状数组(Fenwick Tree)