首先,可以先把長官關係轉化成一顆樹,並且給定尤拉序列(Euler tour)。
所以,接下來詢問就變成:有多少區間pair(x,y),使得區間 x 屬於 r1,區間 y屬於 r2,並且r1完整包覆r2。
到這邊,可以先提出很多種算法,到時候再一起併起來之類的(?)
定義 S1 是 r1 集合的大小, S2 是 r2集合的大小
(1) 預處理答案
預處理所有的 (r1,r2) 組合,注意到對於一個r1,可以在O(N)的時間 (走訪這棵樹) 求出所有 r2的答案。
(2) O(S1 + S2)
如果 r1 和 r2 是已經排好序的話,就可以好好的使用類似雙指針 (類似而已,實作細節可參考底下的query3() )。
(3) O(S1 log S2) or O(S2 log S1)
好好的寫XD。
對於每個 r ,先按照尤拉序列走訪後的區間排序,之後開個持久化線段樹,好好的想辦法讓每次的問題都做到 log 的查詢時間
這樣一來,有兩種併法:
(一)
(1) + (2)
如果 S1, S2 其中一個 > C,使用 (1),否則使用(2)
取C = N^0.5的話,分析複雜度後會是好的。
(二)
(2) + (3) //下面code的寫法
如果S1, S2其中一個 > C,使用(3) ,選擇複雜度比較小的那個
否則使用 (2)
適當的取C,會是好的XDDD。
#include <bits/stdc++.h> using namespace std; struct Node { Node *lc,*rc; int sum; Node():lc(NULL),rc(NULL),sum(0){} void pull() { sum = lc->sum + rc->sum; } }; Node* Build(int L,int R) { Node* node = new Node(); if (L==R) return node; int mid=(L+R)>>1; node->lc = Build(L,mid); node->rc = Build(mid+1,R); return node; } Node* getNode(Node* old) { Node* node = new Node(); node->lc = old->lc; node->rc = old->rc; node->sum = old->sum; return node; } void modify(Node* old,Node* node,int L,int R,int pos,int val) { if (L==R) { node->sum += val; return; } int mid=(L+R)>>1; if (pos <= mid) { node->lc = getNode(old->lc); modify(old->lc,node->lc,L,mid,pos,val); } else { node->rc = getNode(old->rc); modify(old->rc,node->rc,mid+1,R,pos,val); } node->pull(); } int query(Node* node,int L,int R,int l,int r) { if (l>R || L>r) return 0; else if (l<=L && R<=r) return node->sum; int mid=(L+R)>>1; return query(node->lc,L,mid,l,r) + query(node->rc,mid+1,R,l,r); } typedef pair<int,int> pii; const int N = 200006; #define SZ(x) ((int)(x).size()) vector<int> G[N]; vector<pii> rr[N]; int r_pos[N]; #define F first #define S second pii p[N]; int stamp; void dfs(int now) { p[now].F = (++stamp); for (int i:G[now]) { dfs(i); } p[now].S = stamp; } Node* root[N]; int pre[N]; int n; int query1(int x,int y) { //how many interval cover y //size of y is smaller int ret=0; int lastpos = 0; for (pii p:rr[y]) { int pos = lower_bound(rr[x].begin() + lastpos,rr[x].end(),p) - rr[x].begin(); lastpos = pos; if (pos) { ret += query(root[ pre[x]+pos ],1,n,p.S,n); } } return ret; } int query2(int x,int y) { //how many interval within y //size of y is smaller //cout << "X = " << x << " , y = " <<y <<endl; int ret=0; int lastpos=0; for (pii p:rr[y]) { int pos = lower_bound(rr[x].begin()+lastpos,rr[x].end(),p) - rr[x].begin(); //cout << " p = " << p.F <<" , " << p.S << " , pos = " << pos << endl; lastpos = pos; if (pos != 0) { //cout << "pos = " << pos << " , pre[x] = " << pre[x] << " , pre[x+1] = " << pre[x+1] <<endl; ret += (query(root[ pre[x+1] ],1,n,p.F,p.S) - query(root[ pre[x]+pos ],1,n,p.F,p.S)); } else { ret += query(root[ pre[x+1] ],1,n,p.F,p.S); } } return ret; } int query3(int x,int y) { int ptrl = 0; int ptrr = -1; int ret=0; for (int i=0;SZ(rr[x])>i;i++) { while (ptrl != SZ(rr[y]) && rr[y][ptrl].F < rr[x][i].F) ++ptrl; while (ptrr != -1 && rr[y][ptrr].F > rr[x][i].S) --ptrr; while (ptrr != SZ(rr[y])-1 && rr[y][ptrr+1].F <= rr[x][i].S) ++ptrr; ret += (ptrr - ptrl + 1); } return ret; } const int C = 706; int main () { int r,q; scanf("%d%d%d",&n,&r,&q); for (int i=1;n>=i;i++) { if (i == 1) { int x; scanf("%d",&x); //rr[x].push_back(i); r_pos[i] = x; continue; } int par,x; scanf("%d%d",&par,&x); G[par].push_back(i); //rr[x].push_back(i); r_pos[i] = x; } dfs(1); for (int i=1;n>=i;i++) { //cout << "i = " <<i << " , r_pos = " << r_pos[i] << " , p = " << p[i].F << " , " << p[i].S << endl; rr[ r_pos[i] ].push_back(p[i]); } for (int i=1;r>=i;i++) { sort(rr[i].begin(),rr[i].end()); pre[i+1] = pre[i] + SZ(rr[i]); //cout << "i = " <<i << " , pre = " << pre[i] <<endl; } int id=1; root[0] = Build(1,n); for (int i=1;r>=i;i++) { if (SZ(rr[i]) == 0) continue; root[id] = getNode(root[0]); modify(root[0],root[id],1,n,rr[i][0].S,1); ++id; //cout << "hi1" <<endl; for (int j=1;SZ(rr[i])>j;j++) { //cout << "jj = " << j << endl; root[id] = getNode(root[id-1]); modify(root[id-1],root[id],1,n,rr[i][j].S,1); ++id; } //cout << "i = " << i <<" , id = " << id << endl; } map<pii,int> mp; while (q--) { int r1,r2; scanf("%d%d",&r1,&r2); if (SZ(rr[r1]) == 0 ||SZ(rr[r2]) == 0) { printf("0\n"); fflush(stdout); continue; } if (mp.find(make_pair(r1,r2)) != mp.end()) { printf("%d\n",mp[ make_pair(r1,r2) ]); fflush(stdout); continue; } int ans = 0; if (SZ(rr[r1]) < C && SZ(rr[r2]) < C) { ans = query3(r1,r2); } else if (SZ(rr[r1]) >= SZ(rr[r2])) { ans = query1(r1,r2); } else { ans = query2(r2,r1); } mp[ make_pair(r1,r2) ] = ans; printf("%d\n",ans); fflush(stdout); } }
沒有留言:
張貼留言