Description

给出 \(1 - n\) 的一个排列, 统计该排列有多少个长度为奇数的连续子序列的中位数是 b. 中位数是指把所有元素从小到大排列后, 位于中间的数.

Input

第一行为两个正整数 \(n\)\(b\), 第二行为 \(1 - n\) 的排列.

Output

输出一个整数, 即中位数为 \(b\) 的连续子序列个数.

Sample Input

7 4
5 7 2 4 3 1 6

Sample Output

4

Sample Explanation

四个连续子序列分别是: {4}, {7, 2, 4}, {5, 7, 2, 4, 3} 和 {5, 7, 2, 4, 3, 1, 6}

Data Range

对于 100% 的数据, 保证:\(n \leq 100000\)

Solution

一道大水题~

首先这是一个两两互异的数的排列, 所以不可能出现多于两个 \(b\) (没看到这点可是把我害惨了...... 想了半天 “如果出给我一个全是 \(b\) 的数据怎么办”) . 所以我们可以把这个位置作为固定的位置, 然后考虑排列问题.

这样一来, 数的值就没有关系了, 要存的也就是它们与 \(b\) 的大小关系.

找到 b 在数列中的位置设为 point, 比 b 大的赋值为-1, 比 b 小的赋值为 1;

然后求出 sum [i, point] 的值出现了几次记为 lfre [sum [i, point]] ++; ans+=lfre [sum [i, point]] * rfre [-sum [i, point]] ;

由于 c++数组不能是负数, 所以稍微处理一下

可以想到的是, 我们记 \(\gt = 1, = = 0, \lt = -1\), 那么含 \(b\) 为中位数的一段区间, 其各数之和必为 \(0\), 比如:

\[\sum_{i = l}^{r} proc[i] = 0\]

然后维护一下前面有多少个以每一个数作为标记的数量, 反之亦然.

算法复杂度是一个显然的 \(O(n)\)

注意数组越界问题

Source Code


#include <iostream>
#include <cstdlib>
#include <cstdio>

using namespace std;
typedef long long lli;
const int maxn = 200100;

int n, b;
int loc; // The position where b is located
int arr[maxn], sum[maxn],
    lcnt_arr[maxn], rcnt_arr[maxn];

// To avoid array out-of-bound, we use this trick
#define lcnt(__x) lcnt_arr[(__x)+n]
#define rcnt(__x) rcnt_arr[(__x)+n]

int main(int argc, char** argv)
{
    scanf("%d%d", &n, &b);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &arr[i]);
        if (arr[i] > b) {
            arr[i] = 1;
        } else if (arr[i] == b) {
            arr[i] = 0;
            loc = i;
        } else {
            arr[i] = -1;
        }
    }
    // Done reading data, processing sum, and occurence counts
    lcnt(0) = rcnt(0) = 1;
    for (int i = loc - 1; i >= 1; i--) {
        sum[i] = sum[i + 1] + arr[i];
        lcnt(sum[i])++;
    }
    for (int i = loc + 1; i <= n; i++) {
        sum[i] = sum[i - 1] + arr[i];
        rcnt(sum[i])++;
    }
    // Processing final results, shouldn't be too much.
    int res = 0;
    for (int i = -n; i <= n; i++)
        res += lcnt(i) * rcnt(-i);
    printf("%d\n", res);
    return 0;
}