2018年5月23日 星期三

(IOI) 2009 Day 2 p3 Regions

https://contest.yandex.com/ioi/contest/1363/problems/G/

首先,可以先把長官關係轉化成一顆樹,並且給定尤拉序列(Euler tour)。

所以,接下來詢問就變成:有多少區間pair(x,y),使得區間 x 屬於 r1,區間 y屬於 r2,並且r1完整包覆r2。

到這邊,可以先提出很多種算法,到時候再一起併起來之類的(?)

 定義 S1 是 r1 集合的大小, S2 是 r2集合的大小

(1) 預處理答案

預處理所有的 (r1,r2) 組合,注意到對於一個r1,可以在O(N)的時間 (走訪這棵樹) 求出所有 r2的答案。

(2)  O(S1 + S2)

如果 r1 和 r2 是已經排好序的話,就可以好好的使用類似雙指針 (類似而已,實作細節可參考底下的query3() )。

(3) O(S1 log S2) or O(S2 log S1)

好好的寫XD。

對於每個 r ,先按照尤拉序列走訪後的區間排序,之後開個持久化線段樹,好好的想辦法讓每次的問題都做到 log 的查詢時間

這樣一來,有兩種併法:

(一)

(1) + (2)
如果 S1, S2 其中一個 > C,使用 (1),否則使用(2)
取C = N^0.5的話,分析複雜度後會是好的。

(二)

(2) + (3)  //下面code的寫法

如果S1, S2其中一個 > C,使用(3) ,選擇複雜度比較小的那個
否則使用 (2)

適當的取C,會是好的XDDD。

#include <bits/stdc++.h>
using namespace std;

struct Node
{
    Node *lc,*rc;
    int sum;
    Node():lc(NULL),rc(NULL),sum(0){}
    void pull()
    {
        sum = lc->sum + rc->sum;
    }
};

Node* Build(int L,int R)
{
    Node* node = new Node();
    if (L==R) return node;
    int mid=(L+R)>>1;
    node->lc = Build(L,mid);
    node->rc = Build(mid+1,R);
    return node;
}

Node* getNode(Node* old)
{
    Node* node = new Node();
    node->lc = old->lc;
    node->rc = old->rc;
    node->sum = old->sum;
    return node;
}

void modify(Node* old,Node* node,int L,int R,int pos,int val)
{
    if (L==R)
    {
        node->sum += val;
        return;
    }
    int mid=(L+R)>>1;
    if (pos <= mid)
    {
        node->lc = getNode(old->lc);
        modify(old->lc,node->lc,L,mid,pos,val);
    }
    else
    {
        node->rc = getNode(old->rc);
        modify(old->rc,node->rc,mid+1,R,pos,val);
    }
    node->pull();
}

int query(Node* node,int L,int R,int l,int r)
{
    if (l>R || L>r) return 0;
    else if (l<=L && R<=r) return node->sum;
    int mid=(L+R)>>1;
    return query(node->lc,L,mid,l,r) + query(node->rc,mid+1,R,l,r);
}

typedef pair<int,int> pii;
const int N = 200006;

#define SZ(x) ((int)(x).size())

vector<int> G[N];
vector<pii> rr[N];
int r_pos[N];



#define F first
#define S second

pii p[N];

int stamp;

void dfs(int now)
{
    p[now].F = (++stamp);
    for (int i:G[now])
    {
        dfs(i);
    }
    p[now].S = stamp;
}

Node* root[N];
int pre[N];

int n;

int query1(int x,int y)
{
    //how many interval cover y
    //size of y is smaller
    int ret=0;
    int lastpos = 0;
    for (pii p:rr[y])
    {
        int pos = lower_bound(rr[x].begin() + lastpos,rr[x].end(),p) - rr[x].begin();
        lastpos = pos;
        if (pos)
        {
            ret += query(root[ pre[x]+pos ],1,n,p.S,n);
        }
    }
    return ret;
}

