Description
Farmer John is arranging his \(n\) cows in a line to take a photo \((1 \leq n \leq 100000)\). The height of the \(i-th\) cow in sequence is \(h_i\), and the heights of all cows are distinct. As with all photographs of his cows, Farmer John wants this one to come out looking as nice as possible. He decides that cow ii looks “unbalanced” if \(L_i\) and \(R_i\) differ by more than factor of \(2\), where \(L_i\) and \(R_i\) are the number of cows taller than \(i\) on her left and right, respectively. That is,\(i\) is unbalanced if the larger of \(L_i\) and \(R_i\) is strictly more than twice the smaller of these two numbers. Farmer John is hoping that not too many of his cows are unbalanced. Please help Farmer John compute the total number of unbalanced cows.
农夫约翰正在安排他的 \(n\) 头牛拍照片, 每头牛有一个身高, 从 \(1\) 到 \(n\) 编号, 排列 成一行 \((h_1, h_2, \ldots, h_n)\), 每头牛 \(i\) 左边比他高的牛的数量记为 \(L_i\), 右边比他高的牛的数量记为 \(R_i\), 如果存在 \(i\) 满足 \(max(R_i, L_i) > 2 \times min(L_i, R_i)\) 则这个牛 \(i\) 是不平衡的, 现在农夫约翰需要你告诉他有多少头牛不平衡.
Input
The first line of input contains \(n\). The next \(n\) lines contain \(h_1, \ldots, h_n\) each a nonnegative integer at most \(1,000,000,000\). 输入第一行为 \(n (n \leq 10^5)\), 接下来的一行有 \(n\) 个数, 每个数表示第 \(i\) 头牛的身高, 不超过 \(10^9\).
Output
Please output a count of the number of cows that are unbalanced.
输出有多少头牛是不平衡的.
Sample Input
7
34
6
23
0
5
99
2
Sample Output
3
Explanation
直接维护一个 Splay 树就可以了 (虽然线段树+离散化也行).
记得有可能有数值是重复的, 那么 size 就不能直接打标记了.
很奇怪为什么 Splay 没有被卡常~
Source Code
#include <iostream>
#include <cstdlib>
#include <cstdio>
#define rep(_var,_begin,_end) for(int _var=_begin;_var<=_end;_var++)
#define range(_begin,_end) rep(_,_begin,_end)
#define minimize(__x,__y) __x=min(__x,__y);
#define maximize(__x,__y) __x=max(__x,__y);
using namespace std;
typedef long long lli;
const int maxn = 200100;
const int infinit = 0x7fffffff;
class SplayTree
{
public:
int arr_i[maxn][6];
#define lc(_x) arr_i[_x][0]
#define rc(_x) arr_i[_x][1]
#define ch(_x,_y) arr_i[_x][_y]
#define par(_x) arr_i[_x][2]
#define size(_x) arr_i[_x][3]
#define val(_x) arr_i[_x][4]
#define vsz(_x) arr_i[_x][5]
int root, ncnt;
int make_node(int v)
{
int p = ++ncnt;
lc(p) = rc(p) = par(p) = 0;
size(p) = vsz(p) = 1;
val(p) = v;
return p;
}
void update_lazy(int p)
{
size(p) = size(lc(p)) + vsz(p) + size(rc(p));
return ;
}
void rotate(int p)
{
int q = par(p), g = par(q);
int x = p == rc(q), y = q == rc(g);
ch(q, x) = ch(p, !x); if (ch(q, x)) par(ch(q, x)) = q;
ch(p, !x) = q; par(q) = p;
if (g) ch(g, y) = p; par(p) = g;
update_lazy(q);
update_lazy(p);
return ;
}
void splay(int p, int t)
{
for (int q = 0; (q = par(p)) && q != t; rotate(p))
if (par(q) && par(q) != t)
rotate((q == rc(par(q))) == (p == rc(q)) ? q : p);
if (t == 0) root = p;
return ;
}
int find(int x)
{
int p = root;
while (true) {
if (x <= size(lc(p))) {
x = lc(p);
} x -= size(lc(p));
if (x <= vsz(p)) {
return p;
} x -= vsz(p);
x = rc(p);
}
return p;
}
int insert(int v)
{
int p = make_node(v);
int q = root;
while (q != 0) {
if (v < val(q)) {
if (lc(q)) q = lc(q);
else lc(q) = p, par(p) = q, q = 0;
} else if (v > val(q)) {
if (rc(q)) q = rc(q);
else rc(q) = p, par(p) = q, q = 0;
} else {
vsz(q) += 1;
p = q;
q = 0;
}
}
splay(p, 0);
return p;
}
int query_greater(int p)
{
splay(p, 0);
int res = size(rc(p)) - 1;
return res;
}
void init(void)
{
int lp = make_node(-infinit),
rp = make_node(infinit);
root = rp;
lc(root) = lp, par(lp) = root;
update_lazy(root);
return ;
}
void debug(void)
{
printf("=> Splay tree contains %d nodes (root %d);\n", ncnt, root);
for (int i = 1; i <= ncnt; i++)
printf(" #%d: lc %d rc %d par %d size %d val %d vsz %d\n", i, lc(i),
rc(i), par(i), size(i), val(i), vsz(i));
return ;
}
} stl, str;
int n, arr[maxn], L[maxn], R[maxn];
int main(int argc, char** argv)
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &arr[i]);
// After the input of data, construct a dynamic splay tree.
stl.init();
str.init();
for (int i = 1; i <= n; i++) {
int p = stl.insert(arr[i]);
L[i] = stl.query_greater(p);
}
for (int i = n; i >= 1; i--) {
int p = str.insert(arr[i]);
R[i] = str.query_greater(p);
}
// Gathered information.
int res = 0;
for (int i = 1; i <= n; i++)
if (max(L[i], R[i]) > 2 * min(L[i], R[i]))
res += 1;
printf("%d\n", res);
return 0;
}