2021牛客多校第一场[线段树合并+树重心] C:Cut the Tree

2021牛客多校补题

前置知识:线段树合并 树重心

2021牛客暑期多校训练营1 C.Cut the Tree

题意

给一个带点权的树,你可以删去树上一个点,最小化所有子树最长上升子序列的长度最大值 。 N <= 100000

题解

首先竟然是最小化所有子树的最长上升子序列,那么我们第一步就是需要找到一颗树上的最长上升子序列,考虑暴力的做法,只使用搜索做对于每个点暴力搜索统计最高会用到n^ 2复杂度很明显会超时。

我们换一种想法,对树中的每个点存储当前树与子树的最长上升子序列,对于每个点都存储一段区间,如果每次插入一个值都要遍历整颗树那么还是会超时,所以可以考虑线段树合并,因为每个结点相当于存储了一个线段树,我们只需要从树中的叶子结点依次向上合并即可.

因为可能会出现以下情况这时候可以由中间6点分开左边为最长上升子序列右边为最长下降子序列.所以需要单独开一个数组来维护最长下降子序列.

那么子序列找到了,现在来考虑删点的操作,直观上很明显要想答案最优,删除的点一定在最长子序列的链上,而且删除的点要尽可能的在中间,这样能确保这条链分出来的序列会最小,但这样还是很难确定答案,因为就算删除最长链的最中心点,还是可能会出现其他比最长链短一点链更长,所以,要想比删除最长链中点要优一点的做法就是找所有链都有交集的点去删.因为若干条链的交集一定还是一条链,所以删除了一条链的中点后还是可以继续删除除删除点以外的最长链,直到交集中没有其他链为止.

那么如何找到最长链中心呢,可以先找到树的重心,求去除中心后各边的最大子序列子。

//找树重心
void dfs(int to, int fa){
    sz[to] = 1;
    int mx = 0;
    for(int i = head[to]; i + 1; i = e[i].nxt){
        int y = e[i].to;
        if(y == fa || vis[y]) continue;
        dfs(y, to);
        sz[to] += sz[y];
        mx = max(mx, sz[y]);
    }
    mx = max(mx, sum - sz[to]);
    if(mx < minNode){
        minNode = mx;
        root = to;
    }
}

之后在找去除已找到重心点的重心找最长子序列反复这样操作直到找不到新的重心点为止。

void slove(int x){
    vis[x] = 1;
    int mx = 0, ret = -1;
    for(int i = head[x]; i + 1; i = e[i].nxt){	//对所有经过重心的边求最长子序列
        int y = e[i].to;
        res = T = 0;
        int tt = result_dfs(y, x);
        if(tt > mx){	//找到最长的子序列
            mx = tt; ret = y;
        }
    }
    ans = min(ans, mx);
    if(vis[ret] || ret == -1) return;
    minNode = root = sum = sz[ret];
    dfs(ret, 0);
    slove(root);
}

之后在附上找最长链的代码,因为在找重心周边的子序列时可能会出现比更长的子序列在重心边的子树下。所以在线段树合并时需要吧这种情况加进去

int merge(int x, int y){	//线段树合并
    if(!x) return y;
    if(!y) return x;
    up[x] = max(up[x], up[y]);
    down[x] = max(down[x], down[y]);
    res = max(res, max(up[lc[x]] + down[rc[y]], up[lc[y]] + down[rc[x]]));	//找重心对应边子树的最长子序列
    lc[x] = merge(lc[x], lc[y]);
    rc[x] = merge(rc[x], rc[y]);
    return x;
}
int result_dfs(int to, int fa){
    rt[to] = 0;
    for(int i = head[to]; i + 1; i = e[i].nxt){	//找到叶子结点从下往上
        ll y = e[i].to;
        if(y == fa) continue;
        result_dfs(y, to);
    }
    int now_up = 0, now_down = 0;
    for(int i = head[to]; i + 1; i = e[i].nxt){
        int y = e[i].to;
        if(y == fa) continue;
        int ups = query(rt[y], 1, n, 1, max(a[to] - 1, (int)1), up);	
        int downs = query(rt[y], 1, n, min(a[to] + 1, n), n, down);
        rt[to] = merge(rt[to], rt[y]);
        res = max(res, ups + now_down + 1);
        res = max(res, now_up + downs + 1);
        now_up = max(now_up, ups);
        now_down = max(now_down, downs);
    }
    update(rt[to], 1, n, a[to], a[to], now_up + 1, now_down + 1);
    return res;
}

