最近公共祖先问题模板:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdlib>
#include<string>
#include<vector>
#define N 30000
using namespace std;
int id[N],vis[N],depth[N],in[N],dp[N][20],k;
vector<int>G[N];
int Min(int i,int j)
{
if(depth[i]<=depth[j])
return i;
return j;
}
void rmq_init(int n)
{
for(int i=1; i<=n; i++)
dp[i][0]=i;
for(int j=1; (1<<j)<=n; j++)
for(int i=1; i+(1<<j)-1<=n; i++)
dp[i][j]=Min(dp[i][j-1],dp[i+(1<<(j-1))][j-1]);
}
void dfs(int v,int d)
{
id[v]=k;
vis[k]=v;
depth[k]=d;
k++;
for(int i=0; i<G[v].size(); i++)
{
dfs(G[v][i],d+1);
vis[k]=v;
depth[k]=d;
k++;
}
}
void init(int root,int n)
{
k=1;
dfs(root,0);
rmq_init(n*2-1);
}
int query(int l,int r)
{
int k=(int)(log10(r-l+1)/log10(2.0));
return Min(dp[l][k],dp[r-(1<<k)+1][k]);
}
int main()
{
int t,n,a,b,x,y;
cin>>t;
while(t--)
{
for(int i=0; i<N; i++)
G[i].clear();
memset(in,0,sizeof(in));
cin>>n;
for(int i=1; i<n; i++)
{
scanf("%d%d",&a,&b);
G[a].push_back(b);
in[b]=1;
}
int i;
for(i=1; i<=n; i++)
if(in[i]==0)
break;
init(i,n);
cin>>x>>y;
int ans=query(min(id[x],id[y]),max(id[x],id[y]));
printf("%d\n",vis[ans]);
}
return 0;
}
求解树上两点相距的最近距离模板:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdlib>
#include<string>
#include<vector>
#define N 40010
using namespace std;
int vis[2*N],depth[2*N],id[N],dis[N],dp[N*2][20],k;
struct node
{
int to,w;
};
vector<node>G[N];
int Min(int a,int b)
{
if(depth[a]<=depth[b])
return a;
return b;
}
void rmq_init(int n)
{
for(int i=1; i<=n; i++)
dp[i][0]=i;
for(int j=1; (1<<j)<=n; j++)
for(int i=1; i+(1<<j)-1<=n; i++)
dp[i][j]=Min(dp[i][j-1],dp[i+(1<<(j-1))][j-1]);
}
void dfs(int v,int p,int d)
{
id[v]=k;
vis[k]=v;
depth[k]=d;
k++;
for(int i=0; i<G[v].size(); i++)
{
if(G[v][i].to!=p)
{
dis[G[v][i].to]=dis[v]+G[v][i].w;
dfs(G[v][i].to,v,d+1);
vis[k]=v;
depth[k]=d;
k++;
}
}
}
void init(int n)
{
k=1;
dfs(1,-1,0);
rmq_init(2*n-1);
}
int query(int l,int r)
{
int k=(int)(log10(r-l+1)/log10(2.0));
int num=Min(dp[l][k],dp[r-(1<<k)+1][k]);
return vis[num];
}
int main()
{
int n,m,a,b,c,q,x,y;
char ch[10];
while(cin>>n>>m)
{
for(int i=0; i<N; i++)
{
G[i].clear();
id[i]=0;
dis[i]=0;
}
for(int i=0; i<m; i++)
{
scanf("%d%d%d%s",&a,&b,&c,ch);
node nd;
nd.to=b,nd.w=c;
G[a].push_back(nd);
nd.to=a,nd.w=c;
G[b].push_back(nd);
}
init(n);
cin>>q;
while(q--)
{
scanf("%d%d",&x,&y);
int ans=query(min(id[x],id[y]),max(id[x],id[y]));
printf("%d\n",dis[x]+dis[y]-2*dis[ans]);
}
}
return 0;
}