拉格朗日插值

众所周知,n+1n + 1xx 坐标不同的点可以确定唯一的一个 nn 次多项式。

如何确定呢?我们考虑构造 n+1n + 1 个多项式,第 ii 个多项式在 xix_i 点上取到 yiy_i 的值,而在其他的 xj(ij)x_j(i \ne j) 点上取到 00,这 n+1n + 1 个多项式的和就是最终确定的多项式。

现在的问题就变成如何构造。首先满足在别的点上取到 00。可以构造一个这样的多项式:1jn+1,ij(xxj)\prod_{1 \le j \le n + 1, i \ne j} (x - x_j),容易发现,这个多项式在别的点上都能取到 00

然后考虑在 xix_i 点上取到 yiy_i。可以发现,我们只需要把它缩放一下就行了。具体来说,假设刚刚的多项式在 xix_i 处取到 yy',然后我们给刚刚的多项式乘上 yiy\frac{y_i}{y'} 即可。所以第 ii 个的多项式就是:

yi1jn+1,ij(xxj)1jn+1,ij(xixj)y_i\dfrac{\prod_{1 \le j \le n + 1, i \ne j}(x - x_j)}{\prod_{1 \le j \le n + 1, i \ne j}(x_i - x_j)}

于是最终确定的多项式就为:

i=1n+1yi1jn+1,ij(xxj)1jn+1,ij(xixj)\sum_{i = 1}^{n + 1} y_i\dfrac{\prod_{1 \le j \le n + 1, i \ne j}(x - x_j)}{\prod_{1 \le j \le n + 1, i \ne j}(x_i - x_j)}

在具体应用中,一般在某些要求的东西是一个次数较小的多项式,而需要求一个任意点的值时,可以先把它在较小点上的一些取值求出来,拉格朗日插值得到多项式,再把要求的地方的值代入即可。

模板题代码:

#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#define mod 998244353
using namespace std;

const int N = 2001;
int n, X, x[N], y[N], ans;
int qpow(int base, int p){int ans = 1;
	for (;p;p >>= 1, base = 1ll * base * base % mod){
		if (p & 1) ans = 1ll * ans * base % mod;
	}return ans;
}
int main(){
	scanf("%d%d", &n, &X), n --;
	for (int i = 0;i <= n;i ++) scanf("%d%d", &x[i], &y[i]);
	for (int i = 0;i <= n;i ++){int tp1 = y[i], tp2 = 1;
		for (int j = 0;j <= n;j ++){
			if (i == j) continue;
			tp1 = 1ll * tp1 * (X - x[j] + mod) % mod;
			tp2 = 1ll * tp2 * (x[i] - x[j] + mod) % mod;
		}
		ans = (ans + 1ll * tp1 * qpow(tp2, mod - 2)) % mod;
	}printf("%d", ans);
}

xx 点连续的线性做法

i=1n+1yi1jn+1,ij(xxj)1jn+1,ij(xixj)\sum_{i = 1}^{n + 1} y_i\dfrac{\prod_{1 \le j \le n + 1, i \ne j}(x - x_j)}{\prod_{1 \le j \le n + 1, i \ne j}(x_i - x_j)}

这个柿子暴力求是 O(n2)O(n^2) 的,但是在 xix_i 连续时可以 O(n)O(n) 求,我们考虑 O(1)O(1) 求出右边的柿子。

右边的分子可以预处理出 prd=1jn+1(xxj)prd = \prod_{1 \le j \le n + 1}(x - x_j),然后每次乘上一个 1xxi\frac 1{x - x_i} 即可。然后我们发现分母是两个阶乘之积,预处理阶乘即可。

模板题代码:

#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#define mod 1000000007
using namespace std;

