树状数组(Fenwick Tree)

树状数组(Fenwick Tree),又称 Binary Indexed Tree(BIT),可以在 O(log n) 时间内完成前缀和查询和单点更新,是处理动态区间和的利器。

什么是树状数组?

树状数组是一种支持单点更新 + 区间求和的数据结构,比线段树更简洁,代码量更少。

核心操作

  • update(i, delta):将第 i 个元素增加 delta
  • query(i):查询前 i 个元素的和(前缀和)
  • range_sum(l, r):查询区间 [l, r] 的和

时间复杂度

操作 朴素数组 树状数组
单点更新 O(1) O(log n)
前缀查询 O(n) O(log n)
区间查询 O(n) O(log n)

树状数组的工作原理

核心思想:二进制分解

树状数组利用了二进制的前缀包含关系

假设 n = 8 (1000 in binary)

数组索引:  1   2   3   4   5   6   7   8
           ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓
BIT 存储:  [1] [1-2] [3] [1-4] [5] [5-6] [7] [1-8]
                    ↑           ↑
                奇数长度     偶数长度

每个 BIT 节点存储特定长度的区间和:

  • BIT[1] = arr[1]
  • BIT[2] = arr[1] + arr[2]
  • BIT[3] = arr[3]
  • BIT[4] = arr[1] + arr[2] + arr[3] + arr[4]

lowbit 函数

def lowbit(x):
    """返回 x 二进制最低位的 1 表示的数值"""
    return x & (-x)

# lowbit 规律:
# lowbit(1) = 1   1 & -1 = 1
# lowbit(2) = 2   10 & -10 = 10
# lowbit(3) = 1   11 & -11 = 1
# lowbit(4) = 4   100 & -100 = 100
# lowbit(5) = 1   101 & -101 = 1
# lowbit(6) = 2   110 & -110 = 10
# lowbit(8) = 8   1000 & -1000 = 1000

为什么 lowbit 有效?

在补码表示中,-x = ~x + 1,所以 x & -x 只保留最低位的 1。


树状数组模板

class FenwickTree:
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 1)  # 1-indexed
    
    def update(self, i, delta):
        """将第 i 个元素增加 delta"""
        while i <= self.n:
            self.tree[i] += delta
            i += lowbit(i)
    
    def query(self, i):
        """查询前 i 个元素的和"""
        result = 0
        while i > 0:
            result += self.tree[i]
            i -= lowbit(i)
        return result
    
    def range_sum(self, l, r):
        """查询区间 [l, r] 的和"""
        return self.query(r) - self.query(l - 1)
    
    def add(self, i, delta):
        """单点增加(别名)"""
        self.update(i, delta)

LeetCode 实战

题目 1:区域和检索 - 数组可修改

LeetCode 307 - 设计一个支持「单点更新」和「区间求和」的数据结构。

class NumArray:
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (self.n + 1)
        
        # 构建 BIT
        for i, num in enumerate(nums, 1):
            self.tree[i] += num
            j = i + lowbit(i)
            if j <= self.n:
                self.tree[j] += self.tree[i]
        
        self.nums = nums  # 存储原数组用于验证
    
    def update(self, index, val):
        """单点更新:将 nums[index] 改为 val"""
        delta = val - self.nums[index]
        self.nums[index] = val
        
        i = index + 1
        while i <= self.n:
            self.tree[i] += delta
            i += lowbit(i)
    
    def sumRange(self, left, right):
        """区间求和 [left, right]"""
        return self.range_sum(left + 1, right + 1)
    
    def query(self, i):
        result = 0
        while i > 0:
            result += self.tree[i]
            i -= lowbit(i)
        return result
    
    def range_sum(self, l, r):
        return self.query(r) - self.query(l - 1)

题目 2:翻转对

LeetCode 493 - 统计翻转对 (i, j) 其中 i < j 且 nums[i] > 2 * nums[j]。

class Solution:
    def reversePairs(self, nums):
        """利用 BIT 统计每个元素之前有多少个大于它的两倍"""
        if not nums:
            return 0
        
        # 离散化:只保留有效的比较值
        vals = sorted(set(nums + [x * 2 for x in nums]))
        tree = FenwickTree(len(vals))
        
        result = 0
        for i in range(len(nums)):
            # 查询当前元素之前有多少个数大于 2*nums[i]
            # 即查询 (2*nums[i], +∞) 区间
            x = nums[i] * 2
            idx = bisect.bisect_right(vals, x)  # 第一个 > 2*nums[i] 的位置
            if idx < len(vals):
                # 总数 - query(idx) = 大于 2*nums[i] 的数量
                result += tree.range_sum(idx + 1, len(vals))
            
            # 将当前元素加入 BIT
            num_idx = bisect.bisect_left(vals, nums[i]) + 1
            tree.add(num_idx, 1)
        
        return result

