2017年8月28日 星期一

(Hackerrank) Coprime Paths [樹上莫隊]

https://www.hackerrank.com/challenges/coprime-paths

樹上莫隊


#include <iostream>
#include <stdio.h>
#include <vector>
#include <algorithm>
#include <cstring>
#include <stack>
#include <cmath>
using namespace std;

typedef long long LL;
const int MAX_N = 25006;
const int MAX_P = 1e7 + 1;
const int MAX_Q = 15;
int prime[MAX_P];

void addd(int x,LL val);
void subb(int x,LL val);

struct P {
    int n;
    int p[3];
    int tot;
    void input() {
        scanf("%d",&n);
        tot=0;
        while (n != 1) {
            int pp=prime[n];
            p[tot++] = pp;
            while (n%pp==0) n/=pp;
        }
    }
    void add() {
        for (int i=0;(1<<tot)>i;i++) {
            int sz=0;
            int xx=1;
            for (int j=0;tot>j;j++) {
                if (((1<<j)&i) != 0) {
                    sz++;
                    xx *= p[j];
                }
            }
            if (sz == 0) continue;
            else if (sz%2 == 0) addd(xx,1);
            else addd(xx,-1);
        }
    }
    void sub() {
        for (int i=0;(1<<tot)>i;i++) {
            int sz=0;
            int xx=1;
            for (int j=0;tot>j;j++) {
                if (((1<<j)&i) != 0) {
                    sz++;
                    xx *= p[j];
                }
            }
            if (sz == 0) continue;
            else if (sz%2 == 0) subb(xx,1);
            else subb(xx,-1);
        }
    }
} p[MAX_N];

void build() {
    prime[0] = prime[1] = 1;
    for (int i=2;MAX_P>i;i++) {
        if (prime[i] == 0) {
            prime[i] = i;
            for (LL j=i;MAX_P>j;j+=i) {
                prime[j] = i;
            }
        }
    }
}

vector<int> G[MAX_N];
stack<int,vector<int> > st;
int B;  //block size
int block[MAX_N],b_cnt;  //block[x] --> block id of x
int dfn[MAX_N],dfs_time; //dfn[i] --> time that dfs(i)
int depth[MAX_N];
int pa[MAX_N];
int pin[MAX_N],pout[MAX_N];
int stamp;

#define SZ(x) ((int)(x).size())

void dfs(int u,int cur_depth,int par) {
    pin[u] = ++stamp;
    pa[u] = par;
    depth[u] = cur_depth;
    dfn[u] = dfs_time++;
    int buttom = SZ(st);
    for (int v:G[u]) {
        if (v==par) continue;
        dfs(v,cur_depth+1,u);
        if (SZ(st) - buttom >= B) {
            while (SZ(st) != buttom) {
                block[st.top()] = b_cnt;
                st.pop();
            }
            b_cnt++;
        }
    }
    st.emplace(u);
    pout[u] = ++stamp;
}

void make_block(int root,int n) {
    B=sqrt(n);
    b_cnt = dfs_time = 0;
    dfs(root,1,root);
    while (SZ(st) != 0) {
        block[st.top()] = b_cnt-1;
        st.pop();
    }
}

int cnt[MAX_P];

struct QUERY {
    int u,v,id;
    void give_val(int _u,int _v,int _id) {
        u=_u;
        v=_v;
        if (dfn[u] > dfn[v]) swap(u,v);
        id=_id;
    }
    bool operator<(const QUERY &b) {
        if (block[u] != block[b.u]) return block[u] < block[b.u];
        return dfn[v] < dfn[b.v];
    }
} query[MAX_N];

int ans[MAX_N];

int lca[MAX_Q][MAX_N];

void pre_lca(int n) {
    for (int i=0;MAX_Q>i;i++) {
        for (int j=1;n>=j;j++) {
            if (!i) lca[i][j] = pa[j];
            else lca[i][j] = lca[i-1][lca[i-1][j]];
        }
    }
}

bool is_anc(int son,int par) {
    return pin[par] <= pin[son] && pout[son] <= pout[par];
}

int get_lca(int u,int v) {
    if (depth[u] > depth[v]) swap(u,v);
    if (is_anc(v,u)) return u;
    for (int i=MAX_Q-1;i>=0;i--) {
        if (!is_anc(v,lca[i][u])) {
            u=lca[i][u];
        }
    }
    return lca[0][u];
}

#define minus sagiri

LL tot_sz;
LL minus;

void addd(int x,LL val) {
    if (cnt[x] >= 1) {
        minus -= val*(cnt[x])*(cnt[x]-1)/2;
        cnt[x]++;
        minus += val*(cnt[x])*(cnt[x]-1)/2;
    }
    else {
        cnt[x]++;
    }
}

void subb(int x,LL val) {
    if (cnt[x] >= 2) {
        minus -= val*(cnt[x])*(cnt[x]-1)/2;
        cnt[x]--;
        minus += val*(cnt[x])*(cnt[x]-1)/2;
    }
    else {
        cnt[x]--;
    }
}

void add(P x) {
    tot_sz++;
    x.add();
}

void sub(P x) {
    tot_sz--;
    x.sub();
}

bool in_set[MAX_N];

void flip (int x) {
    if (in_set[x]) sub(p[x]);
    else add(p[x]);
    in_set[x] ^= 1;
}

void move(int a,int b) {
    int lca=get_lca(a,b);
    for (;a!=lca;a=pa[a]) flip(a);
    for (;b!=lca;b=pa[b]) flip(b);
}

int main () {
    build();
    int n,q;
    scanf("%d %d",&n,&q);
    for (int i=1;n>=i;i++) {
        p[i].input();
    }
    for (int i=1;n-1>=i;i++) {
        int a,b;
        scanf("%d %d",&a,&b);
        G[a].push_back(b);
        G[b].push_back(a);
    }
    make_block(1,n);
    pre_lca(n);
    for (int i=1;q>=i;i++) {
        int a,b;
        scanf("%d %d",&a,&b);
        query[i].give_val(a,b,i);
    }
    sort(query+1,query+q+1);
    int u=1,v=1;
    tot_sz = minus = 0;
    for (int i=1;q>=i;i++) {
        int uu=query[i].u,vv=query[i].v;
        move(u,uu);
        move(v,vv);
        u=uu;
        v=vv;
        int lca=get_lca(u,v);
        add(p[lca]);
        ans[query[i].id] = tot_sz*(tot_sz-1)/2 + minus;
        sub(p[lca]);
    }
    for (int i=1;q>=i;i++) {
        printf("%d\n",ans[i]);
    }
}


沒有留言:

張貼留言