int query2(int x,int y)
{
    //how many interval within y
    //size of y is smaller
    //cout << "X = " << x << " , y = " <<y <<endl;
    int ret=0;
    int lastpos=0;
    for (pii p:rr[y])
    {
        int pos = lower_bound(rr[x].begin()+lastpos,rr[x].end(),p) - rr[x].begin();
        //cout << " p = " << p.F <<" , " << p.S << " , pos = " << pos << endl;
        lastpos = pos;
        if (pos != 0)
        {
            //cout << "pos = " << pos << " , pre[x] = " << pre[x] << " , pre[x+1] = " << pre[x+1] <<endl;
            ret += (query(root[ pre[x+1] ],1,n,p.F,p.S) - query(root[ pre[x]+pos ],1,n,p.F,p.S));
        }
        else
        {
            ret += query(root[ pre[x+1] ],1,n,p.F,p.S);
        }
    }
    return ret;
}

int query3(int x,int y)
{
    int ptrl = 0;
    int ptrr = -1;
    int ret=0;
    for (int i=0;SZ(rr[x])>i;i++)
    {
        while (ptrl != SZ(rr[y]) && rr[y][ptrl].F < rr[x][i].F) ++ptrl;
        while (ptrr != -1 && rr[y][ptrr].F > rr[x][i].S) --ptrr;
        while (ptrr != SZ(rr[y])-1 && rr[y][ptrr+1].F <= rr[x][i].S) ++ptrr;
        ret += (ptrr - ptrl + 1);
    }
    return ret;
}

const int C = 706;

int main ()
{
    int r,q;
    scanf("%d%d%d",&n,&r,&q);
    for (int i=1;n>=i;i++)
    {
        if (i == 1)
        {
            int x;
            scanf("%d",&x);
            //rr[x].push_back(i);
            r_pos[i] = x;
            continue;
        }
        int par,x;
        scanf("%d%d",&par,&x);
        G[par].push_back(i);
        //rr[x].push_back(i);
        r_pos[i] = x;
    }
    dfs(1);
    for (int i=1;n>=i;i++)
    {
        //cout << "i = " <<i << " , r_pos = " << r_pos[i] << " , p = " << p[i].F << " , " << p[i].S << endl;
        rr[ r_pos[i] ].push_back(p[i]);
    }
    for (int i=1;r>=i;i++)
    {
        sort(rr[i].begin(),rr[i].end());
        pre[i+1] = pre[i] + SZ(rr[i]);
        //cout << "i = " <<i << " , pre = " << pre[i] <<endl;
    }
    int id=1;
    root[0] = Build(1,n);
    for (int i=1;r>=i;i++)
    {
        if (SZ(rr[i]) == 0) continue;
        root[id] = getNode(root[0]);
        modify(root[0],root[id],1,n,rr[i][0].S,1);
        ++id;
        //cout << "hi1" <<endl;
        for (int j=1;SZ(rr[i])>j;j++)
        {
            //cout << "jj = " << j << endl;
            root[id] = getNode(root[id-1]);
            modify(root[id-1],root[id],1,n,rr[i][j].S,1);
            ++id;
        }
        //cout << "i = " << i <<" , id = " << id << endl;
    }
    map<pii,int> mp;
    while (q--)
    {
        int r1,r2;
        scanf("%d%d",&r1,&r2);
        if (SZ(rr[r1]) == 0 ||SZ(rr[r2]) == 0)
        {
            printf("0\n");
            fflush(stdout);
            continue;
        }
        if (mp.find(make_pair(r1,r2)) != mp.end())
        {
            printf("%d\n",mp[ make_pair(r1,r2) ]);
            fflush(stdout);
            continue;
        }
        int ans = 0;
        if (SZ(rr[r1]) < C && SZ(rr[r2]) < C)
        {
            ans = query3(r1,r2);
        }
        else if (SZ(rr[r1]) >= SZ(rr[r2]))
        {
            ans = query1(r1,r2);
        }
        else
        {
            ans = query2(r2,r1);
        }
        mp[ make_pair(r1,r2) ] = ans;
        printf("%d\n",ans);
        fflush(stdout);
    }
}


沒有留言:

張貼留言