很好写的 O(mlogn)O(\sum m \log n) 做法,目前 solution size 排第三。

因为 mm 比较小,所以考虑枚举 max\max。容斥一下,设 FxF_x 表示对于每条链,链上的点 maxx\max \le x,其他点任意填的方案数,即 max\max 恰好等于 ii 的方案数即为 ansx=FxFx1ans_x = F_x - F_{x - 1}。答案为 x=1mx(FxFx1)=Fmx=1m1Fx\sum_{x = 1}^m x(F_x - F_{x - 1}) = F_m - \sum_{x = 1}^{m - 1} F_x

maxx\max \le x 是容易求的,即每个点都 x\le x。所以对于一个在路径上的点,它的填法有 xx 种,否则有 mm 种。

考虑树形 DP,设 fkf_k 表示大小为 kk 的树内的路径填法和,gkg_k 表示到大小为 kk 树的内到根的路径填法和,有转移方程:

gk=x(glsmrs+mlsgrs+mk1)fk=gk+xglsgrs+flsmrs+1+mls+1frsg_k = x(g_{ls}m^{rs} + m^{ls}g_{rs} + m^{k - 1})\\ f_k = g_k + xg_{ls}g_{rs} + f_{ls}m^{rs + 1} + m^{ls + 1}f_{rs}

其中 ls,rsls, rs 分别为左右子树大小。状态数看似是 O(n)O(n) 的,但是实际上左右两边至少一边是 2y12^y - 1 的形式,所以对 fnf_n 有用的状态数实际上是 O(logn)O(\log n) 的。

转移时如果用快速幂求出系数,就会多出一个 log\log,但是我们对用到的状态都预处理出 mkm^k 就可以 O(logn)O(\log n) 求出 fnf_n

我们对于每一个 xx,都 O(logn)O(\log n) 求出 FxF_x(即当前 xx 求出的 fnf_n),就可以 O(mlogn)O(m \log n) 求出原本的答案了。

实现时不需要开一个 unordered_map 去记忆化搜索,只需要记录所有 k=2y1k = 2^y - 1 的状态,时间复杂度就是对的。

代码真的很好写:

#include <unordered_map>
#include <algorithm>
#include <cstring>
#include <cstdio>
#define mod 998244353
#define popc __builtin_popcountll
#define log2 __builtin_ctzll
using namespace std;
const int N = 100001;
typedef long long ll;
int t, n, m, ans[N], sum;ll k;
struct node{int vf, vg, vp;}F[64];
inline void add(int& x, int y){x += y;if (x >= mod) x -= mod;}
node dp(ll x){
	if (!x) return {0, 0, 1};
	if (x == 1) return {m, m, n};x ++;
	if (popc(x) == 1 && F[log2(x)].vp) return F[log2(x)];x --;
	ll ri = 1ll << 64 - __builtin_clzll(x);
	ri --;ll lx = ri - 1 >> 1, rx = lx;
	ll ls = ri - x, di = ri + 1 >> 2;
	if (ls <= di) rx -= ls;
	else rx -= di, lx -= ls - di;
	node lf = dp(lx), rf = dp(rx), res;
	res.vp = 1ll * lf.vp * rf.vp % mod;
	res.vg = (1ll * lf.vg * rf.vp + 1ll * rf.vg * lf.vp + res.vp) % mod * m % mod;
	res.vp = 1ll * res.vp * n % mod;
	res.vf = (1ll * lf.vg * rf.vg % mod * m + res.vg +
	(1ll * lf.vf * rf.vp + 1ll * rf.vf * lf.vp) % mod * n) % mod;
	x ++;if (popc(x) == 1) F[log2(x)] = res;x --;return res;
}
int main(){scanf("%d", &t);
	while (t --){scanf("%lld%d", &k, &n), sum = 0;
		for (m = 1;m <= n;m ++)
		memset(F, 0, sizeof(F)), ans[m] = dp(k).vf;
		for (m = n;m >= 1;m --)
			add(ans[m], mod - ans[m - 1]),
			add(sum, 1ll * ans[m] * m % mod);
		printf("%d\n", sum);
	}
}