题解

2 条题解

  • 0
    @ 2025-06-10 18:11:54
    #pragma GCC optimize ("unroll-loops")
    #pragma GCC optimize ("Ofast")
    #include<cstdio>
    #include<iostream>
    #include<cstring>
    #include<cmath>
    #include<algorithm>
    #define N 32774
    #define ll long long
    #define reg register
    #define add(x,y) (x+y>=p?x+y-p:x+y)
    #define dec(x,y) (x<y?x-y+p:x-y)
    using namespace std;
    
    int p;
    
    int rev[N],rt[N];
    int siz;
    
    #define md 998244353
    
    inline int power(int a,int t){
        int res = 1;
        while(t){
            if(t&1) res = (ll)res*a%md;
            a = (ll)a*a%md;
            t >>= 1; 
        }
        return res;
    }
    
    void init(int n){
        int w,lim = 1;
        while(lim<=n) lim <<= 1,++siz;
        for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
        rt[lim>>1] = 1;
        w = power(3,(md-1)>>siz);
        for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%md;
        for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
    }
    
    inline void dft(int *f,int lim){
        static unsigned long long a[N];
        reg int x,shift = siz-__builtin_ctz(lim);
        for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
        for(reg int mid=1;mid!=lim;mid<<=1)
        for(reg int j=0;j!=lim;j+=(mid<<1))
        for(reg int k=0;k!=mid;++k){
            x = a[j|k|mid]*rt[mid|k]%md;
            a[j|k|mid] = a[j|k]+md-x;
            a[j|k] += x;
        }
        for(reg int i=0;i!=lim;++i) f[i] = a[i]%md;
    }
    
    inline void idft(int *f,int lim){
        reverse(f+1,f+lim);
        dft(f,lim);
        int x = md-((md-1)>>__builtin_ctz(lim));
        for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%md;
    }
    
    inline int getlen(int n){
        return 1<<(32-__builtin_clz(n));
    }
    
    void multiply(const int *A,const int *B,int n,int m,int *R,int len){
        static int f[N],g[N];
        memcpy(f,A,(n+1)<<2),memcpy(g,B,(m+1)<<2);
        int lim = getlen(n+m);
        memset(f+n+1,0,(lim-n)<<2);
        memset(g+m+1,0,(lim-m)<<2);
        dft(f,lim),dft(g,lim);
        for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*g[i]%md;
        idft(f,lim);
        for(reg int i=0;i<=len;++i) R[i] = f[i]%p;
    }
    
    inline void inverse(const int *f,int n,int *R){
        static int g[N],h[N];
        memset(g,0,getlen(n<<1)<<2);
        int lim = 1,top = 0;
        int s[30];
        while(n){
            s[++top] = n;
            n >>= 1;
        }
        g[0] = 1;
        while(top--){
            n = s[top+1];
            while(lim<=(n<<1)) lim <<= 1;
            memcpy(h,f,(n+1)<<2);
            memset(h+n+1,0,(lim-n)<<2);
            multiply(h,g,n,n,h,n);
            multiply(h,g,n,n,h,n);
            for(reg int i=0;i<=n;++i) g[i] = dec(add(g[i],g[i]),h[i]);
        }
        memcpy(R,g,(n+1)<<2);
    }
    
    inline void power(int *f,int n,int k,int *R){
        int g[N];
        g[0] = 1;
        while(1){
            if(k&1) multiply(g,f,n,n,g,n);
            k >>= 1;
            if(k==0) break;
            multiply(f,f,n,n,f,n);
        }
        memcpy(R,g,(n+1)<<2);
    }
    
    int m,A,o,s,u;
    int f[N],g[N],h[N];
    
    int main(){
        scanf("%d%d%d%d%d%d",&m,&p,&A,&o,&s,&u);
        init(m<<1);
        for(reg int i=1;i<=m;++i) f[i] = (u+i*(s+o*i))%p;
        if(A<m) memcpy(g,f+1,(m-A)<<2);
        f[0] = 1;
        for(reg int i=1;i<=m;++i) f[i] = f[i]==0?0:p-f[i];
        inverse(f,m,f);
        if(A>=m){
            printf("%d",f[m]);
            return 0;
        }
        power(g,m-A-1,A+1,h);
        for(reg int i=m;i>A;--i) h[i] = h[i-A-1];
        for(reg int i=A+1;i<=m;++i) h[i] = h[i]==0?0:p-h[i];
        int ans = f[m];
        for(reg int i=A+1;i<=m;++i)
            if(h[i]!=0) ans = (ans+h[i]*f[m-i])%p;
        printf("%d",ans);    
        return 0;   
    }
    
  • -1
    @ 2016-09-20 15:29:54
    #include<cstdio>
    #define rec(x,y) rec[aa[y]+x]
    using namespace std;
    
    int n,m,a,b,c,p,l1,l2,l3,l0;
    int rec[100000001],ans=0,aa[10001],f[255];
    int main(){
        scanf("%d%d%d%d%d%d",&m,&p,&n,&a,&b,&c);
        register int i,j;
        if (n>m) n=m;
        aa[1]=0;
        a%=p;b%=p;c%=p;
        f[0]=c;
        for (i=1;i<p;i++){
            f[i]=f[i-1]+a*(2*i-1)+b;
            while (f[i]>=p) f[i]-=p;
        }
        for (i=2;i<=m;i++) aa[i]=aa[i-1]+i-1;
        for (i=1,j=1;i<=m;i++,j++,j-=j>=p?p:0)
        if (i>p) rec(1,i)=rec(1,(j==0?p-1:j-1)+1);else rec(1,i)=f[j];
        int k=((rec(1,2)-3*rec(1,1))+3*p)%p,y=(2*a-rec(1,2)+2*rec(1,1)+100*p)%p;
        ans=rec(1,m);
        for (i=2;i<=n;i++){
            rec(i,i)=(rec(i-1,i-1)*f[1])%p;
            int d=rec(i,i);
            j=i+1;
            d+=((k*rec(i-1,j-2))+(f[1]*rec(i-1,j-1)));
            rec(i,j)=(2*rec(i,j-1)+d+p)%p;
            j++;
            l0=aa[j]+i;l1=aa[j-1]+i-1;l2=aa[j-2]+i-1;l3=aa[j-3]+i-1;
            for (;j<=m;j++){
                d+=((y*rec[l3])+(k*rec[l2])+(f[1]*rec[l1]));
                rec[l0]=(2*rec[l1+1]-rec[l2+1]+d+p)%p;
                l0+=j;l1+=j-1;l2+=j-2;l3+=j-3;
            }
            ans+=rec(i,m);
        }
        printf("%d\n",ans%p);
    }```
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    ac全靠rp
    测试数据 #0: Accepted, time = 0 ms, mem = 391936 KiB, score = 10
    测试数据 #1: Accepted, time = 0 ms, mem = 391936 KiB, score = 10
    测试数据 #2: Accepted, time = 0 ms, mem = 391932 KiB, score = 10
    测试数据 #3: Accepted, time = 0 ms, mem = 391932 KiB, score = 10
    测试数据 #4: Accepted, time = 15 ms, mem = 391940 KiB, score = 10
    测试数据 #5: Accepted, time = 0 ms, mem = 391936 KiB, score = 10
    测试数据 #6: Accepted, time = 328 ms, mem = 391936 KiB, score = 10
    测试数据 #7: Accepted, time = 328 ms, mem = 391936 KiB, score = 10
    测试数据 #8: Accepted, time = 984 ms, mem = 391932 KiB, score = 10
    测试数据 #9: Accepted, time = 1000 ms, mem = 391936 KiB, score = 10
    Accepted, time = 2655 ms, mem = 391940 KiB, score = 100
    代码
    ```c++
    #include<cstdio>
    #define rec(x,y) rec[aa[y]+x]
    using namespace std;
    
    int n,m,a,b,c,p,l1,l2,l3,l0;
    int rec[100000001],ans=0,aa[10001],f[255];
    int main(){
        scanf("%d%d%d%d%d%d",&m,&p,&n,&a,&b,&c);
        register int i,j;
        if (n>m) n=m;
        aa[1]=0;
        a%=p;b%=p;c%=p;
        f[0]=c;
        for (i=1;i<p;i++){
            f[i]=f[i-1]+a*(2*i-1)+b;
            while (f[i]>=p) f[i]-=p;
        }
        for (i=2;i<=m;i++) aa[i]=aa[i-1]+i-1;
        for (i=1,j=1;i<=m;i++,j++,j-=j>=p?p:0)
        if (i>p) rec(1,i)=rec(1,(j==0?p-1:j-1)+1);else rec(1,i)=f[j];
        int k=((rec(1,2)-3*rec(1,1))+3*p)%p,y=(2*a-rec(1,2)+2*rec(1,1)+100*p)%p;
        ans=rec(1,m);
        for (i=2;i<=n;i++){
            rec(i,i)=(rec(i-1,i-1)*f[1])%p;
            int d=rec(i,i);
            j=i+1;
            d+=((k*rec(i-1,j-2))+(f[1]*rec(i-1,j-1)));
            rec(i,j)=(2*rec(i,j-1)+d+p)%p;
            j++;
            l0=aa[j]+i;l1=aa[j-1]+i-1;l2=aa[j-2]+i-1;l3=aa[j-3]+i-1;
            for (;j<=m;j++){
                d+=((y*rec[l3])+(k*rec[l2])+(f[1]*rec[l1]));
                rec[l0]=(2*rec[l1+1]-rec[l2+1]+d+p)%p;
                l0+=j;l1+=j-1;l2+=j-2;l3+=j-3;
            }
            ans+=rec(i,m);
        }
        printf("%d\n",ans%p);
    }```
    
  • 1

信息

ID
1955
难度
9
分类
(无)
标签
递交数
357
已通过
23
通过率
6%
被复制
3
上传者