题目 3:计算右侧小于当前元素的个数

LeetCode 315 - 统计每个元素右侧有多少个小于它的元素。

class Solution:
    def countSmaller(self, nums):
        if not nums:
            return []
        
        # 离散化
        sorted_vals = sorted(set(nums))
        tree = FenwickTree(len(sorted_vals))
        
        result = [0] * len(nums)
        
        # 从右往左遍历
        for i in range(len(nums) - 1, -1, -1):
            # 查询当前元素之前(更小的)有多少个
            idx = bisect.bisect_left(sorted_vals, nums[i]) + 1
            result[i] = tree.query(idx - 1)  # 比当前元素小的数量
            tree.add(idx, 1)
        
        return result

二维树状数组

树状数组可以扩展到二维:

class FenwickTree2D:
    def __init__(self, m, n):
        self.m, self.n = m, n
        self.tree = [[0] * (n + 1) for _ in range(m + 1)]
    
    def update(self, x, y, delta):
        """更新二维点 (x, y) 的值"""
        i = x
        while i <= self.m:
            j = y
            while j <= self.n:
                self.tree[i][j] += delta
                j += lowbit(j)
            i += lowbit(i)
    
    def query(self, x, y):
        """查询矩阵 [1,1] 到 [x,y] 的和"""
        result = 0
        i = x
        while i > 0:
            j = y
            while j > 0:
                result += self.tree[i][j]
                j -= lowbit(j)
            i -= lowbit(i)
        return result
    
    def range_sum(self, x1, y1, x2, y2):
        """查询矩形 [x1,y1] 到 [x2,y2] 的和"""
        return (self.query(x2, y2) - self.query(x1-1, y2) 
                - self.query(x2, y1-1) + self.query(x1-1, y1-1))

LeetCode 308 - 二维区域和检索(矩阵可变):

class NumMatrix:
    def __init__(self, matrix):
        if not matrix or not matrix[0]:
            return
        self.matrix = matrix
        self.m, self.n = len(matrix), len(matrix[0])
        self.tree = [[0] * (self.n + 1) for _ in range(self.m + 1)]
        
        # 构建 BIT
        for i in range(1, self.m + 1):
            for j in range(1, self.n + 1):
                self.tree[i][j] += matrix[i-1][j-1]
                ni = i + lowbit(i)
                nj = j + lowbit(j)
                if ni <= self.m:
                    self.tree[ni][j] += self.tree[i][j]
                if nj <= self.n:
                    self.tree[i][nj] += self.tree[i][j]
                if ni <= self.m and nj <= self.n:
                    self.tree[ni][nj] -= self.tree[i][j]
    
    def update(self, row, col, val):
        delta = val - self.matrix[row][col]
        self.matrix[row][col] = val
        
        i = row + 1
        while i <= self.m:
            j = col + 1
            while j <= self.n:
                self.tree[i][j] += delta
                j += lowbit(j)
            i += lowbit(i)
    
    def sumRegion(self, row1, col1, row2, col2):
        x1, y1 = row1 + 1, col1 + 1
        x2, y2 = row2 + 1, col2 + 1
        
        res = self.tree[x2][y2] - self.tree[x1-1][y2] \
              - self.tree[x2][y1-1] + self.tree[x1-1][y1-1]
        
        # 手动计算(因为上面的初始化比较特殊)
        result = 0
        for i in range(x1, x2 + 1):
            for j in range(y1, y2 + 1):
                result += self.matrix[i-1][j-1]
        return result

💡 小结

树状数组的核心:

操作 时间复杂度 空间
单点更新 O(log n)  
前缀查询 O(log n) O(n)
区间查询 O(log n)  

适用场景

  • 动态数组的区间求和(307)
  • 逆序对统计(315、493)
  • 统计小于某值的元素个数
  • 二维区间和(308)

注意:树状数组要求索引从 1 开始!

模板代码

class FenwickTree:
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 1)
    
    def add(self, i, delta):
        while i <= self.n:
            self.tree[i] += delta
            i += i & -i
    
    def sum(self, i):
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & -i
        return res

相关阅读:字典树(Trie)