51nod 1387 移数字

2018/03/06 17:51
阅读数 9

任意门


  回来拉模版的时候意外发现这个题还没写题解,所以就随便补点吧。

  题意其实就是要你求n的阶乘在模意义下的值。

  首先找出来一个最大的$m$满足$m^2<=n$,对于大于$m^2$部分的数我们直接暴力求就行了,问题是求$m^2$以内的答案。

  先构造一个多项式$f(x)=(x+1)(x+2)(x+3)……(x+m)$,然后求它在$x=0、x=m……x=m(m-1)$位置的值,然后求个值全部乘起来就行了。

  

  稍微说下怎么做多点求值,构造两个多项式

  $$G_1(x)=(x-x_1)(x-x_2)……(x-x_{\left\lfloor\frac{m}{2}\right\rfloor})=x(x-m)……(x-\left\lfloor\frac{m}{2}\right\rfloor m)$$

  $$G_2(x)=(x-x_{\left\lfloor\frac{m}{2}\right\rfloor+1})……(x-x_m)=(x-\left\lfloor\frac{m}{2}\right\rfloor m-m)……(x-m^2+m)$$

  然后拿$f(x)$对$G_1(x)$取模得到一个$\left\lfloor\frac{m}{2}\right\rfloor$次的多项式,这个多项式在$x_1、x_2……x_{\left\lfloor\frac{m}{2}\right\rfloor}$位置的值跟$f(x)$是一样的(这是因为构造出来的式子在这些位置都等于0,而我们可以把多项式除法看成很多次减法,所以这个值不会变),后半部分同理用$G_2(x)$处理,这个时候问题规模就减半了,由此递归即可。

  题目最后复杂度是$O(\sqrt{n}log^2\sqrt{n})$

  

  

