loj #6059. 「2017 山东一轮集训 Day1」Sum — 倍增+NTT

#6059. 「2017 山东一轮集训 Day1」Sum

内存限制:256 MiB 时间限制:1500 ms

题目描述

求有多少 n 位十进制数是 p pp 的倍数且每位之和小于等于 mi(mi=0,1,2,…,m−1,m) ,允许前导 0,答案对 998244353 取模。

输入格式

一行三个整数 n,p,m

输出格式

输出一行 m+1个正整数,分别表示 mi=0,1,2,…,m−1,m 时的答案。

样例

样例输入

2 3 3

样例输出

1 1 1 5

数据范围与提示

 

首先裸dp比较好推,f[i][j][k]表示第i位,模p为j,数字和为k的方案数

由于n很大,我们可以考虑二进制拆分

倍增求出f[2^i][][],然后将n的二进制位是1的合并起来就好了

然后就是考虑如何合并两个数组

可以先枚举k,k=k1+k2,然后p^2的枚举每一个余数就好

我们不难发现这就是一个卷积

用NTT加速即可

这里复杂度是(p*mlogm+p^2 m)的,因为我们只需要一次正变换,然后乘完之后再变换回来即可

所以总复杂度(p*mlogm+p^2 m)logn

 

#include<map>
#include<cmath>
#include<queue>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define inf 1000000007
#define mod 998244353
#define ll long long
#define N 4010
int r[N],L=-1,G=3,nn;
ll inv;
ll ksm(ll a,int b)
{
	ll sum=1;
	while(b)
	{
		if(b&1) sum=sum*a%mod;
		a=a*a%mod;b>>=1;
	}
	return sum;
}
ll t,f[55][N],g[55][N],h[55][N];
void NTT(ll *x,int f)
{
	int i,j,k;
	ll wn,w,X,Y;
	for(i=0;i<nn;i++) if(i<r[i]) swap(x[i],x[r[i]]);
	for(i=1;i<nn;i<<=1)
	{
		wn=ksm(G,(mod-1)/(i<<1));
		for(j=0;j<nn;j+=(i<<1))
		{
			for(k=0,w=1;k<i;k++,w=w*wn%mod)
			{
				X=x[j+k];Y=w*x[j+k+i]%mod;
				x[j+k]=(X+Y)%mod;x[j+k+i]=(X-Y+mod)%mod;
			}
		}
	}
	if(f==-1)
	{
		reverse(x+1,x+nn);
		for(i=0;i<nn;i++) x[i]=x[i]*inv%mod;
	}
}
int n,m,p;
void sol1()
{
	register int i,j,k;
	for(i=0;i<p;i++) NTT(f[i],1),NTT(g[i],1);
	memset(h,0,sizeof(h));
	for(i=0;i<p;i++)
		for(j=0;j<p;j++)
			for(k=0;k<nn;k++) (h[(t*i+j)%p][k]+=f[i][k]*g[j][k])%=mod;
	for(i=0;i<p;i++)
	{
		NTT(h[i],-1),NTT(g[i],-1);
		for(j=0;j<=m;j++) f[i][j]=h[i][j];
		for(j=m+1;j<nn;j++) f[i][j]=0;
	}
}
void sol2()
{
	register int i,j,k;
	for(i=0;i<p;i++) NTT(g[i],1);
	memset(h,0,sizeof(h));
	for(i=0;i<p;i++)
		for(j=0;j<p;j++)
			for(k=0;k<nn;k++) (h[(t*i+j)%p][k]+=g[i][k]*g[j][k])%=mod;
	for(i=0;i<p;i++)
	{
		NTT(h[i],-1);
		for(j=0;j<=m;j++) g[i][j]=h[i][j];
		for(j=m+1;j<nn;j++) g[i][j]=0;
	}
}
int main()
{
	scanf("%d%d%d",&n,&p,&m);
	for(nn=1;nn<=m*2;nn<<=1) L++; inv=ksm(nn,mod-2);
	for(int i=0;i<nn;i++) r[i]=(r[i>>1]>>1)|((i&1)<<L);
	t=10;f[0][0]=1;
	for(int i=0;i<=9&&i<=m;i++) g[i%p][i]++;
	while(n)
	{
		if(n&1) sol1();
		sol2();n>>=1;t=t*t%p;
	}
	for(int i=1;i<=m;i++) (f[0][i]+=f[0][i-1])%=mod;
	for(int i=0;i<=m;i++) printf("%lld%c",f[0][i]," \n"[i==m]);
	return 0;
}

 

评论

还没有任何评论,你来说两句吧

发表评论

衫小寨 出品