「Luogu 4927」梦美与线段树

Description

题目链接:Luogu 4927

有一棵维护区间和的线段树,每个节点的权值是该节点所对应区间的元素 $a_i$ 的权值和。梦美会从这棵线段树的根节点开始游历,当她要进入子节点时,假设左右儿子的权值为 $sum_l$ 和 $sum_r$,当前节点的权值为 $sum_{cnr}$,那么梦美会以 $\frac{sum_l}{sum_{cnr}}$ 的概率进入左子树,否则进入右子树。

梦美有时会把下标在 $[l,r]$ 的序列的元素权值加上 $v$。梦美每次游历时,梦美会把经过的节点权值累加,现在她希望求出这个权值的期望。答案化成最简分数为 $\frac{p}{q}$,输出 $p\cdot q^{-1}\bmod 998244353$。

数据范围:$1\le n,m\le 10^5$,$1\le a_i,v\le 10^9$


Solution

计算答案

我们考虑期望原来的定义:$\frac{\text{所有情况的结果总和}}{\text{总情况数}}$。由于没给节点的权值都是左右儿子的权值和,经过的概率也和权值有关,我们记 $val_i$ 为节点 $i$ 的权值,那么可以看做是从根节点 $1$ 出发进行 $val_1$ 次游历,那么走到每一个节点的次数就是这个节点的权值。

那么所有情况的结果总和就是 ${val_i}^2$,总情况数就是 $val_1$。因此,如果没有修改操作,我们只要维护每个点的 ${val_i}^2$ 即可。

修改操作

如果某一个节点的代表的长度为 $len$,原来的权值为 $val$,区间加 $v$,那么这个节点的权值变为 $val+len\times v$,权值的平方的增量为 $(val+len\times v)^2-val^2=2\times val\times len\times v-len^2\times v^2$,那么我们就需要维护 $val$ 和 $len^2$ 和 $val\times len$。

如何维护

其中 $val$ 和 $len^2$ 很好维护,只是 $val\times len$ 比较麻烦。

对于一个节点而言,区间加之后它的权值由 $val$ 变成了 $val+len\times v$,那么 $val\times len$ 变成了 $(val+len\times v)\times len$,增量为 $len^2\times v$,而其中的 $len^2$ 我们可以维护了,因此 $val\times len$ 也可以轻松维护了!

代码里 $val[i]$ 表示节点 $i$ 的权值,$ans[i]$ 表示子树 $i$ 的答案,$len[i]$ 表示子树内所有节点的 $len^2$ 的和,$sum[i]$ 表示子树内所有节点的 $val\times len$ 的和。

注意:这题卡 $\text{long long}$ 因此要用 $\text{unsigned long long}$ 或者 $\text{__int128}$ 或者手写高精度!

时间复杂度:$O(n\log n)$


Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <cstdio>
#define lson rt<<1
#define rson rt<<1|1

const int N=1e5+5;
const int mod=998244353;
int n,m;
__int128 val[N<<2],ans[N<<2],len[N<<2],sum[N<<2],tag[N<<2];

__int128 gcd(__int128 x,__int128 y) {return y?gcd(y,x%y):x;}
int pow(int x,int p) {
int res=1;
for(;p;p>>=1,x=1LL*x*x%mod) if(p&1) res=1LL*x*res%mod;
return res;
}
void print(__int128 x) {
if(x>9) print(x/10);
putchar(x%10+'0');
}
void update(int rt,int l,int r,__int128 x) {
tag[rt]+=x;
ans[rt]+=2*x*sum[rt]+len[rt]*x*x;
val[rt]+=x*(r-l+1);
sum[rt]+=x*len[rt];
}
void pushup(int rt,int l,int r) {
val[rt]=val[lson]+val[rson];
sum[rt]=sum[lson]+sum[rson]+val[rt]*(r-l+1);
ans[rt]=val[rt]*val[rt]+ans[lson]+ans[rson];
}
void pushdown(int rt,int l,int r) {
if(!tag[rt]) return;
int mid=(l+r)>>1;
update(lson,l,mid,tag[rt]);
update(rson,mid+1,r,tag[rt]);
tag[rt]=0;
}
void build(int rt,int l,int r) {
if(l==r) {
int x; scanf("%d",&x);
sum[rt]=val[rt]=x,ans[rt]=1LL*x*x,len[rt]=1;
return;
}
int mid=(l+r)>>1;
build(lson,l,mid);
build(rson,mid+1,r);
pushup(rt,l,r);
len[rt]=len[lson]+len[rson]+1LL*(r-l+1)*(r-l+1);
}
void modify(int x,int y,int rt,int l,int r,__int128 k) {
if(x<=l&&r<=y) {
update(rt,l,r,k);
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(x<=mid) modify(x,y,lson,l,mid,k);
if(mid<y) modify(x,y,rson,mid+1,r,k);
pushup(rt,l,r);
}
int main() {
scanf("%d%d",&n,&m);
build(1,1,n);
for(int opt,l,r,v;m--;) {
scanf("%d",&opt);
if(opt==1) {
scanf("%d%d%d",&l,&r,&v);
modify(l,r,1,1,n,v);
} else {
__int128 x=ans[1],y=val[1],d=gcd(x,y);
x/=d,y/=d,x%=mod;
print(y==1?x:1LL*x*pow(y%mod,mod-2)%mod),puts("");
}
}
return 0;
}
0%