2018年2月23日 星期五

(POJ) 2417. Discrete Logging [baby step giant step, bsgs算法]

http://poj.org/problem?id=2417

題目可以簡化成求a ^ n === b (mod p)。

有一個想法是:把n拆成 x * m + y,y < m,在這裡取m = ceil( sqrt(n) )。

這時,我們就可以枚舉a ^ y,把這個表存在hash表裡面,接著枚舉a ^ x*m。

#include <iostream>
#include <cstdio>
#include <map>
#include <cmath>
#include <utility>
using namespace std;

int mull(int a,int b,int mod)
{
    return a*1LL*b%mod;
}

int ppow(int a,int n,int mod)
{
    int ret=1;
    int now=a;
    while (n)
    {
        if (n&1)
        {
            ret = mull(ret,now,mod);
        }
        now = mull(now,now,mod);
        n >>= 1;
    }
    return ret;
}

//hash table

const int C = 397184;

int get_pos(int x)
{
    return x%C;
}

int now_vis;

int vis_id[C];
map<int,int> mpp[C];

bool has_val(int pos,int ori_pos)
{
    if (vis_id[pos] == now_vis)
    {
        if (mpp[pos].find(ori_pos) != mpp[pos].end()) return true;
        else return false;
    }
    return false;
}

void set_val(int pos,int val,int ori_pos)
{
    if (vis_id[pos] != now_vis) mpp[pos].clear();
    vis_id[pos] = now_vis;
    mpp[pos][ori_pos] = val;
}

int get_val(int pos,int ori_pos)
{
    return mpp[pos][ori_pos];
}

void solve(int p,int b,int n)
{
    //b^L == n (mod p)
    //P < 2**31, 2 <= B < P, 1 <= N < p
    if (p%b==0)
    {
        puts("no solution");
        return;
    }
    if (n==1)
    {
        puts("0");
        return;
    }
    int x = ceil(sqrt(double(p)));
    int now_val = 1;
    for (int i=0;x>i;i++)
    {
        if (has_val(get_pos(now_val),now_val) == false)
        {
            set_val(get_pos(now_val),i,now_val);
            //mp[now_val] = i;
        }
        now_val = mull(now_val,b,p);
        if (now_val == n)
        {
            printf("%d\n",i+1);
            return;
        }
    }
    int now_val_rev = ppow(now_val,p-2,p);
    int tmp_now_val = now_val;
    int tmp_now_val_rev = now_val_rev;
    for (int i=1;x>i;i++)
    {
        int target_val = mull(n,now_val_rev,p);
        if (has_val(get_pos(target_val),target_val))
        {
            printf("%lld\n",get_val(get_pos(target_val),target_val) + i*1LL*x);
            return;
        }
        now_val = mull(now_val,tmp_now_val,p);
        now_val_rev = mull(now_val_rev,tmp_now_val_rev,p);
    }
    puts("no solution");
}

int main ()
{
    int p,b,n;
    while (scanf("%d %d %d",&p,&b,&n) != EOF)
    {
        now_vis++;
        solve(p,b,n);
    }
}



沒有留言:

張貼留言