轻重链剖分教程(自用) + 模版

轻重链剖分教程(自用)

轻重链剖分

前置知识: DFS序, 线段树, LCA

轻重链剖分的作用

  1. 将树从x到y结点上的最短路径上的所有值都加上v

  2. 求树从x到y结点上的最短路径上的所有值的和

  3. 将以x为根结点的所有结点值加上v

  4. 求以x为根结点的所有结点值的和

    前面2问中x到y结点上的最短路径很明显就是求x和y的LCA

算法概念

​ 该算法就和它的名字一样就是将一颗树划分成多个链,然后维护这些链. 对于每条链可以看作是一个区间,对于区间操作一般我们选择线段树进行维护(线段树复杂度低,性质多,代码多,线段树nb!)。

​ 对于这个算法需要知道一些定义:

概念 定义
重儿子[结点] 父亲结点的所有儿子中子树结点数最多的结点
轻儿子[结点] 父亲结点中除了重儿子以外的所有儿子结点
重边[一条边] 父亲结点和重儿子连成的一条边
轻边[一条边] 父亲结点和轻儿子连成的一条边
重链[多条边] 由多条重边链接而成的路径,链头为轻链
轻链 由多条轻边链接而成的路径,链头是它自己

如下面一张图:

  • 粉色点代表重儿子
  • 黑色点代表轻儿子
  • 相邻红色边连起来代表重边
  • 相邻蓝色边连起来代表轻边
  • 所有红色边代表重链
  • 所有蓝色边代表轻链

对于儿子数相同的两个点来说选谁当重儿子都可以

算法实现

​ 首先需要对一颗树中的所有点进行重新编号, 默认一颗树上的点的编号可能会很乱,为了方便管理需要对这些点进行一个重新编号,但这编号不能乱编

​ 对于一颗树的编号顺序:

1. 先遍历重儿子结点进行编号

2. 按照dfs序进行进行编号,在遍历完重儿子后,就遍历其他轻儿子

使用这种方式遍历编号有几点好处.

  1. 重链中的所有点一定是连续的,因为每次是优先访问重儿子,所以它们之间的结点必然是连续的
  2. 如果(u,v)是一条轻边,那么size[v] <= size[u] / 2,如果大于那么v结点必定是重儿子
  3. 每个点到根结点,必定不会超过logn条重路径,和logn条轻路径


​ 对于一条重链来说我们需要存储链头(即深度最浅的结点),对于一条轻链来说它的链头就是它自己,重链的头结点也算是一个轻儿子.

​ 总结元素:

定义 含义
deep[x] 结点x的深度
fa[x] x结点的父亲
size[x] 以x为根结点的树大小
wson[x] 结点x的重儿子
dfn[x] 树的dfs序,为重新编号的序列
top[x] x结点的链头
pre[x] x结点的前驱结点

​ 求出上述结点只需要进行两次dfs操作即可,对于第一次dfs,主要求出每个结点的深度,父亲,树大小,以及重儿子,对于第二dfs根据重儿子求dfs序,并且记录链头以及前驱结点.

void dfs1(ll to, ll fa){
    sz[to] = 1;
    fas[to] = fa;   //设置父结点
    deep[to] = deep[fa] + 1; //计算深度
    for(ll i = head[to]; i + 1; i = e[i].nxt){
        if(e[i].to == fa) continue;
        dfs1(e[i].to, to);
        sz[to] += sz[e[i].to];  //计算树的大小
        if(sz[wson[to]] < sz[e[i].to]){ //获取重儿子
            wson[to] = e[i].to;
        }
    }
}
ll cnt = 0;
void dfs2(ll to, ll fa){
    dfn[to] = ++ cnt;   //dfs序列(重新编号)
    pre[cnt] = to;  //设置前驱结点
    top[to] = fa;   //设置链头
    if(wson[to]) dfs2(wson[to], fa);    //优先遍历重儿子
    else return;
    for(ll i = head[to]; i + 1; i = e[i].nxt){
        if(wson[to] == e[i].to || e[i].to == fas[to]) continue;
        dfs2(e[i].to, e[i].to); //轻儿子的链头是它自己
    }
}

对于两个dfs复杂度为o(2m)

​ 对于求x到y的最短路径通过链头进行LCA操作即可,每次优先修改链头深度最大的点向上跳跃一个区间并通过线段树修改跳跃区间值

void up_lca_tree(ll x, ll y, ll v){
    while(top[x] != top[y]){
        if(deep[top[x]] < deep[top[y]]) swap(x, y);
        update(1, 1, n, dfn[top[x]], dfn[x], v);
        x = fas[top[x]];
    }
    if(deep[x] < deep[y]) swap(x, y);
    update(1, 1, n, dfn[y], dfn[x], v);
}

​ 如下图找6和9点最短路径,先找6的链头为6(轻儿子的链头是自己)并修改,然后找3的链头为1并修改1-3结点的值(因为重链结点是连续的), 以此类推.最后能维护这段最短距离的区间(绿色指向为x的父亲)

​ 查找同理:

ll qe_lca_tree(ll x, ll y){
    ll ans = 0;
    while(top[x] != top[y]){
        if(deep[top[x]] < deep[top[y]]) swap(x, y);
        ans += query(1, 1, n, dfn[top[x]], dfn[x]);
        x = fas[top[x]];
    }
    if(deep[x] < deep[y]) swap(x, y);
    ans += query(1, 1, n, dfn[y], dfn[x]);
    return ans;
}

显然我们知道线段树的修改和查找复杂度为O(logn),查找树上的距离取决于这条路径上的重链和轻链,由性质可知,重链和轻链的数量不会超过2logn,所以最短距离的修改和查询的总复杂度为O(log^2n)

