SJY摆棋子 HYSBZ – 2648 – 初识 K-D 树 – 模板

关于K-D树

k-d树学习笔记
基础-12:15分钟理解KD树

代码(模板)

#include <iostream>
#include <bits/stdc++.h>
using namespace std;

const int MAX=500050;
const int dim=2;
const int inf=0x3f3f3f3f;

struct node
{
    int l,r;
    //d 数组储存当前点,
    //minn 和 maxn 表示当前节点维护的矩形的边界
    int d[dim],minn[dim],maxn[dim];
    //对节点进行初始化
    inline void maintail()
    {
        for(int i=0;i<dim;i++)
        {
            minn[i]=maxn[i]=d[i];
        }
        l=r=0;
    }
}tree[MAX*2];

//通过修改全局变量 D,实现按不同维度排序
int D;
bool operator <(const node &a,const node &b)
{
    return a.d[D]<b.d[D];
}

void pushup(int p)
{
    int son[2]={tree[p].l,tree[p].r};
    for(int i=0;i<2;i++)
    {
        if(!son[i])continue;
        for(int j=0;j<dim;j++)
        {
            tree[p].maxn[j]=max(tree[son[i]].maxn[j],tree[p].maxn[j]);
            tree[p].minn[j]=min(tree[son[i]].minn[j],tree[p].minn[j]);
        }
    }
}

int build(int l,int r,int now)
{
    int mid=(l+r)>>1;
    D=now;
    //将中间数放到 mid 位置,小于中间数的放左边,大于的放右边,
    //不保证左右边有序,类似快排的一部分,复杂度O(N)
    nth_element(tree+l,tree+mid,tree+r+1);
    //初始化节点信息
    tree[mid].maintail();
    if(l<mid)tree[mid].l=build(l,mid-1,(now+1)%dim);
    //(now+1)%dim实现了维度的交替
    if(mid<r)tree[mid].r=build(mid+1,r,(now+1)%dim);
    //维护子树信息
    pushup(mid);
    return mid;
}

void insert(int &o,int k,int now)
{
    if(o==0)
    {
        o=k;
        return;
    }
    if(tree[k].d[now]<tree[o].d[now])insert(tree[o].l,k,(now+1)%dim);
    else insert(tree[o].r,k,(now+1)%dim);
    pushup(o);
}

int ans;

//当前查找的点距离节点维护的矩阵的最近距离
int partionmin(int o,int k)
{
    int rst=0;
    for(int i=0;i<dim;i++)
    {
        if (tree[k].d[i] > tree[o].maxn[i]) rst += tree[k].d[i] - tree[o].maxn[i];
        if (tree[k].d[i] < tree[o].minn[i]) rst += tree[o].minn[i] - tree[k].d[i];
    }
    return rst;
}

void query(int o,int k)
{
    //通过当前节点储存的点更新答案
    int dm=abs(tree[o].d[0]-tree[k].d[0])+abs(tree[o].d[1]-tree[k].d[1]);
    if(dm<ans)
        ans=dm;
    //计算左右子树距离当前点可能的最近的答案
    int dl=tree[o].l?partionmin(tree[o].l,k):inf;
    int dr=tree[o].r?partionmin(tree[o].r,k):inf;
    //通过搜索顺序进行剪枝
    if(dl<dr)
    {
        //如果最近可能的点都大于答案,那么不可能更新答案
        if(dl<ans)query(tree[o].l,k);
        if(dr<ans)query(tree[o].r,k);
    }
    else
    {
        if(dr<ans)query(tree[o].r,k);
        if(dl<ans)query(tree[o].l,k);
    }
}

int root=0,pos=1;

int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        for(int j=0;j<dim;j++)
        {
            scanf("%d",&tree[i].d[j]);
        }
    }
    root=build(1,n,0);
    pos=n+1;
    int op;
    while(m--)
    {
        scanf("%d",&op);
        for(int j=0;j<dim;j++)
        {
            scanf("%d",&tree[pos].d[j]);
        }
        if(op==1)
        {
            tree[pos].maintail();
            insert(root,pos,0);
            pos++;
        }
        else
        {
            ans=inf;
            query(root,pos);
            printf("%d\n",ans);
        }
    }
    return 0;
}

附:欧几里得距离函数

inline ll partionMin(int o, int k) {
    if (tree[o].minn[2] > tree[k].d[2]) return INF;
    ll rst = 0;
    for (int i = 0; i < 2; i++) {
        if (tree[k].d[i] > tree[o].maxn[i]) rst += sqr(tree[k].d[i] - tree[o].maxn[i]);
        if (tree[k].d[i] < tree[o].minn[i]) rst += sqr(tree[o].minn[i] - tree[k].d[i]);
    }
    return rst;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注