无脑 O(nlog2n)O(n \log^2 n) 做法,但是跑的比官方题解 O(nlogn)O(n \log n) 快。

先套路的设 "("\texttt{"("}11")"\texttt{")"}1-1,那么一条路径合法,当且仅当从上(深度小的一端)至下(深度大的一端)的所有前缀和都 0\ge 0,同理,从下至上的所有前缀和都 0\le 0

考虑从下至上树形 DP,设 fu,if_{u, i} 表示 uu 子树内,uu 到每个叶子的路径上的权值和都恰好为 ii 的最大连通块的大小。根据上面的结论,i0i \le 0 时这才是一个合法的状态。

可以推出转移方程:

fu,i+au1+vfv,ifu,au1f_{u, i + a_u} \gets 1 + \sum_v f_{v, i}\\ f_{u, a_u} \gets 1

所有状态初始设为 00

暴力转移是 O(uszu)O(\sum_u sz_u) 的,可以卡到 O(n2)O(n^2),我们考虑优化。稍微将转移方程进行变形,得:

fu,ivfv,ifu,00fu,i+au=1+fu,if'_{u, i} \gets \sum_v f_{v, i}\\ f'_{u, 0} \gets 0\\ f_{u, i + a_u} = 1 + f'_{u, i}

可以发现,我们只需要支持单点修改、快速合并每棵子树的集合、整体加、以及整体平移即可快速维护 ff

因为我们要支持合并,所以考虑树上启发式合并。对于每个 fuf_u 都开一个 map 维护非 00 的位置,合并时将小的集合暴力合并到大的集合里,而整体加、整体平移可以打两个 tag 分别维护。最后将平移后下标 >0> 0 的所有位置都删掉即可。

时间复杂度 O(nlog2n)O(n \log^2 n),跑不满,可以通过。

注意合并与修改时对 tag 的处理。

#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#include <map>
using namespace std;
class edge{
	public:
		edge(int x, int y, int z, int zz){v = x, w = y, c = z, nxt = zz;}
		int v, w, c, nxt;
};
class graph{
	private:
		vector<int> head;vector<edge> e;
	public:
		void init(int n){head.clear(), e.clear();for (int i = 0;i <= n;i ++) head.emplace_back(-1);}
		graph(){}graph(int n){init(n);}
		inline int addedge(int u, int v, int w = 0, int c = 0){e.emplace_back(v, w, c, head[u]), head[u] = e.size() - 1;return e.size() - 1;}
		inline int add2edge(int u, int v, int w = 0){int tmp = addedge(u, v, w);addedge(v, u, w);return tmp;}
		inline int addfedge(int u, int v, int w, int c){int tmp = addedge(u, v, 0, c);addedge(v, u, w, -c);return tmp;}
		inline int& h(int u){return head[u];}
		inline edge& operator[](int i){return e[i];}
}G;
const int N = 500005;map<int, int> f[N];
int n, a[N], mv1[N], mv2[N], ans;char ch[N];
void dp(int u, int fa){
	for (int i = G.h(u);i != -1;i = G[i].nxt){
		int v = G[i].v;if (v == fa) continue;
		dp(v, u);if (f[u].size() < f[v].size())
		swap(f[u], f[v]), swap(mv1[u], mv1[v]), swap(mv2[u], mv2[v]);
		for (auto [x, y] : f[v]){
			int p = x - mv1[v] + mv1[u], pp = y + mv2[v];
			if (f[u].find(p) != f[u].end()) f[u][p] += pp;
			else f[u][p] = pp - mv2[u];
		}f[v].clear();
	}
	if (f[u].find(mv1[u]) == f[u].end())
	f[u][mv1[u]] = -mv2[u];mv1[u] -= a[u], mv2[u] ++;
	if (f[u].find(mv1[u]) != f[u].end())
	ans = max(ans, f[u][mv1[u]] + mv2[u]);
	else f[u][mv1[u]] = -mv2[u];
	while (f[u].size()){auto ptr = f[u].end();ptr --;
		if (ptr -> first - mv1[u] > 0) f[u].erase(ptr);
		else break;
	}
}
int main(){scanf("%d%s", &n, ch + 1), G.init(n);
	for (int i = 1;i <= n;i ++) a[i] = ch[i] == '(' ? 1 : -1;
	for (int i = 1, u, v;i < n;i ++) scanf("%d%d", &u, &v), G.add2edge(u, v);
	dp(1, 0), printf("%d", ans);
}