前置知识:树的存储与遍历,树的重心。
点分治
淀粉质点分治,是一种在树上静态统计满足某种条件的路径的算法。
显然,设哪个点为根与答案没有关系,所以我们可以任意令一个点为根,设这个根为 。
树上的路径可以分为两类:
- 经过 的路径。
- 不经过 ,在 的某棵子树中。
根据分治思想,对于第二类路径,我们可以在子树内递归进行处理。
对于第一类路径,我们可以将 拆成 与 ,然后就可以使用一遍 DFS 来求出每一条类似的路径。
因为对某一个根 的每一棵子树的 DFS 时间和为 ,所以设遍历层数为 则时间复杂度为 。如果我们每一次设每一棵子树的重心为根,时间复杂度为 。
模板题代码:
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <queue>
struct edge{
int v, w, nxt;
};
struct graph{
int cnt, head[10001];
edge e[20001];
void init(){
cnt = 0;
memset(head, -1, sizeof(head));
}
inline void addedge(int u, int v, int w){
e[++ cnt] = {v, w, head[u]};
head[u] = cnt;
}
inline int h(int u){return head[u];}
inline edge operator[](int x){return e[x];}
}g;
using namespace std;
int n, m, u, v, w, q[101];
int siz[10001], d[10001], dd[10001], cnt, maxs, p;
bool vis[10001], t[10000001], ans[101];
queue<int> del;
void dfs1(int s, int u, int fa){
int maxss = 0;siz[u] = 1;
for (int i = g.h(u);i != -1;i = g[i].nxt){
int v = g[i].v;
if (v == fa || vis[v]) continue;
dfs1(s, v, u);
siz[u] += siz[v];
maxss = max(maxss, siz[v]);
}
maxss = max(maxss, s - siz[u]);
if (maxss < maxs) p = u, maxs = maxss;
} // 找重心并计算 siz
void dfs2(int u, int fa){
dd[++ cnt] = d[u];
for (int i = g.h(u);i != -1;i = g[i].nxt){
int v = g[i].v, w = g[i].w;
if (v == fa || vis[v]) continue;
d[v] = d[u] + w, dfs2(v, u);
}
} // 计算到根的距离
void solve(int u, int fa){
t[0] = 1, del.push(0), vis[u] = 1;
for (int i = g.h(u);i != -1;i = g[i].nxt){
int v = g[i].v, w = g[i].w;
if (v == fa || vis[v]) continue;
cnt = 0, d[v] = w, dfs2(v, u); // 对于每一棵子树计算距离
for (int j = 1;j <= m;j ++){
for (int k = 1;k <= cnt;k ++){
if (q[j] >= dd[k]) ans[j] |= t[q[j] - dd[k]];
}
} // 用桶统计并计算答案
for (int k = 1;k <= cnt;k ++){
if (dd[k] <= 10000000) del.push(dd[k]), t[dd[k]] = 1;
} // 记录要赋 0 的值
}
while (del.size()) t[del.front()] = 0, del.pop(); // 注意不能直接 memset,否则时间复杂度不对
for (int i = g.h(u);i != -1;i = g[i].nxt){
int v = g[i].v, w = g[i].w;
if (v == fa || vis[v]) continue;
maxs = 1919810000;
dfs1(siz[v], v, u);
dfs1(siz[v], p, 0);
solve(p, u); // 选择重心作为根
}
} // 点分治
int main(){
g.init();
scanf("%d%d", &n, &m);
for (int i = 1;i < n;i ++) scanf("%d%d%d", &u, &v, &w), g.addedge(u, v, w), g.addedge(v, u, w);
for (int i = 1;i <= m;i ++) scanf("%d", &q[i]);
maxs = 1919810000;
dfs1(n, 1, 0);
dfs1(n, p, 0);
solve(p, 0);
for (int i = 1;i <= m;i ++) printf(ans[i] ? "AYE\n" : "NAY\n");
}