Cydiater

「BZOJ 4700」适者

敌方有$n$台人形兵器,每台的攻击力为$a_i$,护甲值为$d_i$。我方只有一台人形兵器,攻击力为$ATK$。战斗看作回合制,每回合进程如下:

  • 我方选择对方某台人形兵器并攻击,令其护甲值减少$ATK$,若$ATK \leq 0$则被破坏。
  • 敌方每台未被破坏的人形兵器攻击我方基地造成$a_i$点损失。

但是,在第一回合开始之前,某两台敌方的人形兵器被干掉了(秒杀)。问最好情况下,我方基地会受到多少点损失。

好久没管博客了,来更篇题解!

首先看到这道题,我们考虑如果没有秒杀的情况如何计算答案呢。很显然我们发现当攻击一个怪的时候一定是一直打直到打死最优。那么对于不同的攻打顺序对应了不同的答案,对于防御,我们直接对去除去我们的攻击上取整,代表打死他所需要的次数。我们考虑如何得到最优的顺序。

方便起见,我们设第$i$个人的攻击是$a_i$,防御是$b_i$,前$i$个人的防御和是$p_i$,后$i$个人的攻击和是$s_i$我们考虑两个位置上的元素$a_i,b_i$以及$a_j,b_j$,什么时候$i$应该在$j$前面呢。很显然如果满足

$$a_i\times(b_i - 1) + a_j\times(b_i + b_j - 1) < a_j\times(b_j - 1) + a_i\times(b_i + b_j - 1)$$

的时候 $i$应该在$j$前面,我们化简一下得到$a_i \times b_j < a_j \times b_i$

接着考虑删掉两个数的情况,容易想到可以枚举一个数,然后快速确定第二数的最优位置。我们得到当删掉$i<j$时,答案要扣去的影响:

$$delta = b_i\times s_i - a_i + b_j \times s_j - a_j + a_i \times (p_i - b_i) + a_j \times (p_j - b_i - b_j)$$

然后我们化简一下,发现其是一个斜率的形式,然后 CDQ 分治维护一下即可。

#include <bits/stdc++.h>

using namespace std;

#define ll             long long
#define db            double
#define up(i,j,n)        for (int i = j; i <= n; i++)
#define down(i,j,n)    for (int i = j; i >= n; i--)
#define cadd(a,b)        a = add (a, b)
#define cpop(a,b)        a = pop (a, b)
#define cmul(a,b)        a = mul (a, b)
#define pr            pair<ll, ll>
#define fi            first
#define se            second
#define SZ(x)        (int)x.size()
#define bin(i)        (1 << (i))
#define Auto(i,node)    for (int i = LINK[node]; i; i = e[i].next)

template<typename T> inline bool cmax(T & x, T y){return y > x ? x = y, true : false;}
template<typename T> inline bool cmin(T & x, T y){return y < x ? x = y, true : false;}

const int MAXN = 3e5 + 5;
const int oo = 0x3f3f3f3f;

inline int read(){
    char ch = getchar(); int x = 0, f = 1;
    while (ch > '9' || ch < '0') {if (ch == '-') f = -1; ch = getchar();}
    while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
    return x * f;
}

inline pr operator + (const pr & a, const pr & b) {return make_pair(a.fi + b.fi, a.se + b.se);}
inline pr operator - (const pr & a, const pr & b) {return make_pair(a.fi - b.fi, a.se - b.se);}

inline ll crs(const pr & a, const pr & b) {return a.fi * b.se - a.se * b.fi;}

int N, T;

ll s[MAXN], p[MAXN], a[MAXN], b[MAXN], all, ans;
ll c[MAXN], d[MAXN];

struct Node {
    ll k, x, y;
    int o;
    inline bool operator < (const Node & v) const {return x < v.x;}
    Node () {}
    Node (int i) : o(i) {    
        x = b[i];
        y = b[i] * s[i] - a[i] + a[i] * p[i] - a[i] * b[i];
        k = -a[i];
    }
    pr xy() {return make_pair(x, y);}
}o[MAXN], cl[MAXN], cr[MAXN], tmp[MAXN];

int top, ord[MAXN];
pr q[MAXN];

inline bool cmp(int x, int y){return a[x] * b[y] > a[y] * b[x];}

inline ll calc(pr x, Node & y) {
    return y.k * x.fi + y.y + x.se;
}

void Work(int le, int ri){
    int mi = (le + ri) >> 1;
    if (le == ri) return;
    int dl = le, dr = mi + 1;
    up (i, le, ri) {
        if (o[i].o <= mi) cl[dl++] = o[i];
        else cr[dr++] = o[i];
    }
    up (i, le, mi) o[i] = cl[i];
    up (i, mi + 1, ri) o[i] = cr[i];
    Work(mi + 1, ri);
    top = 0;
    up (i, le, mi) {
        while (top >= 2 && crs(q[top - 1] - o[i].xy(), q[top] - o[i].xy()) >= 0) top--;
        q[++top] = o[i].xy();
    }
    int cur = 1;
    up (i, mi + 1, ri) {
        while (cur + 1 <= top && calc(q[cur + 1], o[i]) > calc(q[cur], o[i])) cur++;
        cmax(ans, calc(q[cur], o[i]));
    }
    Work(le, mi);
    dl = le; dr = mi + 1;
    up (i, le, ri) {
        if (dl == mi + 1) tmp[i] = o[dr++];
        else if (dr == ri + 1) tmp[i] = o[dl++];
        else tmp[i] = o[dl].k < o[dr].k ? o[dl++] : o[dr++];
    }
    up (i, le, ri) o[i] = tmp[i];
}

int main(){
    scanf("%d%d", &N, &T);
    up (i, 1, N) {
        a[i] = read();
        b[i] = (read() + T - 1) / T;
        ord[i] = i;
    } 
    sort(ord + 1, ord + N + 1, cmp);
    up (i, 1, N) {
        c[i] = a[ord[i]];
        d[i] = b[ord[i]];
    }
    up (i, 1, N) {
        a[i] = c[i];
        b[i] = d[i];
    }
    up (i, 1, N) p[i] = p[i - 1] + b[i];
    down (i, N, 1) s[i] = s[i + 1] + a[i];
    up (i, 1, N) all += a[i] * (p[i] - 1);
    ans = 0;
    up (i, 1, N) o[i] = Node(i);
    sort(o + 1, o + N + 1);
    Work(1, N);
    printf("%lld\n", all - ans);
    return 0;
}