拉格朗日插值法求多项式系数

拉格朗日插值

公式

\[L(x) = \sum_{j=0}^{k} y_{j}l_{j}(x)\]

其中 \(l_{j}(x) = \prod_{i=0, i \neq j}^{k} \frac{x-x_{i}}{x_{j}-x_{i}}\)

上述公式可以保证 $L(x_{j}) = y_{j}$

给出$n$个点可以确定一个 $n-1$ 次多项式,求出$f(k)$,每求一次k值的时间复杂度为$O(n^2)$代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int maxn = 1e5+5;
const LL mod = 998244353;

LL quick_mod(LL a, LL b, LL mod){
	LL sum = 1;
	while(b){
		if(b & 1) sum = sum*a % mod;
		a = a*a % mod; b >>= 1;
	}return sum;
}

struct Point{LL x, y;}p[maxn];
int n, k;
LL Lagrange(LL k){
	LL ans = 0; 
	for(int j=1; j<=n; j++){
		LL fz = 1, fm = 1;
		for(int i=1; i<=n; i++){
			if(i == j) continue;
			fz = fz * ((k - p[i].x + mod)%mod) % mod;
			fm = fm * ((p[j].x - p[i].x + mod)%mod) % mod;
		}
		ans = (ans + (p[j].y * fz)%mod * quick_mod(fm, mod-2, mod) %mod) %mod;
	}return ans;
}
int main(){
	cin >> n >> k;
	for(int i=1; i<=n; i++) cin >> p[i].x >> p[i].y;
	cout << Lagrange(k) << endl;
}

变形:重心拉格朗日插值法

令 $l(x) = (x-x_{0})(x-x_{1})…(x-x_{k}) = \prod_{i=0}^{k}(x-x_{i})$

则上述 \(l_{j}(x) = \frac{l(x)}{x-x_{j}}\frac{1}{\prod_{i=0, i \neq j}^{k} x_{j}-x_{i}}\) 定义权重$w_{j}$ : \(w_{j} = \frac{1}{\prod_{i=0, i \neq j}^{k} x_{j}-x_{i}}\) $l_{j}(x)$ 可以进一步化为 : \(l_{j}(x) = l(x) \frac{w_{j}}{x-x_{j}}\) 于是有重心拉格朗日多项式 : \(L(x) = l(x) \sum_{j=0}^{k} \frac{w_{j}}{x-x_{j}}y_{j}\) 若此时新增加一个点 $(x_{k+1}, y_{k+1})$

需要更新以下值 :

  • $k$ := $ k+1$
  • $l(x)$ := $l(x)(x-x_{k+1})$
  • $ w_{j} $ := $\frac{w_{j}}{(x_{j}-x_{k+1})}$ 更新所有的$w_{j}$ 需要 $O(n)$ 时间

重心拉格朗日插值(二型)

考虑对函数 $g(x) \equiv 1$ (即 $y_{j} = 1$)进行插值: \(g(x) = l(x) \sum_{j=0}^{k} \frac{w_{j}}{x-x_{j}}\) 则 $L(x) $ 中的$l(x)$ 可以消去,得到重心拉格朗日插值公式二型: \(L(x) = \frac{L(x)}{g(x)} = \frac{\sum_{j=0}^{k} \frac{w_{j}}{x-x_{j}}y_{j}}{\sum_{j=0}^{k} \frac{w_{j}}{x-x_{j}}}\) 此时增加一个点时不必考虑 $l(x)$ 了

例题

自然数k次幂和

给定 $n$ 和 $k$ , 求 $ \sum_{i=1}^{n}i^k $ % mod 的值。(n <= 1e18, k <= 1e6)

当 $k$ 比较小时,例如 $k = 1$ 时有 \(\sum_{i=1}^{n} i = \frac{n(n+1)}{2}\) 易知:$\sum_{i=1}^{n}i^k$ 的最高次幂为 $k+1$ 。所以知道前 $k+2$ 个点$(x_{i}, y_{i})$便可求出 $f(n)$ 的值。

1

**上式的二式sum中的 j 改为 i **

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<bits/stdc++.h>

using namespace std;
const int N = 1e6+100;
typedef long long ll;
const ll mod = 1e9+7;
ll p[N],x[N],s1[N],s2[N],ifac[N];
ll qpow(ll a,ll b){
	ll ans=1;
	while(b){
		if(b&1) ans=(ans%mod*a%mod)%mod;
		a=(a%mod*a%mod)%mod;
		b>>=1;
	}
	return (ans%mod+mod)%mod;
}