​ 对于上述的第3,4问修改/查找以x为根结点的所有子结点的值,其实通过上述的编号性质能得出,以x为根结点的所有子树编号范围为:dfn[x] + sz[dfn[x]] - 1,只需要使用线段树进行修改/查询即可

void tree_up(ll x, ll v){
    update(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, v);
}
ll tree_qe(ll x){
    return query(1 , 1, n, dfn[x], dfn[x] + sz[x] - 1);
}

树链剖分的总复杂度为O(mlog^n)

模版代码

题目:【模板】轻重链剖分

链接: https://www.luogu.com.cn/problem/P3384

ll n, m, r, p;
ll a[N];
struct Node{
    ll to, nxt;
}e[N];
ll head[N], tot = 0;
void add(ll a, ll b){
    e[tot].to = b;
    e[tot].nxt = head[a];
    head[a] = tot ++ ;
}
ll tree[N], lazy[N];
ll ls[N], rs[N];
ll top[N], pre[N], wson[N], dfn[N], sz[N], deep[N], fas[N];
void build(ll rt, ll l, ll r){
    ls[rt] = l; rs[rt] = r;
    if(l == r){
        tree[rt] = a[pre[l]];
        return;
    }
    ll mid = (l + r) >> 1;
    build(rt << 1, l, mid);
    build(rt << 1 | 1, mid + 1, r);
    tree[rt] = tree[rt << 1] + tree[rt << 1 | 1];
}
ll vis[N];
void dfs1(ll to, ll fa){
    sz[to] = 1;
    fas[to] = fa;
    deep[to] = deep[fa] + 1;
    for(ll i = head[to]; i + 1; i = e[i].nxt){
        if(e[i].to == fa) continue;
        dfs1(e[i].to, to);
        sz[to] += sz[e[i].to];
        if(sz[wson[to]] < sz[e[i].to]){
            wson[to] = e[i].to;
        }
    }
}
ll cnt = 0;
void dfs2(ll to, ll fa){
    dfn[to] = ++ cnt;
    pre[cnt] = to;
    top[to] = fa;
    if(wson[to]) dfs2(wson[to], fa);
    else return;
    for(ll i = head[to]; i + 1; i = e[i].nxt){
        if(wson[to] == e[i].to || e[i].to == fas[to]) continue;
        dfs2(e[i].to, e[i].to);
    }
}
void push_lazy(ll x){
    if(lazy[x]){
        lazy[x << 1] += lazy[x];
        lazy[x << 1 | 1] += lazy[x];
        tree[x << 1] += lazy[x] * (rs[x << 1] - ls[x << 1] + 1);
        tree[x << 1 | 1] += lazy[x] * (rs[x << 1 | 1] - ls[x << 1 | 1] + 1);
        lazy[x] = 0;
    }
}
void update(ll rt, ll l, ll r, ll L, ll R, ll v){
    if(L <= l && r <= R){
        tree[rt] += v * (r - l + 1);
        lazy[rt] += v;
        tree[rt] %= p;
        return ;
    }
    push_lazy(rt);
    ll mid = (l + r) >> 1;
    if(L <= mid) update(rt << 1, l, mid, L, R, v);
    if(mid < R) update(rt << 1 | 1, mid + 1, r, L, R, v);

    tree[rt] = tree[rt << 1] + tree[rt << 1 | 1];
}
ll query(ll rt, ll l, ll r, ll L, ll R){
    if(L <= l && r <= R){
        return tree[rt];
    }
    push_lazy(rt);
    ll ans = 0;
    ll mid = (l + r) >> 1;
    if(L <= mid) ans += query(rt << 1, l, mid, L, R);
    if(mid < R) ans += query(rt << 1 | 1, mid + 1, r, L, R);
    tree[rt] = tree[rt << 1] + tree[rt << 1 | 1];
    return ans;
}
void up_lca_tree(ll x, ll y, ll v){
    v %= p;
    while(top[x] != top[y]){
        if(deep[top[x]] < deep[top[y]]) swap(x, y);
        update(1, 1, n, dfn[top[x]], dfn[x], v);
        x = fas[top[x]];
    }
    if(deep[x] < deep[y]) swap(x, y);
    update(1, 1, n, dfn[y], dfn[x], v);
}
ll qe_lca_tree(ll x, ll y){
    ll ans = 0;
    while(top[x] != top[y]){
        if(deep[top[x]] < deep[top[y]]) swap(x, y);
        ans += query(1, 1, n, dfn[top[x]], dfn[x]);
        x = fas[top[x]];
    }
    if(deep[x] < deep[y]) swap(x, y);
    ans += query(1, 1, n, dfn[y], dfn[x]);
    return ans;
}
void tree_up(ll x, ll v){
    v %= p;
    update(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, v);
}
ll tree_qe(ll x){
    return query(1 , 1, n, dfn[x], dfn[x] + sz[x] - 1);
}
int main(){
    MM(head, -1);
    RLL4(n, m, r, p);
    FOR(i, 1, n) RLL(a[i]);
    FOR(i, 1, n - 1){
        ll l, r; RLL2(l, r);
        add(l, r);
        add(r, l);
    }
    dfs1(r, 0);
    dfs2(r, r);
    build(1, 1, n);
    while(m --){
        ll s; RLL(s);
        if(s == 1){
            ll x, y, z;
            RLL3(x, y, z);
            up_lca_tree(x, y, z);
        }else if(s == 2){
            ll x, y;
            RLL2(x, y);
            PLN(qe_lca_tree(x, y) % p );
        }else if(s == 3){
            ll x, z;
            RLL2(x, z);
            tree_up(x, z);
        }else{
            ll x; RLL(x);
            PLN(tree_qe(x) % p);
        }
    }

    return 0;
}
赞赏