传送门:铁人两项
简述一下题目:
给出一个(不一定联通)的图,求有多少个三元组(s,c,f)满足s,c,f都是图中的点,且存在一条从s到c的路径和一条从c到f的路径,使得两条路径没有公共点(除c以外)。
这个题当时刚接触到圆方树,我的想法跟正解十分接近使我非常兴奋。
这个题我们想一下如果n2的话我们要怎么做:
枚举两个圆点s,f。路径上所有的点双中的点都可以作为c。如何方便地统计呢?首先我们建出圆方树,把圆点权值设为-1(因为正常计算有重复路径,这样直接免去容斥减的过程),方点权值设为点双的大小,则s到f的路径上的点(包括s,f)的权值和,也就是c的个数。这是n方的。
如果枚举中间点,则很容易求出树上有多少个圆圆路径经过这个点,通过这个点的子树大小直接O(1)进行计算即可,这样我们只用把所有的点枚举一遍即可,这样是O(n)的。
代码先咕咕咕
代码成功的没有咕咕咕:
#define B cout << "BreakPoint" << endl;
#define O(x) cout << #x << " " << x << endl;
#define O_(x) cout << #x << " " << x << " ";
#define Msz(x) cout << "Sizeof " << #x << " " << sizeof(x)/1024/1024 << " MB" << endl;
#include<cstdio>
#include<cmath>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#include<stack>
#define LL long long
#define inf 1000000009
#define N 1000005
using namespace std;
inline int read() {
int s = 0,w = 1;
char ch = getchar();
while(ch < '0' || ch > '9') {
if(ch == '-')
w = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9') {
s = s * 10 + ch - '0';
ch = getchar();
}
return s * w;
}
LL ans;
int vis[N];
int n,m,top,res,tot,s;
int dfn[N],low[N],stk[N],val[N],sz[N];
struct Graph {
int head[N],nxt[N << 1],to[N << 1];
int ecnt;
inline void add(int u,int v) {
to[++ecnt] = v;
nxt[ecnt] = head[u];
head[u] = ecnt;
return;
}
inline void init(int u,int v) {
add(u,v);
add(v,u);
return;
}
} eold,enew;
inline void cmin(int &x,int y) {
if(x > y) x = y;
return;
}
void tarjan(int u) {
dfn[u] = low[u] = ++tot;
stk[++top] = u;
sz[u] = 1;
for(int i = eold.head[u]; i; i = eold.nxt[i]) {
int v = eold.to[i];
if(!dfn[v]) {
tarjan(v);
cmin(low[u],low[v]);
if(low[v] >= dfn[u]) {
int t = 0,cnt = 1;
res++;
while(t != v) {
t = stk[top--];
cnt++;
enew.init(res,t);
sz[res] += sz[t];
}
val[res] = cnt;
sz[u] += sz[res];
enew.init(res,u);
}
} else {
cmin(low[u],dfn[v]);
}
}
return ;
}
void dfs(int u,int fa) {
int x = u <= n;
ans += 2ll * sz[u] * (s - sz[u]) * val[u];
for(int i = enew.head[u]; i; i = enew.nxt[i]) {
int v = enew.to[i];
if(v == fa) {
continue;
}
ans += 2ll * x * sz[v] * val[u];
x += sz[v];
dfs(v,u);
}
return ;
}
void pre() {
n = read(),m = read();
res = n;
memset(val,-1,sizeof(val));
for(int i = 1; i <= m; i++) {
int u = read(),v = read();
eold.init(u,v);
}
return ;
}
void solve() {
for(int i = 1; i <= n; i++) {
if(!dfn[i]) {
tarjan(i);
s = sz[i];
dfs(i,-1);
}
}
printf("%lld",ans);
return ;
}
int main() {
pre();
solve();
return 0;
}