线段树专题
第一道题是来自codeforces的438D,题目的题意是给你一个初始化的数组,然后给你几种操作,如果操作码为1的时候计算区间【l,r】的区间和,然后操作码为2时,将区间【l,r】之间的数字进行取摸,如果操作码为3的时候将位置为k的数字改成x。
初步一看就能很快想到线段树,用线段树维护我们的树,唯一的难点在我们对区间取摸的操作,如果反复取摸会出现t的情况,那我们怎么优美的暴力呢,如果我们知道每个区间的最大值,我只要判断区间最大值和取摸数之间的关系,就可以知道是否需要去更改区间,那么我们的时间复杂度一下就降下来了,而其他的操作就是正常的线段树单点修改和区间求和的操作,下面是代码。
#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
#include<stack>
#include<queue>
#include<algorithm>
#include<cmath>
#include<set>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
const int maxn = 1e5+10;
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll qpow(ll a, ll b, ll m) {
ll ans = 1;
ll k = a;
while (b) {
if (b & 1)ans = ans * k % m;
k = k * k % m;
b >>= 1;
}
return ans;
}
int fa[maxn];
int find(int x)
{
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
void baba(int x, int y)
{
int fx = find(x), fy = find(y);
if (fx != fy)
{
fa[fx] = fy;
}
}
ll sum[maxn * 4], val[maxn * 4];
//sum线段树的总和,val线段树的线段最大值
void build(int x, int l, int r)
{
if (l == r)
{
cin >> sum[x];
val[x] = sum[x];
return;
}
int mid = (l + r) / 2;
build(x * 2, l, mid);
build(x * 2 + 1, mid + 1, r);
val[x] = max(val[x * 2], val[x * 2 + 1]);
sum[x] = sum[x * 2] + sum[x * 2 + 1];
}
void update(int x, int l, int r, int weizhi, int exchange)
{
if (l == r)
{
val[x] = sum[x] = exchange;
return;
}
int mid = (l + r) / 2;
if (weizhi <= mid)
{
update(x * 2, l, mid, weizhi, exchange);
}
else
{
update(x * 2 + 1, mid + 1, r, weizhi, exchange);
}
val[x] = max(val[x * 2], val[x * 2 + 1]);
sum[x] = sum[x * 2] + sum[x * 2 + 1];
}
void getmod(int x, int l, int r, int givel, int giver, int mod)
{
if (l > giver || r < givel)return;
if (l >= givel && r <= giver && val[x] < mod)return;
if (l == r)
{
sum[x] %= mod;
val[x] %= mod;
return;
}
int mid = (l + r) / 2;
getmod(x * 2, l, mid, givel, giver, mod);
getmod(x * 2+1, mid+1, r, givel, giver, mod);
val[x] = max(val[x * 2], val[x * 2 + 1]);
sum[x] = sum[x * 2] + sum[x * 2 + 1];
}
ll getsum(int x, int l, int r, int givel, int giver)
{
if (l > giver || r < givel)return 0;
if (l >=givel && r <= giver)return sum[x];
int mid = (l + r) / 2;
ll qw = getsum(x * 2, l, mid, givel, giver);
ll qr= getsum(x * 2+1, mid+1, r, givel, giver);
return qw + qr;
}
int main()
{
int n, m;
cin >> n >> m;
build(1, 1, n);
for (int i = 0; i < m; i++)
{
int lazy;
cin >> lazy;
if (lazy == 1)
{
int l, r;
cin >> l >> r;
ll ans = getsum(1, 1, n, l, r);
cout << ans << endl;
}
else if (lazy == 2)
{
int l, r, x;
cin >> l >> r >> x;
getmod(1, 1, n, l, r, x);
}
else if (lazy == 3)
{
int k,x;
cin >> k >> x;
update(1, 1, n, k, x);
}
}
}
第二道题是codeforces的600E,这道题是一道有关线段树合并的题目,首先我们先了解一下什么是线段树合并,就是将已有的两棵线段树合并成一颗,相同的位置的信息进行整合,就是将一棵线段树的每一个位置取出来插入另一棵比较高效的线段树中合并。
线段树合并的原理如下:
对于两棵树的节点u和v:
1.如果u为空,返回v;
2.如果v为空,返回u;
3.否则,新建一个节点T,整合u和v的信息,然后递归合并u和v的左右子树
然后本道题的题意一个树有n个结点,每个结点都是一个颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和,很绕口,就是将每个结点都先看成一个线段树,枚举当前结点的每个儿子,然后合并线段树,使得线段树上的信息合并。这就是本题的解法。除此之外本题还有dsu on tree的做法,启发式合并的做法,主席树的做法,莫队的做法。
#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
#include<stack>
#include<queue>
#include<algorithm>
#include<cmath>
#include<set>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
const int maxn = 2e6+10;
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll qpow(ll a, ll b, ll m) {
ll ans = 1;
ll k = a;
while (b) {
if (b & 1)ans = ans * k % m;
k = k * k % m;
b >>= 1;
}
return ans;
}
struct tree
{
int l, r, maxx;
ll ans;
}p[maxn];
int n;
int a[maxn], date[maxn], head[maxn], Next[maxn];
int root[maxn];
ll ans[maxn];
int tot = 0,cnt=0;
void add(int x, int y)
{
date[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}
void spread(int a)
{
if (p[p[a].l].maxx > p[p[a].r].maxx)
{
p[a].maxx = p[p[a].l].maxx;
p[a].ans = p[p[a].l].ans;
}
else if (p[p[a].l].maxx < p[p[a].r].maxx)
{
p[a].maxx = p[p[a].r].maxx;
p[a].ans = p[p[a].r].ans;
}
else
{
p[a].maxx = p[p[a].l].maxx;
p[a].ans = p[p[a].l].ans + p[p[a].r].ans;
}
}
int merge(int r1, int r2, int l, int r)
{
if (!r1 || !r2)return r1 + r2;
if (l == r)
{
p[r1].maxx += p[r2].maxx;
return r1;
}
int mid = (l + r)/2;
p[r1].l = merge(p[r1].l, p[r2].l, l, mid);
p[r1].r = merge(p[r1].r, p[r2].r, mid + 1, r);
spread(r1);
return r1;
}
void update(int x, int l, int r, int pos)
{
if (l == r)
{
p[x].maxx++;
p[x].ans = l;
return;
}
int mid = (l + r) / 2;
if (pos <= mid)
{
if (!p[x].l)p[x].l = ++cnt;
update(p[x].l, l, mid, pos);
}
else
{
if (!p[x].r)p[x].r = ++cnt;
update(p[x].r, mid+1, r, pos);
}
spread(x);
}
void dfs(int x, int fa)
{
for (int i = head[x]; i ; i=Next[i])
{
if (date[i] == fa)continue;
dfs(date[i], x);
root[x] = merge(root[x], root[date[i]], 1, n);
}
if (!root[x])root[x] = ++cnt;
update(root[x], 1, n, a[x]);
ans[x] = p[root[x]].ans;
}
int main()
{
cin >> n;
memset(ans, 0, sizeof(ans));
for (int i = 1; i <= n; i++)
{
cin >> a[i];
}
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
dfs(1, 0);
for (int i = 1; i <= n; i++)
{
printf("%lld", ans[i]);
if (i != n)printf(" ");
else printf("\n");
}
}