2962: 序列操作

Description

有一个长度为 \(n\) 的序列, 有三个操作:

  1. I a b c 表示将 \([a, b]\) 这一段区间的元素集体增加 \(c\);
  2. R a b 表示将 \([a, b]\) 区间内所有元素变成相反数;
  3. Q a b c 表示询问 \([a, b]\) 这一段区间中选择 \(c\) 个数相乘的所有方案的和 \(\bmod 19940417\) 的值.

Input

第一行两个数 \(n, q\) 表示序列长度和操作个数.

第二行 \(n\) 个非负整数, 表示序列.

接下来 \(q\) 行每行输入一个操作, 意义如题目描述.

Output

对于每个询问, 输出选出 \(c\) 个数相乘的所有方案的和 \(\bmod 19940417\) 的值.

Sample Input

5 5
1 2 3 4 5
I 2 3 1
Q 2 4 2
R 1 5
I 1 3 -1
Q 1 5 1

Sample Output

40
19940397

Data Range

对于 100% 的数据, 满足:\(n \leq 50000, q \leq 50000\), 初始序列的元素的绝对值 \(\leq 10^9\), 并且所有操作满足如下的性质:

  • I a b c 中保证 \([a, b]\) 是一个合法区间,\(|c| \leq 10^9\);
  • R a b 中保证 \([a, b]\) 是个合法的区间;
  • Q a b c 中保证 \([a, b]\) 是个合法的区间, 且 \(1 \leq c \leq min(b - a + 1, 20)\).

Explanation

一道双倍经验题...... 不过清橙上可以没有权限号随便交也是比较良心的......

瞄一眼区间操作就是线段树对吧...... 用 Splay 没有什么意义

这样搞来就可以比较明确地认为我们可以维护选择 \(0 - 20\) 个元素地乘积之和是多少, 而且 可以 \(O(c^2)\) 合并, 或者当 \(c\) 比较大 (这道题不是这样) 时用卷积 \(O(c \log c)\) 合并也不是不可以......

然后把区间数据用 interval 结构维护, 套上一个裸的线段树即可.

顺便吐槽一下...... 好像我把某个位置的 change 的 r 写成了 l 无限 WA...... 调了三天 QwQ

Source Code


#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>

#define rep(_var,_begin,_end) for(int _var=_begin;_var<=_end;_var++)
#define range(_begin,_end) rep(_,_begin,_end)
using namespace std;
typedef long long lli;
const int maxn = 100100, maxk = 21;
const lli modr = 19940417;

// Combinatorics
lli C[maxn][maxk];

struct interval {
    lli cnt[maxk];
    int size;
    interval(void) { size = 0;
        range(0, 20) cnt[_] = 0; }
    interval(int x) { size = 1;
        range(2, 20) cnt[_] = 0;
        cnt[0] = 1; cnt[1] = x; }
    void add(lli x) {
        lli tmp = x; for (int i = 20; i > 0; i--, tmp = x)
            for (int j = 1; j <= i; j++, tmp = tmp * x % modr)
                cnt[i] = (cnt[i] + cnt[i-j] * tmp % modr * C[size-i+j][j]) % modr;
        return ; }
    void inverse(void) {
        for (int i = 1; i <= 20; i += 2)
            cnt[i] = (modr - cnt[i]) % modr;
        return ; }
}; interval join(interval a, interval b) {
    interval c; c.size = a.size + b.size;
    rep(i, 0, 20) rep(j, 0, i)
        c.cnt[i] = (c.cnt[i] + a.cnt[j] * b.cnt[i-j]) % modr;
    return c;
}

