树状数组(Fenwick Tree),又称 Binary Indexed Tree(BIT),可以在 O(log n) 时间内完成前缀和查询和单点更新,是处理动态区间和的利器。
什么是树状数组?
树状数组是一种支持单点更新 + 区间求和的数据结构,比线段树更简洁,代码量更少。
核心操作:
update(i, delta):将第 i 个元素增加 deltaquery(i):查询前 i 个元素的和(前缀和)range_sum(l, r):查询区间 [l, r] 的和
时间复杂度:
| 操作 | 朴素数组 | 树状数组 |
|---|---|---|
| 单点更新 | O(1) | O(log n) |
| 前缀查询 | O(n) | O(log n) |
| 区间查询 | O(n) | O(log n) |
树状数组的工作原理
核心思想:二进制分解
树状数组利用了二进制的前缀包含关系:
假设 n = 8 (1000 in binary)
数组索引: 1 2 3 4 5 6 7 8
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
BIT 存储: [1] [1-2] [3] [1-4] [5] [5-6] [7] [1-8]
↑ ↑
奇数长度 偶数长度
每个 BIT 节点存储特定长度的区间和:
BIT[1]= arr[1]BIT[2]= arr[1] + arr[2]BIT[3]= arr[3]BIT[4]= arr[1] + arr[2] + arr[3] + arr[4]- …
lowbit 函数
def lowbit(x):
"""返回 x 二进制最低位的 1 表示的数值"""
return x & (-x)
# lowbit 规律:
# lowbit(1) = 1 1 & -1 = 1
# lowbit(2) = 2 10 & -10 = 10
# lowbit(3) = 1 11 & -11 = 1
# lowbit(4) = 4 100 & -100 = 100
# lowbit(5) = 1 101 & -101 = 1
# lowbit(6) = 2 110 & -110 = 10
# lowbit(8) = 8 1000 & -1000 = 1000
为什么 lowbit 有效?
在补码表示中,-x = ~x + 1,所以 x & -x 只保留最低位的 1。
树状数组模板
class FenwickTree:
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1) # 1-indexed
def update(self, i, delta):
"""将第 i 个元素增加 delta"""
while i <= self.n:
self.tree[i] += delta
i += lowbit(i)
def query(self, i):
"""查询前 i 个元素的和"""
result = 0
while i > 0:
result += self.tree[i]
i -= lowbit(i)
return result
def range_sum(self, l, r):
"""查询区间 [l, r] 的和"""
return self.query(r) - self.query(l - 1)
def add(self, i, delta):
"""单点增加(别名)"""
self.update(i, delta)
LeetCode 实战
题目 1:区域和检索 - 数组可修改
LeetCode 307 - 设计一个支持「单点更新」和「区间求和」的数据结构。
class NumArray:
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (self.n + 1)
# 构建 BIT
for i, num in enumerate(nums, 1):
self.tree[i] += num
j = i + lowbit(i)
if j <= self.n:
self.tree[j] += self.tree[i]
self.nums = nums # 存储原数组用于验证
def update(self, index, val):
"""单点更新:将 nums[index] 改为 val"""
delta = val - self.nums[index]
self.nums[index] = val
i = index + 1
while i <= self.n:
self.tree[i] += delta
i += lowbit(i)
def sumRange(self, left, right):
"""区间求和 [left, right]"""
return self.range_sum(left + 1, right + 1)
def query(self, i):
result = 0
while i > 0:
result += self.tree[i]
i -= lowbit(i)
return result
def range_sum(self, l, r):
return self.query(r) - self.query(l - 1)
题目 2:翻转对
LeetCode 493 - 统计翻转对 (i, j) 其中 i < j 且 nums[i] > 2 * nums[j]。
class Solution:
def reversePairs(self, nums):
"""利用 BIT 统计每个元素之前有多少个大于它的两倍"""
if not nums:
return 0
# 离散化:只保留有效的比较值
vals = sorted(set(nums + [x * 2 for x in nums]))
tree = FenwickTree(len(vals))
result = 0
for i in range(len(nums)):
# 查询当前元素之前有多少个数大于 2*nums[i]
# 即查询 (2*nums[i], +∞) 区间
x = nums[i] * 2
idx = bisect.bisect_right(vals, x) # 第一个 > 2*nums[i] 的位置
if idx < len(vals):
# 总数 - query(idx) = 大于 2*nums[i] 的数量
result += tree.range_sum(idx + 1, len(vals))
# 将当前元素加入 BIT
num_idx = bisect.bisect_left(vals, nums[i]) + 1
tree.add(num_idx, 1)
return result
题目 3:计算右侧小于当前元素的个数
LeetCode 315 - 统计每个元素右侧有多少个小于它的元素。
class Solution:
def countSmaller(self, nums):
if not nums:
return []
# 离散化
sorted_vals = sorted(set(nums))
tree = FenwickTree(len(sorted_vals))
result = [0] * len(nums)
# 从右往左遍历
for i in range(len(nums) - 1, -1, -1):
# 查询当前元素之前(更小的)有多少个
idx = bisect.bisect_left(sorted_vals, nums[i]) + 1
result[i] = tree.query(idx - 1) # 比当前元素小的数量
tree.add(idx, 1)
return result
二维树状数组
树状数组可以扩展到二维:
class FenwickTree2D:
def __init__(self, m, n):
self.m, self.n = m, n
self.tree = [[0] * (n + 1) for _ in range(m + 1)]
def update(self, x, y, delta):
"""更新二维点 (x, y) 的值"""
i = x
while i <= self.m:
j = y
while j <= self.n:
self.tree[i][j] += delta
j += lowbit(j)
i += lowbit(i)
def query(self, x, y):
"""查询矩阵 [1,1] 到 [x,y] 的和"""
result = 0
i = x
while i > 0:
j = y
while j > 0:
result += self.tree[i][j]
j -= lowbit(j)
i -= lowbit(i)
return result
def range_sum(self, x1, y1, x2, y2):
"""查询矩形 [x1,y1] 到 [x2,y2] 的和"""
return (self.query(x2, y2) - self.query(x1-1, y2)
- self.query(x2, y1-1) + self.query(x1-1, y1-1))
LeetCode 308 - 二维区域和检索(矩阵可变):
class NumMatrix:
def __init__(self, matrix):
if not matrix or not matrix[0]:
return
self.matrix = matrix
self.m, self.n = len(matrix), len(matrix[0])
self.tree = [[0] * (self.n + 1) for _ in range(self.m + 1)]
# 构建 BIT
for i in range(1, self.m + 1):
for j in range(1, self.n + 1):
self.tree[i][j] += matrix[i-1][j-1]
ni = i + lowbit(i)
nj = j + lowbit(j)
if ni <= self.m:
self.tree[ni][j] += self.tree[i][j]
if nj <= self.n:
self.tree[i][nj] += self.tree[i][j]
if ni <= self.m and nj <= self.n:
self.tree[ni][nj] -= self.tree[i][j]
def update(self, row, col, val):
delta = val - self.matrix[row][col]
self.matrix[row][col] = val
i = row + 1
while i <= self.m:
j = col + 1
while j <= self.n:
self.tree[i][j] += delta
j += lowbit(j)
i += lowbit(i)
def sumRegion(self, row1, col1, row2, col2):
x1, y1 = row1 + 1, col1 + 1
x2, y2 = row2 + 1, col2 + 1
res = self.tree[x2][y2] - self.tree[x1-1][y2] \
- self.tree[x2][y1-1] + self.tree[x1-1][y1-1]
# 手动计算(因为上面的初始化比较特殊)
result = 0
for i in range(x1, x2 + 1):
for j in range(y1, y2 + 1):
result += self.matrix[i-1][j-1]
return result
💡 小结
树状数组的核心:
| 操作 | 时间复杂度 | 空间 |
|---|---|---|
| 单点更新 | O(log n) | |
| 前缀查询 | O(log n) | O(n) |
| 区间查询 | O(log n) |
适用场景:
- 动态数组的区间求和(307)
- 逆序对统计(315、493)
- 统计小于某值的元素个数
- 二维区间和(308)
注意:树状数组要求索引从 1 开始!
模板代码:
class FenwickTree:
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1)
def add(self, i, delta):
while i <= self.n:
self.tree[i] += delta
i += i & -i
def sum(self, i):
res = 0
while i > 0:
res += self.tree[i]
i -= i & -i
return res
相关阅读:字典树(Trie)