本文仅对树状数组的使用作一个总结,并非讲解。

这里的操作都对长度为 $n$ 的数组 $a$ 进行操作。

单点修改,区间查询

  • 暴力做法:

    • 修改:$a[p]=y$,时间复杂度为 $O(1)$
    • 查询:$\sum\limits_{i=l}^ra[i]$ ,时间复杂度为 $O(n)$
  • 树状数组: $tr$ 数组 对 $a$ 数组进行维护

    • 修改:

      1
      2
      3
      4
      5
      
      void update(int x, int y) {
          while (x <= n) tr[x] += y, x += (x & (-x));
      }
            
      update(x, y);
      
       时间复杂度为 $O(\log n)$
      
      
    • 查询:

      1
      2
      3
      4
      5
      
      int query(int x) {
          int ans = 0;
          while (x >= 1) ans += tr[x], x -= (x & (-x));
          return ans;
      }
      
       时间复杂度为 $O(\log n)$
      
      

区间修改,单点查询

  • 暴力做法:

    • 修改:$a[l]=a[l]+x,\cdots,a[r]=a[r]+x$,时间复杂度为 $O(n)$
    • 查询:$a[p]$,时间复杂度为 $O(1)$
  • 树状数组: $b$ 数组是 $a$ 的差分数组。$tr$ 数组对 $b$ 数组进行维护

    • 修改:
      1
      2
      3
      4
      5
      6
      
      void update(int x, int y) {
          while (x <= n) tr[x] += y, x += (x & (-x));
      }
            
      update(l, x);
      update(r + 1, -x);
      
       时间复杂度为 $O(\log n)$
      
    • 查询:
      1
      2
      3
      4
      5
      6
      7
      
      int query(int x) {
          int ans = 0;
          while (x >= 1) ans += tr[x], x -= (x & (-x));
          return ans;
      }
            
      query(x);
      
       时间复杂度为 $O(\log n)$
      
      

区间修改,区间查询

区间查询的的公式为:$\sum\limits_{i=l}^ra[i]$,我们先考虑如何求 $\sum\limits_{i=1}^p a[i]$

问题转换为如何去求解这个公式,暴力情况下,求 $a[i]$ 是 $O(\log n)$,总时间复杂度为 $O(n\log n)$。

但是我们可以拆分这个公式: $\begin{aligned} \sum\limits_{i=1}^p a[i] &=\sum\limits_{i=1}^p \sum\limits_{j=1}^ib[j] \
&= p \times b[1]+(p-1)\times b[2]+\cdots+1\times b[p] \
&= ((p+1)\times \sum\limits_{i=1}^p b[i])-(1\times b[1]+2\times b[2]+\cdots+p\times b[p])\
\end{aligned}$

所以我们再用一个额外的树状数组去维护 $i\times b[i]$ 即可。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// 区间修改[x, n]
const int N =100010;
int tr1[N], tr2[N];
void update(int x, int y) {
	int val2 = x * y;
	while (x <= n) {
		tr1[x] += y;
		tr2[x] += val2;
		x += (x & (-x));
	}
}
update(l, d);      // a[l, n] += d;
update(r + 1, -d); // a[r + 1, n] -= d

// 区间查询(1, x)
int query(int x) {
	int p = x;
	int val1 = 0, val2 = 0;
	while (x >= 1) {
		val1 += tr1[x];
		val2 += tr2[x];
		x -= (x & -x);
	}
	return (p + 1) * val1 - val2;
}
// 查询 a[l, r]
query(1, r) - query(1, l - 1);

时间复杂度为 $O(\log n)$