【模板】点分治

点分质

学习连接:https://zhuanlan.zhihu.com/p/42102528

题目一

题目大意

给定一棵有n个点的树,询问树上距离为k的点对是否存在。

题目链接

https://www.luogu.org/problemnew/show/P3806

代码

#include <bits/stdc++.h>
const int N=10005;
const int NN=10000005;
using namespace std;
struct node {
    int to,w,next;
} e[N*2];
int head[N],cnt;
int n,m,q[105],sum;
int maxx[N],sz[N],rt;
int vis[N],dis[N],f[N],tot,ff[N];
int judge[NN];
int book[NN];
void init() {
    memset(head,-1,sizeof(head));
    cnt=0;
}
void add(int u,int v,int w) {
    e[cnt]= {v,w,head[u]};
    head[u]=cnt++;
}
void getrt(int x,int pre) {
    //找树的重心
    sz[x]=1;
    maxx[x]=0;
    for(int i=head[x]; i!=-1; i=e[i].next) {
        int to=e[i].to;
        if(to!=pre&&!vis[to]) {
            getrt(to,x);
            sz[x]+=sz[to];
            maxx[x]=max(maxx[x],sz[to]);
        }
    }
    maxx[x]=max(maxx[x],sum-sz[x]);
    if(maxx[x]<maxx[rt]) {
        rt=x;
    }
}
void getdis(int x,int pre) {
    //处理根到任意一点的距离
    f[++tot]=dis[x];
    for(int i=head[x]; i!=-1; i=e[i].next) {
        int to=e[i].to;
        if(to==pre||vis[to])
            continue;
        dis[to]=dis[x]+e[i].w;
        getdis(to,x);
    }
}
void cal(int x) {
    int p=0;
    for(int i=head[x]; i!=-1; i=e[i].next) {
        int to=e[i].to;
        int w=e[i].w;
        if(vis[to])
            continue;
        tot=0;
        dis[to]=w;
        getdis(to,x);//处理子树中所有的dis
        for(int j=1; j<=tot; j++) {
            for(int k=1; k<=m; k++) {
                if(q[k]>=f[j]) {
                    //若查询q[k],当前有f[j],若之前出现过q[k]-f[j]则可以
                    book[q[k]]|=judge[q[k]-f[j]];
                }
            }
        }
        for(int j=1; j<=tot; j++) {
            //将出现过的dis都放入ff,judge代表以x为根的树中出现过
            ff[++p]=f[j];
            judge[f[j]]=1;
        }
    }
    for(int i=1; i<=p; i++) {//恢复
        judge[ff[i]]=0;
    }
}
void solve(int x) {
    vis[x]=1;
    judge[0]=1;
    cal(x);
    for(int i=head[x]; i!=-1; i=e[i].next) {
        int to=e[i].to;
        if(vis[to])
            continue;
        //找到子树的重心并递归处理
        rt=0;
        maxx[rt]=0x3f3f3f3f;
        sum=sz[to];
        getrt(to,0);
        solve(rt);
    }
}
int main() {
    ios::sync_with_stdio(false);
    init();
    cin>>n>>m;
    for(int i=1; i<n; i++) {//建边
        int u,v,w;
        cin>>u>>v>>w;
        add(u,v,w);
        add(v,u,w);
    }
    for(int i=1; i<=m; i++) {
        cin>>q[i];
    }
    rt=0;
    maxx[rt]=0x3f3f3f3f;
    sum=n;
    getrt(1,0);//第一次先找整棵树的重心
    solve(rt);//点分质
    for(int i=1; i<=m; i++) {
        if(book[q[i]])
            cout<<"AYE"<<endl;
        else
            cout<<"NAY"<<endl;
    }
    return 0;
}

题目二

题目链接

https://codeforces.com/problemset/problem/161/D

题目大意

问有多少条长度等于k的路径

代码

与之前相似只不过将ff标记数组改为了num数量数组

