朝花夕拾-树链剖分

树链剖分的名字非常高大上, 其实不难, 本质是将树分解成几条链, 映射在线段树(树状数组、Splay等)上, 当我们需要在树上的路径(题目通常给定两个结点,在他们的路径上操作)进行操作时, 就能转化成在线段树(等数据结构)上进行操作, 接下来以树转线段树为例。

树链剖分最重要的地方在于分解树的规则, 下面是一种通用的分解方法:对于每个结点,如果它有儿子,那么取其子树最重(子树结点最多)的儿子为“重儿子”, 重儿子与其父亲的连边称为“重边”。相对的,其余的儿子成为“轻儿子”,轻儿子与父亲的连边称为轻边 。连续的重边构成“重链”,在同一条重链上的点 在线段树中 按重链的顺序相邻。显然,非叶子结点都有重儿子, 它们必定能映射到唯一的重链上。对于多个同父亲的叶子结点,随便取一个作为重儿子就可以了。可以想象, 轻边连接着不相交的重链,树上的每个结点都能唯一的映射到线段的某一个点上。

树链剖分例子

这样一来,一些数据结构就能够推广到树上。比如能求区间最值、区间求和的线段树,在映射做好后,就能套用在树上。
假如我们要求树上两点(u,v)之间的路径的权值和。例如上图求(11, 14)间路径权值和,分别从11和14往上跳,跳到重链的顶端,(为了不超过两点的最近公共祖先, 要选择更深的重链头跳),11跳到2,用线段树求出(2,11)的和(利用重链在映射线段上连续的特点),暴力加一下轻边(1,2),此时两方面同时跳到重链头1,用线段树求出(1,14)的和,得到答案。

算法分析

主要数据

    • 这部分是由题目给出的数据
    • 题目要求在树上的路径做段修改, 段询问.
    • 如果树很大, 输入代价可能也很大, 应该考虑自定义输入. 字符串输入用scanf.
  1. 高级树形数据结构
    • 能够满足题目要求的修改, 询问操作.
    • 基础数据来自一段预处理序列, 而不是直接来自输入.
  2. 输入树到自定义数据结构的映射
    • 简称树链剖分.
    • 将输入树剪成若干条链, 作为高级数据结构的基础数据.

剖分方法