线段树合并最高复杂度会到nlogn但因为每次合并树不会是满树所以会比这复杂度再低一点,之后找链的重心次数最多不会超过logn次所以总复杂度为O(nlog^2n)

完整代码


int n;
struct Node{
    int to, nxt;
}e[M];
int head[M], tot = 0;
void add(int a, int b){
    e[tot].to = b;
    e[tot].nxt = head[a];
    head[a] = tot ++;
}
int a[N];
int root = 0, minNode = inf, sum = 0;
int sz[N], vis[N];
void dfs(int to, int fa){
    sz[to] = 1;
    int mx = 0;
    for(int i = head[to]; i + 1; i = e[i].nxt){
        int y = e[i].to;
        if(y == fa || vis[y]) continue;
        dfs(y, to);
        sz[to] += sz[y];
        mx = max(mx, sz[y]);
    }
    mx = max(mx, sum - sz[to]);
    if(mx < minNode){
        minNode = mx;
        root = to;
    }
}
int res = 0, ans = inf, T = 0;
int rt[M], lc[M], rc[M], up[M], down[M];
int query(int &rt, int l, int r, int L, int R, int *h){
    if(!rt) return 0;
    if(L <= l && r <= R) return h[rt];
    int mid = (l + r) >> 1, ans = 0;
    if(L <= mid) ans = max(ans, query(lc[rt], l, mid, L, R, h));
    if(mid < R) ans = max(ans, query(rc[rt], mid + 1, r, L, R, h));
    return ans;
}
int merge(int x, int y){
    if(!x) return y;
    if(!y) return x;
    up[x] = max(up[x], up[y]);
    down[x] = max(down[x], down[y]);
    res = max(res, max(up[lc[x]] + down[rc[y]], up[lc[y]] + down[rc[x]]));
    lc[x] = merge(lc[x], lc[y]);
    rc[x] = merge(rc[x], rc[y]);
    return x;
}
void update(int &rt, int l, int r, int L, int R, int v1, int v2){
    if(!rt){
        rt = ++ T;
        up[rt] = down[rt] = lc[rt] = rc[rt] = 0;
    }
    if(L <= l && r <= R){
        up[rt] = v1;
        down[rt] =  v2;
        return ;
    }
    int mid = (l + r) >> 1;
    if(L <= mid) update(lc[rt], l, mid, L, R, v1, v2);
    else update(rc[rt], mid + 1, r, L, R, v1, v2);
    up[rt] = max(up[lc[rt]], up[rc[rt]]);
    down[rt] = max(down[lc[rt]], down[rc[rt]]);
}
int result_dfs(int to, int fa){
    rt[to] = 0;
    for(int i = head[to]; i + 1; i = e[i].nxt){
        ll y = e[i].to;
        if(y == fa) continue;
        result_dfs(y, to);
    }
    int now_up = 0, now_down = 0;
    for(int i = head[to]; i + 1; i = e[i].nxt){
        int y = e[i].to;
        if(y == fa) continue;
        int ups = query(rt[y], 1, n, 1, max(a[to] - 1, (int)1), up);
        int downs = query(rt[y], 1, n, min(a[to] + 1, n), n, down);
        rt[to] = merge(rt[to], rt[y]);
        res = max(res, ups + now_down + 1);
        res = max(res, now_up + downs + 1);
        now_up = max(now_up, ups);
        now_down = max(now_down, downs);
    }
    update(rt[to], 1, n, a[to], a[to], now_up + 1, now_down + 1);
    return res;

}
void slove(int x){
    vis[x] = 1;
    int mx = 0, ret = -1;
    for(int i = head[x]; i + 1; i = e[i].nxt){
        int y = e[i].to;
        res = T = 0;
        int tt = result_dfs(y, x);
        if(tt > mx){
            mx = tt; ret = y;
        }
    }
    ans = min(ans, mx);
    if(vis[ret] || ret == -1) return;
    minNode = root = sum = sz[ret];
    dfs(ret, 0);
    slove(root);
}
int main(){
    cin >> n;
    MM(head, - 1);
    FOR(i, 1, n - 1){
        int l, r; cin >> l >> r;
        add(l, r); add(r, l);
    }
    FOR(i, 1, n)  cin >> a[i];
    sum = n;
    dfs(1, 0);
    slove(root);
    PLN(ans);

    return 0;
}
赞赏