#include <bits/stdc++.h>
const int N=50005;
using namespace std;
int cnt=0;
int head[N],sz[N],vis[N],n,k;
int rt,maxx[N],sum;
int num[550],numm[550],dis[N];
long long ans=0;
struct node{
    int v,next,w;
}edge[N*2];
void init(){
    memset(head,-1,sizeof(head));
}
void add(int u,int v,int w){
    edge[cnt]={v,head[u],w};
    head[u]=cnt++;
}
void getrt(int x,int pre){
    sz[x]=1;
    maxx[x]=0;
    for(int i=head[x];i!=-1;i=edge[i].next){
        int to=edge[i].v;
        if(to!=pre&&!vis[to]){
            getrt(to,x);
            sz[x]+=sz[to];
            maxx[x]=max(maxx[x],sz[to]);
        }
    }
    maxx[x]=max(maxx[x],sum-sz[x]);
    if(maxx[x]<maxx[rt]){
        rt=x;
    }
}
void getdis(int x,int pre){
    if(dis[x]<=k)
        num[dis[x]]++;
    for(int i=head[x];i!=-1;i=edge[i].next){
        int to=edge[i].v;
        if(to==pre||vis[to])
            continue;
        dis[to]=dis[x]+edge[i].w;
        getdis(to,x);
    }
}
void cal(int x){
    for(int i=head[x];i!=-1;i=edge[i].next){
        int to=edge[i].v;
        int w=edge[i].w;
        if(vis[to])
            continue;
        for(int j=0;j<=k;j++)
            num[j]=0;
        dis[to]=w;
        getdis(to,x);
        for(int j=0;j<=k;j++){
            ans+=num[j]*numm[k-j];
        }
        for(int j=0;j<=k;j++)
            numm[j]+=num[j];
    }
    for(int i=0;i<=k;i++){
        numm[i]=0;
    }
}
void solve(int x){
    vis[x]=1;
    numm[0]=1;
    cal(x);
    for(int i=head[x];i!=-1;i=edge[i].next){
        int to=edge[i].v;
        if(vis[to])
           continue;
        rt=0;
        maxx[rt]=0x3f3f3f3f;
        sum=sz[to];
        getrt(to,0);
        solve(rt);
    }
}
int main()
{
    ios::sync_with_stdio(false);
    init();
    cin>>n>>k;
    int u,v;
    for(int i=1;i<n;i++){
        cin>>u>>v;
        add(u,v,1);
        add(v,u,1);
    }
    rt=0;
    maxx[rt]=0x3f3f3f3f;
    sum=n;
    getrt(1,0);
    solve(rt);
    cout<<ans<<endl;
    return 0;
}

题目三

题目链接

https://www.luogu.org/problemnew/show/P4178

题目大意

输出距离小于等于k的点对的数量。

解题思路

我们可以把每一个子节点到当前根(重心)的距离排序,然后用类似双指针的方法来求小于等于K的边的数量。然后再用类似容斥的思想将不合法的情况剪掉。

代码

#include <bits/stdc++.h>
const int N=40040;
using namespace std;
int n,k,cnt,rt,maxx[N],sz[N];
int head[N],vis[N],f[N],dis[N];
int sum=0,tot;
long long ans=0;
struct node
{
    int v,w,next;
}edge[N*2];
void init(){
    memset(head,-1,sizeof(head));
}
void add(int u,int v,int w){
    edge[cnt]={v,w,head[u]};
    head[u]=cnt++;
}
void getrt(int x,int pre){
    sz[x]=1;
    maxx[x]=0;
    for(int i=head[x];i!=-1;i=edge[i].next){
        int to=edge[i].v;
        if(to!=pre&&!vis[to]){
            getrt(to,x);
            sz[x]+=sz[to];
            maxx[x]=max(maxx[x],sz[to]);
        }
    }
    maxx[x]=max(maxx[x],sum-sz[x]);
    if(maxx[x]<maxx[rt]){
        rt=x;
    }
}
void getdis(int x,int pre){
    f[tot++]=dis[x];
    for(int i=head[x];i!=-1;i=edge[i].next){
        int to=edge[i].v;
        int w=edge[i].w;
        if(pre==to||vis[to])
            continue;
        dis[to]=dis[x]+w;
        getdis(to,x);
    }
}
int cal(int x,int w){
    dis[x]=w;
    tot=0;
    getdis(x,0);
    sort(f,f+tot);
    int ans=0;
    int l=0,r=tot-1;
    while(l<=r){
        if(f[l]+f[r]<=k){
            ans+=r-l;
            l++;
        }
        else
            r--;
    }
    return ans;
}
void solve(int x){
    vis[x]=1;
    ans+=cal(x,0);
    for(int i=head[x];i!=-1;i=edge[i].next){
        int to=edge[i].v;
        int w=edge[i].w;
        if(vis[to])
            continue;
        ans-=cal(to,w);
        maxx[rt=0]=0x3f3f3f3f;
        sum=sz[to];
        getrt(to,0);
        solve(rt);
    }
}
int main()
{
    ios::sync_with_stdio(false);
    init();
    cin>>n;
    for(int i=1;i<n;i++){
        int u,v,w;
        cin>>u>>v>>w;
        add(u,v,w);
        add(v,u,w);
    }
    cin>>k;
    rt=0;
    maxx[rt]=0x3f3f3f3f;
    sum=n;
    getrt(1,0);
    solve(rt);
    cout<<ans<<endl;
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注