//拉格朗日插值,n项,每个点的坐标为(x_i,y_i),求第xi项的值,保证x是连续的一段 
ll lagrange(ll n, ll *x, ll *y, ll xi) {
    ll ans = 0;
    s1[0] = (xi-x[0])%mod, s2[n+1] = 1;
    for (ll i = 1; i <= n; i++) s1[i] = 1ll*s1[i-1]*(xi-x[i])%mod;
    for (ll i = n; i >= 0; i--) s2[i] = 1ll*s2[i+1]*(xi-x[i])%mod;
    ifac[0] = ifac[1] = 1;
    for (ll i = 2; i <= n; i++) ifac[i] = -1ll*mod/i*ifac[mod%i]%mod;
    for (ll i = 2; i <= n; i++) ifac[i] = 1ll*ifac[i]*ifac[i-1]%mod;
    for (ll i = 0; i <= n; i++)
        (ans += 1ll*y[i]*(i == 0 ? 1 : s1[i-1])%mod*s2[i+1]%mod
            *ifac[i]%mod*(((n-i)&1) ? -1 : 1)*ifac[n-i]%mod) %= mod;
    return (ans+mod)%mod;
}
int main(){
	ll n,k;
	cin>>n>>k;
	if(k==0){
		cout<<n<<endl;
		return 0;
	}
	p[0]=0;
	for(ll i=1;i<=k+2;i++) p[i]=(p[i-1]%mod+qpow(i,k))%mod;
	for(ll i=1;i<=k+2;i++) x[i]=i;
	if(n<=k+2){
		cout<<p[n]<<endl;
	}
	else{
		cout<<lagrange(k+2,x,p,n)<<endl;
	}
	return 0;
}