const int N = 1000003;
int n, k, sm[N], fac[N], prd = 1, ans;
int qpow(int base, int p){int ans = 1;
	for (;p;p >>= 1, base = 1ll * base * base % mod){
		if (p & 1) ans = 1ll * ans * base % mod;
	}return ans;
}
int main(){
	scanf("%d%d", &n, &k), fac[0] = 1;
	for (int i = 1;i <= k + 2;i ++) sm[i] = (sm[i - 1] + qpow(i, k)) % mod;
	if (n <= k + 2){printf("%d\n", sm[n]);return 0;}
	for (int i = 1;i <= k + 2;i ++) fac[i] = 1ll * fac[i - 1] * i % mod;
	for (int i = 1;i <= k + 2;i ++) prd = 1ll * prd * (n - i + mod) % mod;
	for (int i = 1;i <= k + 2;i ++){
		int tp = ((k + 2 - i & 1 ? -1ll : 1ll) * fac[i - 1] * fac[k + 2 - i] % mod + mod) % mod;
		int tmp = 1ll * prd * qpow((n - i + mod), mod - 2) % mod * qpow(tp, mod - 2) % mod * sm[i] % mod;
		ans = (ans + tmp) % mod;
	}printf("%d", ans);
}

实际应用

[集训队互测 2012] calc

首先可以考虑 dp,然后设 fi,jf_{i, j} 为考虑 1i1 \sim i,选了 jj 个数的乘积和,最后答案即为 fk,nn!f_{k, n}n!,可以推出转移方程 fi,j=fi1,j+i×fi1,j1f_{i, j} = f_{i - 1, j} + i \times f_{i - 1, j - 1}。发现 fi,jf_{i, j} 是与 ii 有关的多项式,将它记作 fj(i)f_j(i),设它的次数为 g(j)g(j)

我们发现转移方程是一个前缀和的形式,给它差分一下得到 fj(i)fj(i1)=i×fj1(i1)f_j(i) - f_j(i - 1) = i \times f_{j - 1}(i - 1)。然后可以发现对于一个 kk 次多项式 p(x)p(x)p(x)p(x1)p(x) - p(x - 1)k1k - 1 次多项式,因为 xkx^k 这项会被抵消掉。所以柿子左边是一个 g(j)1g(j) - 1 次多项式,右边是一个 g(j1)+1g(j - 1) + 1 次多项式,有等式 g(j)1=g(j1)+1g(j) - 1 = g(j - 1) + 1,移项可得 g(j)=g(j1)+2g(j) = g(j - 1) + 2,又有 g(0)=0g(0) = 0,所以 g(j)=2jg(j) = 2j

所以,fk,nf_{k, n} 是一个 2n2n 次多项式,我们只需要对于所有 1i2n+11 \le i \le 2n + 1ii 求出 fi,nf_{i, n} 再对于所有的点 (i,fi,n)(i, f_{i, n}) 跑拉格朗日插值代入 kk 即可,时间复杂度为 O(n2)O(n^2)

参考代码:

#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#define inf 0x3f3f3f3f
#define llinf 0x3f3f3f3f3f3f3f3f
typedef long long ll;
typedef __int128_t lll;
typedef unsigned int uint;
typedef unsigned long long ull;
typedef __uint128_t ulll;
using namespace std;

const int N = 1002;
int n, k, mod, f[N][N >> 1], prd, ans;
int qpow(int bse, int p){int ans = 1;
	for (;p;p >>= 1, bse = 1ll * bse * bse % mod){
		if (p & 1) ans = 1ll * ans * bse % mod;
	}return ans;
}
int main(){
	scanf("%d%d%d", &k, &n, &mod), f[0][0] = prd = 1;
	for (int i = 1;i <= n;i ++) prd = 1ll * prd * i % mod;
	for (int i = 1;i <= 2 * n + 1;i ++){
		for (int j = 0;j <= i && j <= n;j ++){f[i][j] = f[i - 1][j];
			if (j) f[i][j] = (f[i][j] + 1ll * i * f[i - 1][j - 1]) % mod;
		}
	}
	for (int i = 1;i <= 2 * n + 1;i ++){int tp1 = f[i][n], tp2 = 1;
		for (int j = 1;j <= 2 * n + 1;j ++){
			if (i == j) continue;
			tp1 = 1ll * tp1 * (k - j + mod) % mod;
			tp2 = 1ll * tp2 * (i - j + mod) % mod;
		}
		ans = (ans + 1ll * tp1 * qpow(tp2, mod - 2)) % mod;
	}printf("%d\n", 1ll * ans * prd % mod);
}

CF1874E Jellyfish and Hack

首先如果 lim>n(n+1)2lim > \frac{n(n + 1)}2 则无解,输出 00。否则考虑暴力 dp,设 fi,jf_{i, j} 表示长度为 ii,耗时为 jj 的排列个数,直接暴力枚举第一个数 pp 转移可以做到 O(n6)O(n^6)

