题目 :洛谷 P2590 [ZJOI2008]树的统计
题目链接:https://www.luogu.com.cn/problem/P2590
我的解题算法:树链剖分
错误类型:Wrong answer
我的提交记录:https://www.luogu.com.cn/record/78801196
主要是不知道哪里错了
我的代码:
```c++
#include <bits/stdc++.h>
using namespace std;
namespace Main
{
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch))
{
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*f;
}
void write(int x)
{
if(x<0)
{
putchar('\n');
x=-x;
}
if(x>=10)
{
write(x/10);
}
putchar(x%10^48);
}
const int maxn=3e4+5;
const int inf=0x3f3f3f3f;
int q,n;
int head[maxn];
struct EDGE
{
int to,nxt;
}edge[maxn<<1];
int ecnt;
inline void add(int u,int to)
{
edge[++ecnt].to=to;
edge[ecnt].nxt=head[u];
head[u]=ecnt;
}
int w[maxn];
//----------dfs--------------
int dep[maxn],fa[maxn],hs[maxn],nodesize[maxn],top[maxn],dfn[maxn],noderank[maxn],nodecnt;
void dfs(int u,int _fa)
{
dep[u]=dep[_fa]+1;
nodesize[u]=1;
fa[u]=_fa;
for(int i=head[u];i;i=edge[i].nxt)
{
int to=edge[i].to;
if(to==_fa)continue;
dfs(to,u);
nodesize[u]+=nodesize[to];
if(nodesize[to]>nodesize[hs[u]])hs[u]=to;
}
}
void dfs2(int u,int _top)
{
top[u]=_top;
dfn[u]=++nodecnt;
noderank[nodecnt]=u;
// printf("noderank[%d] = %d\n",nodecnt,u);
if(hs[u])dfs2(hs[u],_top);
for(int i=head[u];i;i=edge[i].nxt)
{
int to=edge[i].to;
if(to==fa[u]||to==hs[u])continue;
dfs2(to,to);
}
}
//------------Tree-----------------
int L[maxn<<3],R[maxn<<3],val[maxn<<3],imax[maxn<<3];
inline void pushup(int i)
{
val[i]=val[i<<1]+val[i<<1|1];
imax[i]=max(imax[i<<1],imax[i<<1|1]);
}
void Build(int pos,int l,int r)
{
L[pos]=l,R[pos]=r;
if(l==r)
{
imax[pos]=val[pos]=w[noderank[l]];
return;
}
int mid=l+r>>1;
Build(pos<<1,l,mid);
Build(pos<<1|1,mid+1,r);
pushup(pos);
}
void change(int pos,int l,int r,int mbpos,int _val)
{
if(l==r&&r==mbpos)
{
val[pos]=_val;
imax[pos]=_val;
return;
}
if(l>mbpos||r<mbpos)return;
int mid=l+r>>1;
change(pos<<1,l,mid,mbpos,_val);
change(pos<<1|1,mid+1,r,mbpos,_val);
pushup(pos);
}
int query(int pos,int l,int r,int _L,int _R,int id)
{
//id=1: max id=2: sum
if(l>=_L&&r<=_R)
{
if(id==1)
{
return imax[pos];
}
if(id==2)
{
return val[pos];
}
}
if(r<_L||l>_R)
{
if(id==1)return -inf;
if(id==2)return 0;
}
int mid=l+r>>1;
if(id==1)
{
int res=-inf;
if(mid>=_L)res=query(pos<<1,l,mid,_L,_R,id);
if(mid<_R)res=max(res,query(pos<<1|1,mid+1,r,_L,_R,id));
return res;
}
if(id==2)
{
int res=0;
if(mid>=_L)res+=query(pos<<1,l,mid,_L,_R,id);
if(mid<_R)res+=query(pos<<1|1,mid+1,r,_L,_R,id);
return res;
}
}
int Query(int x,int y,int id)
{
if(id==1)
{
int res=-inf;
//max
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res=max(res,query(1,1,n,dfn[top[x]],dfn[x],1));
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
res=max(res,query(1,1,n,dfn[x],dfn[y],1));
return res;
}
if(id==2)
{
int res=0;
//sum
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
res+=query(1,1,n,dfn[top[x]],dfn[x],2);
x=fa[top[x]]; }
if(dep[x]>dep[y])swap(x,y);
res+=query(1,1,n,dfn[x],dfn[y],2);
return res;
}
}
void main()
{
scanf("%d",&n);
int a,b;
for(int i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
for(int i=1;i<=n;i++)
{
scanf("%d",&w[i]);
}
//------------------
dfs(1,0);
dfs2(1,1);
Build(1,1,n);
//----------------
scanf("%d",&q);
char op[6];
int u,t;
while(q--)
{
scanf("%s%d%d",op,&u,&t);
if(op[0]=='C')
{//CHANGE
change(1,1,n,u,t);
}
if(op[1]=='M')
{//QMAX
printf("%d\n",Query(u,t,1));
}
if(op[1]=='S')
{//QSUM
printf("%d\n",Query(u,t,2));
}
}
}
}
int main()
{
Main::main();
return 0;
}