Sorry, your browser cannot access this site
This page requires browser support (enable) JavaScript
Learn more >

P3384 【模板】重链剖分/树链剖分

适用问题

对一个给定的树,维护(修改、查询)其节点点权/边权,操作可 \(O(\log n)\) 次处理(如路径、子树)。

实现

将一棵树剖分成若干条链,每一条链通过区间数据结构维护。

将一个点的子树中最大的那个做为重儿子,其他的叫做轻儿子,对于所有点的重儿子连的链称为重链。我们将重链作为主链,将轻儿子构成的其他轻链加在重链的后面。很明显,对于每一条链,它的节点都是相连的(废话)。对于每一个子树,它的所有的节点都排在子树根的后面。

所以我们就把一个树成功地拆解成了一条链。

因为这条链有如上我所说的几个性质,题目要求如果是能在链上连续维护的(如求子树权值和,求任意两点间距离和),我们就可以用数据结构维护它了,比如线段树。

由于树的重心的性质,一次操作拆成了 \(O(\log n)\) 次修改。

剖分部分

首先我们需要两次 dfs。

dfs1

用于建树,顺便记录每个节点的父亲 和 该点的深度 和 它的子树的大小 和 它的重儿子。

fa[] 父亲节点 dep[] 深度 siz[] 子树大小 son[] 重儿子(重儿子为子树大者)

1
2
3
4
5
6
7
8
9
10
11
12
13
void dfs1(int x,int f,int depth)
{
siz[x]=1; dep[x]=depth; fa[x]=f;
int maxson=-1;
for(R int i=head[x];i;i=e[i].nxt)
{
int u=e[i].to;
if(u==f)continue;
dfs1(u,x,depth+1);
siz[x]+=siz[u];
if(siz[u]>maxson)son[x]=u,maxson=siz[u];
}
}

dfs2

用于将一个树退化成链,记录节点在链的编号 和 节点的链首节点 和 链上节点的权值。

cnt 时间戳 id[] 编号 top[] 链首节点 a[] 原值 w[] 链上的值(便于维护)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void dfs2(int x,int topf)
{
id[x]=++cnt;
w[cnt]=a[x];
top[x]=topf;
if(!son[x])return;
dfs2(son[x],topf);
for(R int i=head[x];i;i=e[i].nxt)
{
int u=e[i].to;
if(u==fa[x]||u==son[x])continue;
dfs2(u,u);
}
}

求两点间距离

要求两点间距离:

若两点在一条链上(top[x]==top[y])我们直接求两点间的距离即可。

若两点不在一条链上,那么求更深的那个点 x 到此刻链首 top[x] 的距离,然后令 x=fa[top[x]],即可将 x 更新到新链上。重复操作,每次只将深度更大的点向上更新。最终两点会处于同一条链(重链或轻链,最远是重链)上,然后再加上两点间的和就可以了。

用线段树维护。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
inline int queryrange(int x,int y)
{
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=0;
st.query(1,id[top[x]],id[x]);
ans=(ans+res)%mod;
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
res=0;
st.query(1,id[x],id[y]);
ans=(ans+res)%mod;
return ans;
}

更新两点间距离

和上面是一样的,分成若干条链更新就可以了。

