后缀数组+单调栈
题解1
题解2
题解3
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
typedef long long ll;
const int N=200010;
char s[N],s1[N],s2[N];
int rk[N],sa[N],cnt[N],height[N];
int x[N],y[N];
int n,m,K;
int n1,n2;
void rsort()
{
for(int i=1;i<=m;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[x[i]]++;
for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i;i--) sa[cnt[x[y[i]]]--]=y[i];
}
void SA()
{
n=strlen(s+1);
m=300;
for(int i=1;i<=n;i++) x[i]=s[i],y[i]=i;
rsort();
for(int k=1;k<=n;k<<=1)
{
int p=0;
for(int i=n-k+1;i<=n;i++) y[++p]=i;
for(int i=1;i<=n;i++) if(sa[i]>k) y[++p]=sa[i]-k;
rsort();swap(x,y);
x[sa[1]]=1,p=1;
for(int i=2;i<=n;i++)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k]?p:++p);
if(p==n) break;
m=p;
}
for(int i=1;i<=n;i++) rk[sa[i]]=i;
for(int i=1,j=0;i<=n;i++)
{
if(j) --j;
while(s[i+j]==s[sa[rk[i]-1]+j]) j++;
height[rk[i]]=j;
}
}
int init()
{
n1=strlen(s1+1);
n2=strlen(s2+1);
for(int i=1;i<=n1;i++) s[i]=s1[i];
s[n1+1]='*';
for(int i=1;i<=n2;i++) s[i+n1+1]=s2[i];
s[n1+n2+1+1]='\0';
return (n1+n2+1);
}
int st[N][2];
ll solve()
{
ll ans=0,tot=0;
int tt=0;
for(int i=1;i<=n;i++)
{
if(height[i]<K){tt=0,tot=0;continue;}
int cnt=0;
if(sa[i-1]<=n1)
{
cnt++;
tot+=height[i]-K+1;
}
while(tt&&height[i]<=st[tt][0])
{
tot-=1ll*(st[tt][0]-height[i])*st[tt][1];
cnt+=st[tt][1];
tt--;
}
st[++tt][0]=height[i];
st[tt][1]=cnt;
if(sa[i]>n1+1) ans+=tot;
}
tt=0;
for(int i=1;i<=n;i++)
{
if(height[i]<K){tt=0,tot=0;continue;}
int cnt=0;
if(sa[i-1]>n1+1)
{
cnt++;
tot+=height[i]-K+1;
}
while(tt&&height[i]<=st[tt][0])
{
tot-=1ll*(st[tt][0]-height[i])*st[tt][1];
cnt+=st[tt][1];
tt--;
}
st[++tt][0]=height[i];
st[tt][1]=cnt;
if(sa[i]<=n1+1) ans+=tot;
}
return ans;
}
int main()
{
while(scanf("%d",&K),K)
{
scanf("%s%s",s1+1,s2+1);
n=init();
SA();
printf("%lld\n",solve());
}
return 0;
}