fi,j=p=1i(i1p1)k=0jifp1,kfip,jikf_{i, j} = \sum_{p = 1}^i \binom{i - 1}{p - 1} \sum_{k = 0}^{j - i} f_{p - 1, k}f_{i - p, j - i - k}

发现 jj 这一维的转移是一个卷积的形式,那么我们将每个 fif_i 都写成多项式形式 FiF_i,即 fi,j=[xj]Fif_{i, j} = [x^j]F_i,则转移方程可以写作:

Fi=p=1i(i1p1)xiFp1FipF_i = \sum_{p = 1}^i \binom{i - 1}{p - 1} x^i F_{p - 1} F_{i - p}

根据 ff 的定义,FnF_n 的最高次不会超过 n(n+1)2\frac{n(n + 1)}2。所以我们将 1n(n+1)2+11 \sim \frac{n(n + 1)}2 + 1 都代入 xx 求出 n(n+1)2+1\frac{n(n + 1)}2 + 1 个不同的 FnF_n 的值,然后暴力插值求出 FnF_n 的各项系数。

FnF_n 的次数为 mm,则我们可以先 O(m2)O(m^2) 预处理出 prd=1jm+1(xxj)prd = \prod_{1 \le j \le m + 1}(x - x_j) 的各项系数,每次 O(m)O(m) 多项式除法除掉一个 (xxi)(x - x_i),再乘上一个 yi1jm+1(xxj)\frac{y_i}{\prod_{1 \le j \le m + 1}(x - x_j)}O(m)O(m) 加入。

每次求出单个值的时间复杂度是 O(n2)O(n^2),总共有 O(n2)O(n^2) 个值,时间复杂度为 O(n4)O(n^4);而暴力插值的时间复杂度为 O(m2)=O(n4)O(m^2) = O(n^4),总时间复杂度为 O(n4)O(n^4)

参考代码:

#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#define mod 1000000007
using namespace std;
const int N = 205, M = 20105;
int n, m, f[N], y[M], C[N][N], co[M], nw[M], ans[M], sum;
inline void add(int& x, int y){x += y;if (x >= mod) x -= mod;}
int qpow(int bse, int p){int ans = 1;for (;p;p >>= 1, bse = 1ll * bse * bse % mod) if (p & 1) ans = 1ll * ans * bse % mod;return ans;}
int main(){scanf("%d%d", &n, &m);
	if (m > n * n + n >> 1){printf("0");return 0;}
	for (int i = 0;i <= n;i ++){C[i][0] = 1;
		for (int j = 1;j <= i;j ++)
		add(C[i][j] = C[i - 1][j], C[i - 1][j - 1]);
	}int R = n * n + n >> 1;R ++;
	for (int x = 1;x <= R;x ++){
		memset(f, 0, sizeof(f)), f[0] = 1;int px = x;
		for (int i = 1;i <= n;i ++){for (int p = 1;p <= i;p ++) 
			f[i] = (f[i] + 1ll * f[p - 1] * f[i - p] % mod * C[i - 1][p - 1]) % mod;
			f[i] = 1ll * f[i] * px % mod, px = 1ll * px * x % mod;
		}y[x] = f[n];
	}co[0] = 1;
	for (int i = 1;i <= R;i ++)
	for (int j = R;j >= 0;j --)
	co[j] = (1ll * co[j] * (mod - i) + (j ? co[j - 1] : 0)) % mod;
	for (int i = 1;i <= R;i ++){int now = 1;
		for (int j = 1;j <= R;j ++) if (i != j)
		now = 1ll * now * (i - j + mod) % mod;
		now = 1ll * qpow(now, mod - 2) * y[i] % mod;
		int iv = qpow(mod - i, mod - 2);
		for (int j = 0;j <= R;j ++)
		nw[j] = 1ll * (co[j] - (j ? nw[j - 1] : 0) + mod) * iv % mod; 
		for (int j = 0;j <= R;j ++)
		ans[j] = (ans[j] + 1ll * nw[j] * now) % mod;
	}
	for (int j = m;j <= R;j ++) add(sum, ans[j]);
	printf("%d", sum);
}