【模板】求多项式系数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include<bits/stdc++.h>
#define ri register int
using namespace std;
const int rlen=1<<18|1;
typedef long long ll;
const int mod=998244353;
inline int add(const int&a,const int&b){return a+b>=mod?a+b-mod:a+b;}
inline int dec(const int&a,const int&b){return a>=b?a-b:a-b+mod;}
inline int mul(const int&a,const int&b){return (ll)a*b%mod;}
inline void Add(int&a,const int&b){a=a+b>=mod?a+b-mod:a+b;}
inline void Dec(int&a,const int&b){a=a>=b?a-b:a-b+mod;}
inline void Mul(int&a,const int&b){a=(ll)a*b%mod;}
inline int ksm(int a,int p){int ret=1;for(;p;p>>=1,a=mul(a,a))if(p&1)Mul(ret,a);return ret;}
vector<int>pos[19],A,B,w[19];
int lim,tim,invv[19];
const int Lim=1<<19;
inline void ntt_init(){
    for(ri mid=1,inv2=(mod+1)/2,w0,t=0,mult=1;mid<Lim;mid<<=1,++t,Mul(mult,inv2)){
        w[t].resize(mid);
        w[t][0]=1,w0=ksm(3,(mod-1)/(mid<<1));
        for(ri i=1;i<mid;++i)w[t][i]=mul(w[t][i-1],w0);
        invv[t]=mult;
        pos[t].resize(mid);
        for(ri i=1;i<mid;++i)pos[t][i]=(pos[t][i>>1]>>1)|((i&1)<<(t-1));
    }
}
inline void init(const int&up){
    lim=1,tim=0;
    while(lim<=up)lim<<=1,++tim;
}
inline void ntt(vector<int>&a,const int&type){
    for(ri i=0;i<lim;++i)if(i<pos[tim][i])swap(a[i],a[pos[tim][i]]);
    for(ri mid=1,a0,a1,t=0;mid<lim;mid<<=1,++t){
        for(ri j=0,len=mid<<1;j<lim;j+=len){
            for(ri k=0;k<mid;++k){
                a0=a[j+k],a1=mul(a[j+k+mid],w[t][k]);
                a[j+k]=add(a0,a1),a[j+k+mid]=dec(a0,a1);
            }
        }
    }
    if(~type)return;
    reverse(++a.begin(),a.end());
    for(ri i=0;i<lim;++i)Mul(a[i],invv[tim]);
}
struct poly{
    vector<int>a;
    poly(int k=0,int x=0){a.resize(k+1),a[k]=x;}
    inline int deg()const{return a.size()-1;}
    inline int&operator[](const int&k){return a[k];}
    inline const int&operator[](const int&k)const{return a[k];}
    inline poly extend(const int&k){poly ret=*this;return ret.a.resize(k+1),ret;}
    inline poly rev(){return reverse(a.begin(),a.end()),*this;}
    inline int getval(const int&x){
        int ret=0;
        for(ri i=0,mult=1,up=deg();i<=up;++i,Mul(mult,x))Add(ret,mul(mult,a[i]));
        return ret;
    }
};
inline poly operator+(const poly&a,const poly&b){
    int n=a.deg(),m=b.deg();
    poly ret(max(n,m));
    for(ri i=0;i<=n;++i)Add(ret[i],a[i]);
    for(ri i=0;i<=m;++i)Add(ret[i],b[i]);
    return ret;
}
inline poly operator-(const poly&a,const poly&b){
    int n=a.deg(),m=b.deg();
    poly ret(max(n,m));
    for(ri i=0;i<=n;++i)Add(ret[i],a[i]);
    for(ri i=0;i<=m;++i)Dec(ret[i],b[i]);
    return ret;
}
inline poly operator*(const int&a,const poly&b){
    poly ret=b;
    for(ri i=0,n=ret.deg();i<=n;++i)Mul(ret[i],a);
    return ret;
}
inline poly operator*(const poly&a,const poly&b){
    int n=a.deg(),m=b.deg();
    init(n+m);
    A.resize(lim),B.resize(lim);
    for(ri i=0;i<=n;++i)A[i]=a[i];
    for(ri i=0;i<=m;++i)B[i]=b[i];
    for(ri i=n+1;i<lim;++i)A[i]=0;
    for(ri i=m+1;i<lim;++i)B[i]=0;
    ntt(A,1),ntt(B,1);
    for(ri i=0;i<lim;++i)Mul(A[i],B[i]);
    poly ret;
    return ntt(A,-1),ret.a=A,ret.extend(n+m);
}
inline poly poly_inv(poly a,const int&k){
    poly f0=poly(0,ksm(a[0],mod-2)),g,f;
    for(ri t=2,s=1;t<=k;t<<=1,++s){
        g=a.extend(t-1);
        lim=t<<1,tim=s+1;
        f0=f0.extend(lim-1),g=g.extend(lim-1);
        ntt(f0.a,1),ntt(g.a,1);
        for(ri i=0;i<lim;++i)f0[i]=mul(f0[i],dec(2,mul(f0[i],g[i])));
        ntt(f0.a,-1);
        f0=f0.extend(t-1);
    }
    return f0;
}
inline poly poly_direv(const poly&a){
    poly ret(a.deg()-1);
    for(ri i=1,up=a.deg();i<=up;++i)ret[i-1]=mul(a[i],i);
    return ret;
}
inline poly operator/(const poly&a,const poly&b){
    int len=1,up;
    up=a.deg()-b.deg();
    poly ta=a,tb=b;
    ta=ta.rev().extend(up),tb=tb.rev().extend(up);
    while(len<=up)len<<=1;
    tb=poly_inv(tb,len).extend(up);
    return (ta*tb).extend(up).rev();
}
inline poly operator%(const poly&a,const poly&b){return a.deg()<b.deg()?a:(a-a/b*b).extend(b.deg()-1);}
const int N=1e5+5;
#define lc (p<<1)
#define rc (p<<1|1)
#define mid (l+r>>1)
poly T[N<<2];
int n,fx[N],fy[N],val[N];
inline void build(int p,int l,int r){
    if(l==r){T[p]=poly(1,1),T[p][0]=mod-fx[l];return;}
    build(lc,l,mid),build(rc,mid+1,r);
    T[p]=T[lc]*T[rc];
}
inline void getval(int p,int l,int r,poly t){
    if(l==r){
        val[l]=mul(ksm(t.getval(fx[l]),mod-2),fy[l]);
        return;
    }
    getval(lc,l,mid,t%T[lc]),getval(rc,mid+1,r,t%T[rc]);
}
inline poly dc_ntt(int p,int l,int r){
    if(l==r)return poly(0,val[l]);
    poly L=dc_ntt(lc,l,mid),R=dc_ntt(rc,mid+1,r);
    return L*T[rc]+T[lc]*R;
}
#undef lc
#undef rc
#undef mid
int main(){
    cin >> n;
    ntt_init();
    for(ri i=1;i<=n;++i) cin >> fx[i] >> fy[i];
    build(1,1,n);
    getval(1,1,n,poly_direv(T[1]));
    poly f=dc_ntt(1,1,n);
    for(ri i=0;i<n;++i)cout<<f[i]<<' ';
    return 0;
}

# 使用单$作为行内数学公式分界符