2018年3月7日 星期三

(POI) XVIII OI Round III (finals) - day 0 (trial) Task Dynamite

https://szkopul.edu.pl/problemset/problem/Xg9hVYries5K7yMcyZYI4NLc/site/?key=statement

cool binary search + DP on tree problem


#include <iostream>
#include <cstdio>
#include <vector>
#include <queue>
#include <cstring>
#include <cassert>
using namespace std;

const int N = 300006;

vector<int> G[N];

int deg[N];
int deg2[N];
int d[N];
int n,m;

int dp_mx[N],dp_mn[N];
int dp_val[N];
bool son_have_negative[N];

bool can(int t_max)
{
    queue<int> que;
    memset(son_have_negative,0,sizeof(son_have_negative));
    for (int i=1;n>=i;i++)
    {
        deg2[i] = deg[i];
        if (deg2[i] == 1)
        {
            que.push(i);
        }
    }
    memset(dp_mx,0,sizeof(dp_mx));
    memset(dp_mn,0,sizeof(dp_mn));
    int ret = 0;
    int cnt=0;
    while (!que.empty())
    {
        cnt++;
        int t=que.front();
        que.pop();
        dp_val[t] = dp_mx[t] + dp_mn[t];
        if (cnt == n && dp_val[t]>0)
        {
            ret++;
            break;
        }
        else if (cnt == n)
        {
            break;
        }
        if (dp_val[t] != 0)
        {
            if (-dp_mn[t] > dp_mx[t])
            {
                dp_val[t] = dp_mn[t];
            }
            else if (dp_mx[t] > -dp_mn[t])
            {
                dp_val[t] = dp_mx[t];
            }
            else
            {
                assert(0);
            }
        }
        if (dp_val[t] == t_max)
        {
            ret++;
            dp_val[t] = -t_max;
        }
        int update_val = dp_val[t] + ( (dp_val[t] == 0 && (d[t] == 0 || d[t] == 1 && son_have_negative[t]) )^1 );
        for (int u:G[t])
        {
            deg2[u]--;
            if (deg2[u] == 1)
            {
                que.push(u);
            }
            dp_mx[u] = max(dp_mx[u],update_val);
            dp_mn[u] = min(dp_mn[u],update_val);
            if (dp_val[t] < 0) son_have_negative[u] = 1;
        }
    }
    return ret <= m;
}

int main ()
{
    scanf("%d %d",&n,&m);
    int tot=0;
    for (int i=1;n>=i;i++)
    {
        scanf("%d",&d[i]);
        tot += d[i];
    }
    if (tot <= m)
    {
        puts("0");
        return 0;
    }
    for (int i=1;n>i;i++)
    {
        int x,y;
        scanf("%d %d",&x,&y);
        deg[x]++;
        deg[y]++;
        G[x].push_back(y);
        G[y].push_back(x);
    }
    int L=0,R = n+1;
    while (R-L != 1)
    {
        int mid=(L+R)>>1;
        if (can(mid)) R = mid;
        else L = mid;
    }
    printf("%d\n",R);
}

沒有留言:

張貼留言