2017年12月14日 星期四

(POJ) 2778. DNA Sequence [AC自動機 + 矩陣快速冪]

http://poj.org/problem?id=2778

被坑了有點久><

矩陣快速冪寫遞迴一直RE ><

#include <iostream>
#include <cstring>
#include <cassert>
#include <string>
using namespace std;

typedef long long LL;
const LL mod = 100000;

struct Matrix {
    static const int N =106;
    LL a[N][N];
    int n;
    void init(int _n)
    {
        n=_n;
        for (int i=0;n>=i;i++)
        {
            for (int j=0;n>=j;j++)
            {
                a[i][j] = 0;
            }
        }
    }
};

Matrix trans;

struct AC_Machine {
    static const int N = 106;
    static const int SIGMA = 4;
    int ch[N][SIGMA];
    int fail[N];
    int last[N];
    int val[N];
    int que[N];
    int sz,qe,qs;
    void init()
    {
        sz = 1;
        qs = qe = 0;
        memset(ch[0],0,sizeof(ch[0]));
        memset(val,0,sizeof(val));
        memset(last,0,sizeof(last));
    }
    int idx(char c)
    {
        if (c=='A') return 0;
        else if (c=='C') return 1;
        else if (c=='G') return 2;
        else if (c=='T') return 3;
        //else assert(0);
    }
    int insert(char* s,int id)
    {
        int now=0;
        int n=strlen(s);
        for (int i=0;n>i;i++)
        {
            int nxt=idx(s[i]);
            if (!ch[now][nxt])
            {
                memset(ch[sz],0,sizeof(ch[sz]));
                ch[now][nxt] = sz;
                sz++;
            }
            now = ch[now][nxt];
        }
        val[now] = id;
        return now;
    }
    void getFail()
    {
        qs = qe = 0;
        fail[0] = 0;
        for (int c=0;SIGMA >c; c++)
        {
            int nxt=ch[0][c];
            if (nxt)
            {
                fail[nxt] = 0;
                que[qe++] = nxt;
                last[nxt] = 0;
            }
        }
        while (qs != qe)
        {
            int t=que[qs++];
            for (int i=0;SIGMA>i;i++)
            {
                int nxt=ch[t][i];
                if (!nxt) continue;
                que[qe++] = nxt;
                int v=fail[t];
                while (v && !ch[v][i]) v = fail[v];
                fail[nxt] = ch[v][i];
                last[nxt] = val[ fail[nxt] ] ? fail[nxt] :last[ fail[nxt] ];
            }
        }
    }
    void AC_evolution()
    {
        qs=0;
        while (qs != qe)
        {
            int now=que[qs++];
            for (int c=0;SIGMA>c;c++)
            {
                if (!ch[now][c]) ch[now][c] = ch[fail[now] ][c];
            }
        }
    }
    void get_trans()
    {
        trans.init(sz);
        for (int i=0;sz>i;i++)
        {
            for (int j=0;SIGMA>j;j++)
            {
                int nxt=ch[i][j];
                if (!val[nxt] && !last[nxt] && !val[i] &&!last[i])
                {
                    //i --> nxt
                    trans.a[nxt][i]++;
                }
            }
        }
    }
} ac;

const int N = 12;

string s[N];
int ed[N];

Matrix operator*(const Matrix &m1,const Matrix &m2)
{
    Matrix ret;
    ret.init(m1.n);
    int n=m1.n;
    for (int i=0;n>=i;i++)
    {
        for (int j=0;n>=j;j++)
        {
            for (int k=0;n>=k;k++)
            {
                ret.a[i][j] += m1.a[i][k]*m2.a[k][j];
            }
        }
    }
    for (int i=0;n>=i;i++)
    {
        for (int j=0;n>=j;j++)
        {
            ret.a[i][j] %= mod;
        }
    }
    return ret;
}

Matrix poww(Matrix a,LL n)
{
    Matrix ret;
    ret.init(a.n);
    for (int i=0;a.n>=i;i++)
    {
        ret.a[i][i] = 1;
    }
    Matrix now = a;
    while (n)
    {
        if (n&1) ret = ret * now;
        now = now * now;
        n >>= 1;
    }
    return ret;
}

LL powww(LL a,LL n,LL mod)
{
    if (n==1) return a;
    LL ret = powww(a,n/2,mod);
    ret *= ret;
    ret %= mod;
    if (n&1) ret *= a;
    return ret%mod;
}

int main ()
{
    int m,n;
    cin >> m >> n;
    if (m==0)
    {
        cout << powww(4,n,mod) <<endl;
        return 0;
    }
    ac.init();
    for (int i=1;m>=i;i++)
    {
        char qaq[106];
        cin >> qaq;
        ed[i] = ac.insert(qaq,i);
    }
    ac.getFail();
    ac.AC_evolution();
    ac.get_trans();
    //cout<<"trans.n = "<<trans.n<<endl;
    Matrix ret=poww(trans,n);
    //cout <<"QAQ" << endl;
    LL ans=0;
    for (int i=0;trans.n>=i;i++)
    {
        if (ac.val[i] == 0 && ac.last[i] == 0)ans += ret.a[i][0];
    }
    cout << ans%mod << endl;
}

沒有留言:

張貼留言