挺好一个题。T4 漏了一万次情况导致没时间做这个了,赛时只写了带 log\log 做法。

考虑把最小异或和之和拆开,设 F(k)F(k) 为选出一个子集使得最小异或和 >k> k 的方案数,那么答案即为 i=0F(k)\sum_{i = 0} F(k)。可以发现 kk 的上界是 O(mn)\mathcal{O}(\frac mn) 级别的,其中 mm 为值域。这是因为,如果我们将 kk 收缩到最大的形如 2p12^p - 1 的数,那么条件等价于 ai2p\left \lfloor \frac{a_i}{2^p} \right \rfloor 互不相同,于是有 2pmn2^p \le \frac mn,可以推出 k<2p+12mnk < 2^{p + 1} \le \frac{2m}n

考虑如何计算 F(k)F(k)。根据最小异或和只可能在数值上相邻两个数之间产生的经典结论,我们考虑按值域从大到小来 DP。设 fi,jf_{i, j} 表示目前已经确定了前 ii 大的数,第 ii 大的数为 jj 时的方案数,并预处理出 cntxcnt_xi=1n[xli]\sum_{i = 1}^n [x \le l_i],有转移方程:

fi,j(cntxi+1)fi1,x (xj>k,x>j)f_{i, j} \gets (cnt_x - i + 1) f_{i - 1, x} \ (x \oplus j> k, x > j)

对于单个 kk 做一遍是 O(nm2)\mathcal{O}(nm^2),总时间复杂度为 O(m3)\mathcal{O}(m^3)。考虑优化。

可以发现满足 xj>kx \oplus j > kxx 在值域上会形成 O(logm)\mathcal{O}(\log m) 个区间,具体来说我们枚举 xjx \oplus j 是在哪一位严格大于 kk 的即可。那么我们可以将单个 kk 的时间复杂度优化到 O(nmlogm)\mathcal{O}(nm \log m),总时间复杂度为 O(m2logm)\mathcal{O}(m^2 \log m)

这是求解单个 F(k)F(k) 的参考代码:

int solve(int x){int ans = 0;
	for (int i = 0;i < N;i ++) f[i] = a[i];
	for (int i = 1;i < n;i ++){
		for (int j = 0;j < N;j ++) s[j] = f[j], f[j] = 0;
		for (int j = 1;j < N;j ++) add(s[j], s[j - 1]);
		for (int j = 0;j < N;j ++) if (a[j] > i){int tp = 0;
			for (int k = D - 1;k >= 0;k --){
				if (j >> k & 1 ^ 1) tp ^= 1 << k;
				if (x >> k & 1 ^ 1)
				add(f[j], S(max(j + 1, tp), tp + (1 << k) - 1)), tp ^= 1 << k;
			}mul(f[j], a[j] - i);
		}
	}
	for (int i = 0;i < N;i ++) add(ans, f[i]);
	return ans;
}

实际上这个东西跑的很快,最慢点 2.07s,卡一下肯定能过。

但是我们可以做到更优。注意到我们在令 jj+1j \gets j + 1 的时候实际上是对于 jj 的二进制表示的一段后缀进行了反转,而这 O(logm)\mathcal{O}(\log m) 个区间中也只有被反转的位对应的区间会受影响。显然被反转的后缀的总长度是 O(m)\mathcal{O}(m) 级别的,我们在每次 jj+1j \gets j + 1 的时候重新处理受影响的区间即可,那么总时间复杂度为 O(m2)\mathcal{O}(m^2)

这是求解单个 F(k)F(k) 的参考代码:

int solve(int x){int ans = 0;
	for (int i = 0;i < N;i ++) f[i] = a[i];
	for (int i = 1;i < n;i ++){
		for (int j = 0;j < N;j ++) s[j] = f[j], f[j] = 0;
		for (int j = 1;j < N;j ++) add(s[j], s[j - 1]);
		int tp = 0, now = 0;
		for (int k = D - 1;k >= 0;k --){tp ^= 1 << k;
			if (x >> k & 1 ^ 1)
			add(now, S(max(1, tp), tp + (1 << k) - 1)), tp ^= 1 << k;
		}
		for (int j = 0, p;a[j] > i;){
			mul(f[j] = now, a[j] - i), p = __builtin_ctz(j + 1);
			for (int k = 0;k <= p;k ++){
				if (x >> k & 1 ^ 1)
				tp ^= 1 << k, add(now, mod - S(max(j + 1, tp), tp + (1 << k) - 1));
				if (j >> k & 1 ^ 1) tp ^= 1 << k;
			}j ++;
			for (int k = p;k >= 0;k --){
				if (j >> k & 1 ^ 1) tp ^= 1 << k;
				if (x >> k & 1 ^ 1)
				add(now, S(max(j + 1, tp), tp + (1 << k) - 1)), tp ^= 1 << k;
			}
		}
	}
	for (int i = 0;i < N;i ++) add(ans, f[i]);
	return ans;
}