在生成映射的过程中, 对于原树的每个节点, 有六个信息要处理.

  1. 父亲节点 - 防止重复搜索
  2. 节点深度 - 防止越过最近公共祖先
  3. 子树大小 - 决定重儿子
  4. 重儿子 - 构成重链
  5. 链上编号 - 映射到序列
  6. 重链头 - 找到编号连续的段
  • 1,2,5,6是自顶向下计算的, 3,4必须自底向上计算.
  • bfs计算1,2. 反向bfs计算3,4.
  • 模拟栈(#define成队列数组)计算5,6. 为了保证重链编号连续, 重儿子后入栈.
  • 如果值在边上, 将值放进子节点中. 询问时注意轻边的纳入.

映射询问

  • 如果两点所在重链的重链头不同, 意味着两点不在同一条重链. 为了防止一条重链头已经越过最近公共祖先, 选择重链头深度大的计算. 询问映射段, 把轻边也算在内(轻边信息放在了重链头节点中). 计算后节点换成重链头的父亲.
  • 如果两节点在同一条重链上, 首先把深度安排好, 特判两节点是否相同. 如果询问边, 这时候的计算与轻边无关, 注意把浅节点编号加一.

关于线段树延迟标记

  • pushDown意味着当前段有多余, 要分两半, 但不一定两半都继续递归. 所以递归后两半都可能带有延迟标记.
  • pushUp在修改操作中出现, 且前面必有pushDown, 意味着当前段的延迟操作已经补好. pushUp结束后当前段一定不带延迟标记.

求路径最值也如法炮制。
下面给出实现树上路径求最值、求和的代码,题目对应HYSBZ 1036。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#include <algorithm>
#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#define mmst(a, b) memset(a, b, sizeof(a))
#define lson (root<<1)
#define rson ((root<<1)|1)
using namespace std;

const int MAXN = 30010;
const int MAXE = MAXN<<1;

void ri(int &x);

int n, q;

int to[MAXE], nex[MAXE], Mindex[MAXN], cur=0;
int w[MAXN], v[MAXN], top[MAXN], fa[MAXN], son[MAXN], siz[MAXN], depth[MAXN], z=0;

int SUM[MAXN<<2], MAX[MAXN<<2];

void addEdge(const int a, const int b);
void dfs(const int x)
{
son[x] = 0;
siz[x] = 1;
for (int i=Mindex[x]; i!=-1; i=nex[i])
{
if (to[i] != fa[x])
{
fa[to[i]] = x;
depth[to[i]] = depth[x] + 1;
dfs(to[i]);
if (siz[to[i]] > siz[son[x]]) son[x] = to[i];
siz[x] += siz[to[i]];
}
}
}
void set_tree(const int x, const int tp)
{
top[x] = tp;
w[x] = ++z;
if (son[x]!=0) set_tree(son[x], tp);
for (int i=Mindex[x]; i!=-1; i=nex[i])
{
if (to[i]!=fa[x] && to[i]!=son[x])
{
set_tree(to[i], to[i]);
}
}
}

void update(int l, int r, int p, int val, int root)
{
if (l==r) { SUM[root] = MAX[root] = val; return ; }
const int mid = (l+r) >> 1;
if (p<=mid) update(l, mid, p, val, lson);
else update(mid+1, r, p, val, rson);
SUM[root] = SUM[lson] + SUM[rson];
MAX[root] = max(MAX[lson], MAX[rson]);
}

int qMax(int l, int r, int L, int R, int root)
{
if (L<=l && r<=R) return MAX[root];
const int mid = (l+r) >> 1;
int ref = -MAXN;
if (L<=mid) ref = qMax(l, mid, L, R, lson);
if (mid<R) ref = max(ref, qMax(mid+1, r, L, R, rson));
return ref;
}

int qSum(int l, int r, int L, int R, int root)
{
if (L<=l && r<=R) return SUM[root];
const int mid = (l+r) >> 1;
int ref = 0;
if (L<=mid) ref = qSum(l, mid, L, R, lson);
if (mid<R) ref += qSum(mid+1, r, L, R, rson);
return ref;
}

int findMax(int a, int b)
{
int ref = -MAXN;
int f1 = top[a], f2 = top[b];
while (f1 != f2)
{
if (depth[f1] < depth[f2]) { swap(a, b); swap(f1, f2); }
ref = max(ref, qMax(1, z, w[f1], w[a], 1));
a = fa[f1];
f1 = top[a];
}
if (a == b) return max(ref, v[a]);
if (depth[a] > depth[b]) swap(a, b);
return max(ref, qMax(1, z, w[a], w[b], 1));
}

int findSum(int a, int b)
{
int ref = 0;
int f1 = top[a], f2 = top[b];
while (f1 != f2)
{
if (depth[f1] < depth[f2]) { swap(a, b); swap(f1, f2); }
ref += qSum(1, z, w[f1], w[a], 1);
a = fa[f1];
f1 = top[a];
}
if (a == b) return ref + v[a];
if (depth[a] > depth[b]) swap(a, b);
return ref + qSum(1, z, w[a], w[b], 1);
}

int main()
{
// freopen("tes.in", "r", stdin);
mmst(Mindex, -1);
ri(n);
for (int i=1, a, b; i<n; ++i)
{
ri(a); ri(b);
addEdge(a, b);
}

for (int i=1; i<=n; ++i)
ri(v[i]);

siz[0] = 0; depth[1] = 1; fa[1] = 0;
dfs(1);
set_tree(1, 1);

for (int i=1; i<=n; ++i)
update(1, z, w[i], v[i], 1);

ri(q);
char que[7];
int q1, q2;
while (q--)
{
scanf("%s", que);
ri(q1); ri(q2);
if (que[1] == 'M') printf("%d\n", findMax(q1, q2));
else if (que[1] == 'S') printf("%d\n", findSum(q1, q2));
else update(1, z, w[q1], v[q1]=q2, 1);
}

return 0;
}

inline void addEdge(const int a, const int b)
{
to[cur] = b; nex[cur] = Mindex[a]; Mindex[a] = cur++;
to[cur] = a; nex[cur] = Mindex[b]; Mindex[b] = cur++;
}

inline void ri(int &x)
{
char c; bool minus = false;
while ((c=getchar())<'0' || '9'<c) if (c=='-') minus=true;
x = c-'0';
while ('0'<=(c=getchar()) && c<='9')
x = 10*x+c-'0';
if (minus) x = -x;
}

编程时有几点要留意:

  • 注意权值在树中摆放的位置, 在边上或者在结点上,通常可以转化成在结点上。
  • 在一条重链上,深度小的点在线段树左边,深度大的在右边(相对而言)。
  • find()中当f1==f2后,路径a~b是尚未被计算的,特别是要留意if (a==b),将其权值记上。
  • dfs()先解决深度、大小(轻重)、重儿子、父亲的问题。
  • set_tree()后解决top重链头标记、安排映射的问题。

2016.06.02新增

改进两个初始化:

  • 第一个初始化先用宽搜解决深度计数和父亲指针的预处理,然后反向遍历队列,用改进的方式计算size和重儿子。
  • 第二个初始化用粗略的模拟栈,重儿子优先做深搜,由浅到深计算id和top。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#include <algorithm>
#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
using namespace std;

void ri(int &);

const int MAXN = 30010;

int n, m;

const bool GETMAX = true, GETSUM = false;

int val[MAXN];

int SUM[MAXN<<2], MAX[MAXN<<2];

int depth[MAXN], fa[MAXN], top[MAXN], wson[MAXN], siz[MAXN], id[MAXN];

int cur = 0;
int to[MAXN<<1], nex[MAXN<<1], _index[MAXN];

void addEdge(int a, int b);

int q[MAXN];
void bfs();
void dfs();
void build(int r);
void update(int l, int r, int p, int v, int root=1);
int queryMAX(int l, int r, int L, int R, int root=1);
int querySUM(int l, int r, int L, int R, int root=1);

int qtree(int a, int b, const bool flag)
{
int ref = (flag == GETMAX ? -MAXN : 0);
while (top[a]!=top[b])
{
if (depth[top[a]] < depth[top[b]]) swap(a, b);
if (flag==GETMAX) ref = max(ref, queryMAX(1, n, id[top[a]], id[a]));
else ref += querySUM(1, n, id[top[a]], id[a]);
a = fa[top[a]];
}
if (depth[a]>depth[b]) swap(a, b);
if (flag == GETMAX) ref = max(ref, queryMAX(1, n, id[a], id[b]));
else ref += querySUM(1, n, id[a], id[b]);
return ref;
}

int main()
{
// freopen("tes.in", "r", stdin);

memset(_index, -1, sizeof(_index));
memset(wson, 0, sizeof(wson));
memset(siz, 0, sizeof(siz));

ri(n);

int u, v;

for (int i=1; i<n; ++i)
{
ri(u); ri(v);
addEdge(u, v);
}

for (int i=1; i<=n; ++i)
{
ri(val[i]);
}

bfs();
dfs();
build(n);

for (int i=1; i<=n; ++i)
update(1, n, id[i], val[i]);

ri(m);
char rd[7];
while (m--)
{
scanf("%s", rd);
ri(u); ri(v);
if (rd[1] == 'M')
printf("%d\n", qtree(u, v, GETMAX));
else if (rd[1]=='S')
printf("%d\n", qtree(u, v, GETSUM));
else update(1, n, id[u], v);
}

return 0;
}

void bfs()
{
int head = 0, tail = 1;
q[0] = 1;
fa[1] = 0;
depth[1] = 0;

while (head < tail)
{
const int x = q[head++];
for (int i=_index[x]; i!=-1; i=nex[i])
{
if (to[i] != fa[x])
{
depth[to[i]] = depth[x] + 1;
fa[to[i]] = x;
q[tail++] = to[i];
}
}
}

for (int i=tail-1; i>=0; --i)
{
const int x = q[i], y = fa[x];
siz[y] += (++siz[x]);
if (siz[wson[y]] < siz[x]) wson[y] = x;
}
}

#define sta q
void dfs()
{
int num = 0;
int tail = 1;
sta[0] = 1;
top[1] = 1;

while (tail != 0)
{
const int x = sta[--tail];
id[x] = ++num;
for (int i=_index[x]; i!=-1; i=nex[i])
{
if (to[i]!=fa[x] && to[i]!=wson[x])
{
top[to[i]] = to[i];
sta[tail++] = to[i];
}
}
if (wson[x] != 0)
{
top[wson[x]] = top[x];
sta[tail++] = wson[x];
}
}
}

#define lson (root<<1)
#define rson ((root<<1)|1)
void update(int l, int r, int p, int v, int root)
{
if (l==r)
{
SUM[root] = MAX[root] = v;
return;
}
const int mid = (l+r)>>1;
if (p<=mid) update(l, mid, p, v, lson);
else update(mid+1, r, p, v, rson);
SUM[root] = SUM[lson] + SUM[rson];
MAX[root] = max(MAX[lson], MAX[rson]);
}

inline void build(int r)
{
r <<= 2;
for (int i=0; i<=r; ++i)
{
SUM[i] = 0;
MAX[i] = -MAXN;
}
}

int queryMAX(int l, int r, int L, int R, int root)
{
if (L<=l && r<=R) return MAX[root];
const int mid = (l+r) >> 1;
int ref = -MAXN;
if (L<=mid) ref = queryMAX(l, mid, L, R, lson);
if (mid<R) ref = max(ref, queryMAX(mid+1, r, L, R, rson));
return ref;
}

int querySUM(int l, int r, int L, int R, int root)
{
if (L<=l && r<=R) return SUM[root];
const int mid = (l+r) >> 1;
int ref = 0;
if (L<=mid) ref = querySUM(l, mid, L, R, lson);
if (mid<R) ref += querySUM(mid+1, r, L, R, rson);
return ref;
}

inline void addEdge(int a, int b)
{
to[cur] = b; nex[cur] = _index[a]; _index[a] = cur++;
to[cur] = a; nex[cur] = _index[b]; _index[b] = cur++;
}

inline void ri(int &x)
{
char c;
bool mus = false;
while ((c=getchar())<'0' || '9'<c) if (c=='-') mus = true;
x = c-'0';
while ('0'<=(c=getchar()) && c<='9')
x = 10*x + c - '0';
if (mus) x = -x;
}