#include<cstdio>
#include<cstring>
#include<algorithm>
#define lp (p<<1)
#define rp ((p<<1)|1)
#define ll long long
#define MN 200200
using namespace std;
int read_p,read_ca;
inline int read(){
    read_p=0;read_ca=getchar();
    while(read_ca<'0'||read_ca>'9') read_ca=getchar();
    while(read_ca>='0'&&read_ca<='9') read_p=read_p*10+read_ca-48,read_ca=getchar();
    return read_p;
}
int _n,n,m,t,e[MN],_e[MN],Mmh=0,D[MN],C_a[MN],C_b[MN],C_c[MN],N_c[MN],D_a[MN],D_b[MN],D_c[MN],tot,gg=2,MMH[MN],L[MN*5];
int rt[MN*40],B[MN*40],sz=0;
int MOD=104857601;
inline void M(int &x){while(x>=MOD)x-=MOD;}
inline int mi(int a,int b){
    int mmh=1;
    while (b){
        if (b&1) mmh=1LL*mmh*a%MOD;
        b>>=1;a=1LL*a*a%MOD;
    }
    return mmh;
}
inline void inv(){
    int base=mi(gg,(MOD-1)/tot),_base=mi(base,MOD-2);
    e[0]=_e[0]=1;
    for (register int i=1;i<=tot;i++) e[i]=1LL*e[i-1]*base%MOD,_e[i]=1LL*_e[i-1]*_base%MOD;
}
inline void NTT(int N,int a[],int w[]){
    register int i,j,k,m,z;
    for (i=j=0;i<N;i++){
        if (i>j) swap(a[i],a[j]);
        for (k=N>>1;(j^=k)<k;k>>=1);
    }
    for (i=2;i<=N;i<<=1)
    for (m=i>>1,j=0;j<N;j+=i)
    for (k=0;k<m;k++){
        z=1LL*a[j+k+m]*w[tot/i*k]%MOD;
        a[j+k+m]=a[j+k]>z?a[j+k]-z:MOD-z+a[j+k];
        a[j+k]=a[j+k]-MOD+z;if (a[j+k]<0) a[j+k]+=MOD;
    }
}
inline void cc(int N,int a[],int b[],int c[]){
    memcpy(C_a,a,N<<2);memcpy(C_b,b,N<<2);
    NTT(N,C_a,e);NTT(N,C_b,e);
    for (register int i=0;i<N;i++) c[i]=1LL*C_a[i]*C_b[i]%MOD;
    NTT(N,c,_e);
    int w=mi(N,MOD-2);
    for (register int i=0;i<N;i++) c[i]=1LL*c[i]*w%MOD;
}
inline void cc(int n,int m,int a[],int b[],int c[]){
    int N;
    for (N=1;N<(n+m);N<<=1);
    memcpy(C_a,a,n<<2);memcpy(C_b,b,m<<2);
    fill(C_a+n,C_a+N,0);fill(C_b+m,C_b+N,0);
    NTT(N,C_a,e);NTT(N,C_b,e);
    for (register int i=0;i<N;i++) c[i]=1LL*C_a[i]*C_b[i]%MOD;
    NTT(N,c,_e);
    int w=mi(N,MOD-2);
    for (register int i=0;i<N;i++) c[i]=1LL*c[i]*w%MOD;
}
inline void ny(int p,int a[],int b[]){
    if (p==1) b[0]=mi(a[0],MOD-2);else{
        ny((p+1)>>1,a,b);
        int N=1;
        while (N<(p<<1))N<<=1;
        copy(a,a+p,N_c);fill(N_c+p,N_c+N,0);
        NTT(N,N_c,e);NTT(N,b,e);
        for (register int i=0;i<N;i++) b[i]=(2LL-1LL*N_c[i]*b[i]%MOD+MOD)*b[i]%MOD;
        NTT(N,b,_e);
        int w=mi(N,MOD-2);
        for (register int i=0;i<N;i++) b[i]=1LL*b[i]*w%MOD;
        fill(b+p,b+N,0);
    }
}
inline void re_copy(int n,int a[],int b[]){for (register int i=0;i<n;i++) b[i]=a[n-i-1];}
inline void div(int n,int m,int a[],int b[],int d[],int r[]){
    int N=1,t=n-m+1,i;
    while (N<t<<1)N<<=1;
    memset(D_a,0,N<<2);
    memset(D_b,0,N<<2);
    memset(D_c,0,N<<2);
    memset(d,0,N<<2);
    re_copy(m,b,D_b);
    re_copy(n,a,D_a);
    ny(t,D_b,D_c);
    for (N=1;N<(n<<1);N<<=1);
    cc(n,t,D_a,D_c,D_b);
    re_copy(t,D_b,d);
    fill(d+t,d+N,0);
    cc(t,m,d,b,D_a);
    for (i=0;i<m;i++) r[i]=(1LL*a[i]-D_a[i]+MOD)%MOD;
    fill(r+m,r+N,0);
}
inline bool ju(int x){
    int u=MOD-1;
    for (register int i=2;i*i<=u;i++)
    if (u%i==0) if (mi(x,u/i)==1) return 1;
    return 0;
}
int mmh=1;
inline void Mmhp(int p,int l,int r){
    if (l==r){
        L[p]=sz;
        rt[sz]=l;
        rt[sz+1]=1;
        sz+=2;
        return;
    }
    int mid=l+r>>1;
    Mmhp(lp,l,mid);Mmhp(rp,mid+1,r);
    cc(mid-l+2,r-mid+1,rt+L[lp],rt+L[rp],rt+sz);
    L[p]=sz;
    sz+=r-l+2;
}
inline void Mmhrt(int p,int l,int r){
    if (l==r){
        L[p]=sz;
        rt[sz]=(MOD-1LL*m*l%MOD)%MOD;
        rt[sz+1]=1;
        sz+=2;
        return;
    }
    int mid=l+r>>1;
    Mmhrt(lp,l,mid);Mmhrt(rp,mid+1,r);
    cc(mid-l+2,r-mid+1,rt+L[lp],rt+L[rp],rt+sz);
    L[p]=sz;
    sz+=r-l+2;
}
inline void _Mmh(int p,int l,int r,int fi,int LL){
    div(LL,r-l+2,B+fi,rt+L[p],D,B+sz);
    int mid=l+r>>1,s=sz;
    sz+=r-l+2;
    if (l==r) mmh=1LL*B[s]*mmh%MOD;else _Mmh(lp,l,mid,s,r-l+1),_Mmh(rp,mid+1,r,s,r-l+1);
}
int main(){
    register int i;
    n=read();
    MOD=read();
    if (n>=MOD) return printf("0\n"),0;
    while(ju(gg))gg++;
    for (m=1;(m+1)*(m+1)<=n;m++);
    for (tot=1;tot<((m+2)<<1);tot<<=1);inv();
    for (i=m*m+1;i<=n;i++) mmh=1LL*mmh*i%MOD;
    sz=0;Mmhp(1,1,m);
    for (i=L[1];i<=L[1]+m;i++) B[i-L[1]]=rt[i];
    sz=0;Mmhrt(1,0,m-1);
    sz=m+1;_Mmh(1,0,m-1,0,m+1);
    if (n&1) mmh=1LL*mmh*mi(2,MOD-2)%MOD;
    printf("%d\n",mmh);
}
View Code

 

展开阅读全文
ntt
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部