class SegmentTree
{
public:
    struct node
    {
        node *lc, *rc;
        int lb, mb, rb;
        interval val;
        lli lazyadd;
        bool lazyinv;
    } *root, npool[maxn<<1];
    int n, ncnt;
    node* make_node(void)
    {
        node *p = &npool[++ncnt];
        p->lc = p->rc = NULL;
        p->lb = p->mb = p->rb = 0;
        p->val = 0;
        p->lazyadd = 0;
        p->lazyinv = false;
        return p;
    }
    void mark_add(node *p, lli v)
    {
        p->lazyadd = (p->lazyadd + v) % modr;
        p->val.add(v);
        return ;
    }
    void mark_inv(node *p)
    {
        p->lazyadd = (modr - p->lazyadd) % modr;
        p->lazyinv ^= 1;
        p->val.inverse();
        return ;
    }
    void dispatch_lazy(node *p)
    {
        if (p->lazyinv) {
            mark_inv(p->lc);
            mark_inv(p->rc);
            p->lazyinv = false;
        }
        if (p->lazyadd) {
            mark_add(p->lc, p->lazyadd);
            mark_add(p->rc, p->lazyadd);
            p->lazyadd = 0;
        }
        return ;
    }
    interval query(node *p, int l, int r)
    {
        if (p->lb == l && p->rb == r) {
            return p->val;
        }
        dispatch_lazy(p);
        if (r <= p->mb) {
            return query(p->lc, l, r);
        } else if (l > p->mb) {
            return query(p->rc, l, r);
        } else {
            return join(query(p->lc, l, p->mb),
                query(p->rc, p->mb + 1, r));
        }
        return interval();
    }
    lli query(int l, int r, int c)
    {
        if (l > r) swap(l, r);
        interval res = this->query(root, l, r);
        return res.cnt[c];
    }
    void change(node *p, int l, int r, lli val)
    {
        if (p->lb == l && p->rb == r) {
            mark_add(p, val);
            return ;
        }
        dispatch_lazy(p);
        if (r <= p->mb)
            change(p->lc, l, r, val);
        else if (l > p->mb)
            change(p->rc, l, r, val);
        else
            change(p->lc, l, p->mb, val),
            change(p->rc, p->mb + 1, r, val);
        p->val = join(p->lc->val, p->rc->val);
        return ;
    }
    void change(int l, int r, lli val)
    {
        val = (val % modr + modr) % modr;
        this->change(root, l, r, val);
        return ;
    }
    void inverse(node *p, int l, int r)
    {
        if (p->lb == l && p->rb == r) {
            mark_inv(p);
            return ;
        }
        dispatch_lazy(p);
        if (r <= p->mb)
            inverse(p->lc, l, r);
        else if (l > p->mb)
            inverse(p->rc, l, r);
        else
            inverse(p->lc, l, p->mb),
            inverse(p->rc, p->mb + 1, r);
        p->val = join(p->lc->val, p->rc->val);
        return ;
    }
    void inverse(int l, int r)
    {
        this->inverse(root, l, r);
        return ;
    }
    node* build_tree(int l, int r, lli arr[])
    {
        node *p = make_node();
        int mid = (l + r) >> 1;
        p->lb = l; p->mb = mid; p->rb = r;
        if (p->lb == p->rb) {
            lli val = (arr[mid] % modr + modr) % modr;
            p->val = interval(val);
        } else {
            p->lc = build_tree(l, mid, arr);
            p->rc = build_tree(mid + 1, r, arr);
            p->val = join(p->lc->val, p->rc->val);
        }
        return p;
    }
    void init(int n, lli arr[])
    {
        this->n = n;
        this->root = this->build_tree(1, n, arr);
        return ;
    }
} st;

int n, m;
lli arr[maxn];
char str[64];

int main(int argc, char** argv)
{
    // Preloading input
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        scanf("%lld", &arr[i]);
    // Loading C(...) function with Yang Hui triangle
    C[0][0] = 0;
    for (int i = 0; i <= n; i++) {
        C[i][0] = 1;
        for (int j = 1; j <= i && j < maxk; j++)
            C[i][j] = (C[i-1][j] + C[i-1][j-1]) % modr;
    }
    // Initializing
    st.init(n, arr);
    // Can you answer these queries?
    for (int i = 1; i <= m; i++) {
        int a, b, c;
        scanf("%s", str);
        if (str[0] == 'I') {
            scanf("%d%d%d", &a, &b, &c);
            c = (c % modr + modr) % modr;
            st.change(a, b, c);
        } else if (str[0] == 'R') {
            scanf("%d%d", &a, &b);
            st.inverse(a, b);
        } else if (str[0] == 'Q') {
            scanf("%d%d%d", &a, &b, &c);
            lli res = st.query(a, b, c);
            printf("%lld\n", res);
        }
    }
    return 0;
}