I love max and multiply
Code
代码抄的std
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
template <class T=int> T rd()
{
T res=0;T fg=1;
char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') fg=-1;ch=getchar();}
while( isdigit(ch)) res=(res<<1)+(res<<3)+(ch^48),ch=getchar();
return res*fg;
}
const int N=(1<<20)+10,mod=998244353,INF=2e9;
int A[N],B[N];
ll C[N];
int mxa[N],mna[N];
int mxb[N],mnb[N];
int n,m;
void solve(int k)
{
if(mna[k]!= INF&&mnb[k]!= INF) C[k]=max(C[k],1ll*mna[k]*mnb[k]);
if(mna[k]!= INF&&mxb[k]!=-INF) C[k]=max(C[k],1ll*mna[k]*mxb[k]);
if(mxa[k]!=-INF&&mnb[k]!= INF) C[k]=max(C[k],1ll*mxa[k]*mnb[k]);
if(mxa[k]!=-INF&&mxb[k]!=-INF) C[k]=max(C[k],1ll*mxa[k]*mxb[k]);
}
int main()
{
int Tc=rd();
while(Tc--)
{
n=rd();
for(int i=0;i<n;i++) A[i]=rd();
for(int i=0;i<n;i++) B[i]=rd();
m=1;
while(m<n) m<<=1;
for(int i=0;i<n;i++) mxa[i]=mna[i]=A[i],mxb[i]=mnb[i]=B[i];
for(int i=n;i<m;i++) mxa[i]=mxb[i]=-INF,mna[i]=mnb[i]=INF;
for(int j=1;j<m;j<<=1)
for(int i=m-1;i>=0;i--)
if(!(i&j)) // i的第j位是1
{
mxa[i]=max(mxa[i],mxa[i^j]);
mxb[i]=max(mxb[i],mxb[i^j]);
mna[i]=min(mna[i],mna[i^j]);
mnb[i]=min(mnb[i],mnb[i^j]);
}
C[n]=-1e18;
for(int i=n-1;i>=0;i--)
{
C[i]=-1e18;
solve(i);
C[i]=max(C[i],C[i+1]);
}
ll ans=0;
for(int i=0;i<n;i++) ans=(ans+C[i]%mod)%mod;
ans=(ans+mod)%mod;
printf("%lld\n",ans);
}
return 0;
}