1
2
3
4
5
6
7
8
9
10
inline void updaterange(int x,int y,int k){
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
st.update(1,id[top[x]],id[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
st.update(1,id[x],id[y],k);
}

更新/查询子树权值(和)

如上,子树节点在链上一定是在根的后面并且连续的。

所以要更新以 x 为根节点的所有子树结点,就更新 id[x]~id[x]+size[x]-1 的范围即可。

模板

当然要根据题目需要做各种修改。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include<bits/stdc++.h>
using namespace std;
inline int read(){
int x=0,w=0;char c=getchar();
while(!isdigit(c)) w|=c=='-',c=getchar();
while(isdigit(c)) x=x*10+c-'0',c=getchar();
return w?-x:x;
}
namespace star
{
const int maxn=1e5+10,maxm=1e5+10;
int n,m,root,mod;
int fa[maxn],dep[maxn],son[maxn],siz[maxn],top[maxn],dfn[maxn],a[maxn],w[maxn];
int ecnt,head[maxn],to[maxn<<1],nxt[maxn<<1];
inline void addedge(int a,int b){
to[++ecnt]=b,nxt[ecnt]=head[a],head[a]=ecnt;
to[++ecnt]=a,nxt[ecnt]=head[b],head[b]=ecnt;
}
struct SegmentTree{
#define ls (ro<<1)
#define rs (ro<<1|1)
#define mid ((l+r)>>1)
int val[maxn<<2],tag[maxn<<2];
inline void pushup(int ro){val[ro]=(val[ls]+val[rs])%mod;}
inline void pushdown(int ro,int l,int r){
tag[ls]+=tag[ro],val[ls]=(val[ls]+(mid-l+1)*tag[ro])%mod;
tag[rs]+=tag[ro],val[rs]=(val[rs]+(r-mid)*tag[ro])%mod;
tag[ro]=0;
}
void build(const int &ro=1,const int &l=1,const int &r=n){
if(l==r) return val[ro]=w[l]%mod,void();
build(ls,l,mid),build(rs,mid+1,r);
pushup(ro);
}
void update(int x,int y,int k,const int &ro=1,const int &l=1,const int &r=n){
if(x==l and y==r) return tag[ro]+=k,val[ro]=(val[ro]+k*(r-l+1))%mod,void();
if(tag[ro]) pushdown(ro,l,r);
if(y<=mid) update(x,y,k,ls,l,mid);
else if(x>mid) update(x,y,k,rs,mid+1,r);
else update(x,mid,k,ls,l,mid),update(mid+1,y,k,rs,mid+1,r);
pushup(ro);
}
int query(int x,int y,const int &ro=1,const int &l=1,const int &r=n){
if(x==l and y==r) return val[ro];
if(tag[ro]) pushdown(ro,l,r);
if(y<=mid) return query(x,y,ls,l,mid);
if(x>mid) return query(x,y,rs,mid+1,r);
return (query(x,mid,ls,l,mid)+query(mid+1,y,rs,mid+1,r))%mod;
}
#undef ls
#undef rs
#undef mid
}st;
void dfs1(int x,int Fa){
fa[x]=Fa,siz[x]=1,dep[x]=dep[Fa]+1;
int mx=-1;
for(int u,i=head[x];i;i=nxt[i]) if((u=to[i])!=Fa){
dfs1(u,x);
siz[x]+=siz[u];
if(siz[u]>mx) son[x]=u,mx=siz[u];
}
}
void dfs2(int x,int topf){
w[dfn[x]=++dfn[0]]=a[x];
top[x]=topf;
if(!son[x]) return;
dfs2(son[x],topf);
for(int u,i=head[x];i;i=nxt[i]) if((u=to[i])!=fa[x] and u!=son[x]) dfs2(u,u);
}
inline void update(int x,int y,int k){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
st.update(dfn[top[x]],dfn[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
st.update(dfn[x],dfn[y],k);
}
inline int query(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+st.query(dfn[top[x]],dfn[x]))%mod;
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans=(ans+st.query(dfn[x],dfn[y]))%mod;
return ans;
}
inline void work(){
n=read(),m=read(),root=read(),mod=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<n;i++) addedge(read(),read());
dfs1(root,root),dfs2(root,root);
st.build();
while(m--){
int a,b,c;
switch(read()){
case 1:{
a=read(),b=read(),c=read()%mod;
update(a,b,c);
break;
}
case 2:{
printf("%d\n",query(read(),read()));
break;
}
case 3:{
a=read(),b=read();
st.update(dfn[a],dfn[a]+siz[a]-1,b);
break;
}
case 4:{
a=read();
printf("%d\n",st.query(dfn[a],dfn[a]+siz[a]-1));
break;
}
}
}
}
}
signed main(){
star::work();
return 0;
}

给小狼留言