思路
首先,题目说:
从编号为 的叶子节点到树根每一个点的点权的 值 上 为 。
然后,根据按位 的性质,,当且仅当:
-
当 的第 个二进制位为 时, 的第 个二进制位也为 。
-
当 的第 个二进制位为 时, 的第 个二进制位任意。
-
就是 的取值数量。
根据乘法原理, 就是整棵树的取值数量。
我们现在就要求所有节点 位位数的和。
可以先求出每一层节点 位位数的和,再加起来,就是答案了。
我们先手推一个 的情况,每一个 取最多的 。

首先,我们知道:叶节点 取值为 ;
非叶节点 为了满足它的两个子节点 和 的要求,取值应为 。
对于第 层, 与 的最后一位一个为 ,一个为 ,或起来为 ;所以,所有第 层的节点的取值的末尾都是 ,删掉对 的个数没有影响。
删掉之后,第 层的 ,,, 就变成了 ,,,。我们发现,这不是 层二叉树的 个叶子节点的取值吗?
于是,我们就可以知道: 层二叉树所有节点 位位数的和 $ = $ $\sum_{i = 1}^n $ 层二叉树第 层节点的和。
再拿 层的 ,,, 举例子,可以发现:第 层的 ,,,,,,, 是由 ,,, 分别在开头加了一个 和一个 得来的。加了一个 , 的个数就加了 ;加了一个 , 的个数没变。
设 为 层二叉树第 层所有节点 位位数的和, 为 层二叉树前 层所有节点 位位数的和,可以得到递推式:
因为 ,所以不难想到用矩阵快速幂优化。
目标矩阵:, 即为答案。
推导一个矩阵 ,使得:
根据转移方程,可得 。
设计初始矩阵为 ,也就是 。
目标矩阵即为 。
根据欧拉定理的推论:
为质数
然后,矩阵乘法运算的时候,将模数设为 即可。
最后求的答案为 。
代码
#include <algorithm>
#include <cstdio>
#define ll long long
using namespace std;
int mod = 998244353;
int _mod = mod - 1;
struct matrix{
int n;
ll val[101][101];
void out(){
for (int i = 1;i <= n;i ++){
for (int j = 1;j <= n;j ++) printf("%lld ", val[i][j]);
printf("\n");
}
}
matrix operator*(matrix b){
matrix res;
res.n = n;
if (n != b.n) exit(1);
for (int i = 1;i <= n;i ++){
for (int j = 1;j <= n;j ++){
res.val[i][j] = 0;
for (int k = 1;k <= n;k ++) res.val[i][j] = (res.val[i][j] + (val[i][k] * b.val[k][j]) % _mod) % _mod;
}
}
return res;
}
}M, B, S;
matrix one(int x){
matrix res;
res.n = x;
for (int i = 1;i <= x;i ++){
for (int j = 1;j <= x;j ++) res.val[i][j] = 0;
}
for (int i = 1;i <= x;i ++) res.val[i][i] = 1;
return res;
}
matrix mpow(matrix m, ll p){
matrix ans = one(m.n), base = m;
while (p){
if (p & 1) ans = ans * base;
base = base * base;
p >>= 1;
}
return ans;
}
ll qpow(ll x, ll p){
ll ans = 1, base = x;
while (p){
if (p & 1) ans = ans * base % mod;
base = base * base % mod;
p >>= 1;
}
return ans % mod;
}
const ll b[4][4] = {{0, 0, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}};
const ll m[4][4] = {{0, 0, 0, 0}, {0, 2, 0, 1}, {0, 1, 2, 0}, {0, 0, 0, 1}};
ll n, po;
int main(){
B.n = M.n = 3;
for (int i = 1;i <= 3;i ++){
for (int j = 1;j <= 3;j ++){
B.val[i][j] = b[i][j];
M.val[i][j] = m[i][j];
}
}
scanf("%lld", &n);
S = B * mpow(M, n - 1);
printf("%lld", qpow(2, (S.val[1][1] + S.val[1][3]) % _mod));
}