Description

给定一棵有 \(n\) 个节点的无根树和 \(m\) 个操作, 操作有 2 类:

  1. 将节点 \(a\) 到节点 \(b\) 路径上所有点都染成颜色 \(c\);
  2. 询问节点 \(a\) 到节点 \(b\) 路径上的颜色段数量 (连续相同颜色被认为是同一段), 如 112221 由 3 段组成:11,222,1.

请你写一个程序依次完成这 \(m\) 个操作.

Input

第一行包含 \(2\) 个整数 \(n\)\(m\), 分别表示节点数和操作数;

第二行包含 \(n\) 个正整数表示 \(n\) 个节点的初始颜色;

接下来 \(n-1\) 行每行包含两个整数 \(x\)\(y\), 表示 \(x\)\(y\) 之间有一条无向边.

下面 \(m\) 行每行描述一个操作:

  • C a b c 表示这是一个染色操作, 把节点 \(a\) 到节点 \(b\) 路径上所有点 (包括 \(a\)\(b\)) 都染成颜色 \(c\);
  • Q a b 表示这是一个询问操作, 询问节点 \(a\) 到节点 \(b\) (包括 \(a\)\(b\)) 路径 上的颜色段数量.

Output

对于每个询问操作, 输出一行答案.

Sample Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Sample Output

3
1
2

Data Range

对于所有数据, 保证:\(n \leq 10^5, m \leq 10^5\), 所有的颜色 \(c\) 为整数且在 \([0, 10^9]\) 之间.

Explanation

可以用树链剖分来维护这棵树, 在树链剖分上套一棵线段树.

线段树维护区间中颜色的数量, 左端颜色和右端颜色.

最后的总时间复杂度为 \(O(n \log^2 n)\), 可能需要倍增 LCA.

Source Code


#include <iostream>
#include <cstdlib>
#include <cstdio>

using namespace std;
typedef long long lli;
const int maxn = 100100, maxm = 400100, maxlog = 17;

struct interval {
    int lc, rc; // Left and right (boundary) colours
    int cols; // Total consecutive colours
    void set_colour(int col) {
        lc = rc = col;
        cols = 1;
        return ; }
    interval(void) {
        this->set_colour(0); }
    interval(int col) {
        this->set_colour(col); }
    interval(int l, int r, int col) {
        lc = l, rc = r, cols = col;
        return ; }
}; interval join(interval a, interval b) {
    int cols = a.cols + b.cols;
    if (a.rc == b.lc) cols -= 1;
    return interval(a.lc, b.rc, cols);
}

class SegmentTree
{
public:
    struct node
    {
        node *lc, *rc;
        int lb, mb, rb, lazy;
        interval val;
    } *root, npool[maxn<<1];
    int n, ncnt;
    node* make_node(void)
    {
        node *p = &npool[++ncnt];
        p->lc = p->rc = NULL;
        p->lb = p->mb = p->rb = 0;
        p->lazy = -1;
        return p;
    }
    void dispatch_lazy(node *p)
    {
        if (p->lazy < 0 || p->lb == p->rb)
            return ;
        p->lc->lazy = p->rc->lazy = p->lazy;
        p->lc->val.set_colour(p->lazy);
        p->rc->val.set_colour(p->lazy);
        p->lazy = -1;
        return ;
    }
    void change(node *p, int l, int r, int col)
    {
        if (p->lb == l && p->rb == r) {
            p->lazy = col;
            p->val.set_colour(col);
            return ;
        }
        dispatch_lazy(p);
        if (r <= p->mb) {
            change(p->lc, l, r, col);
        } else if (l > p->mb) {
            change(p->rc, l, r, col);
        } else {
            change(p->lc, l, p->mb, col);
            change(p->rc, p->mb + 1, r, col);
        }
        p->val = join(p->lc->val, p->rc->val);
        return ;
    }
    void change(int l, int r, int col)
    {
        return this->change(root, l, r, col);
    }
    interval query(node *p, int l, int r)
    {
        if (p->lb == l && p->rb == r) {
            return p->val;
        }
        dispatch_lazy(p);
        if (r <= p->mb) {
            return query(p->lc, l, r);
        } else if (l > p->mb) {
            return query(p->rc, l, r);
        } else {
            return join(query(p->lc, l, p->mb),
                query(p->rc, p->mb + 1, r));
        }
        return interval();
    }
    interval query(int l, int r)
    {
        return this->query(root, l, r);
    }
    int query(int pos)
    {
        node *p = root;
        while (p->lb < p->rb) {
            dispatch_lazy(p);
            if (pos <= p->mb)
                p = p->lc;
            else
                p = p->rc;
        }
        return p->val.lc;
    }
    node* build_tree(int l, int r, int arr[])
    {
        node *p = make_node();
        int mid = (l + r) >> 1;
        p->lb = l; p->rb = r; p->mb = mid;
        if (p->lb == p->rb) {
            p->val = interval(arr[mid]);
        } else {
            p->lc = build_tree(l, mid, arr);
            p->rc = build_tree(mid + 1, r, arr);
            p->val = join(p->lc->val, p->rc->val);
        }
        return p;
    }
    void build(int n, int arr[])
    {
        root = build_tree(1, n, arr);
        return ;
    }
} st;

class TreeChainPartition
{
public:
    struct edge
    {
        int u, v;
        edge *next;
    };
    int n, root, ecnt, dcnt;
    int alt_arr[maxn];
    edge *edges[maxn], epool[maxm];
    void add_edge(int u, int v)
    {
        edge *p = &epool[++ecnt],
             *q = &epool[++ecnt];
        p->u = u; p->v = v;
        q->u = v; q->v = u;
        p->next = edges[u]; edges[u] = p;
        q->next = edges[v]; edges[v] = q;
        return ;
    }
    int size[maxn], par[maxn], depth[maxn];
    int maxch[maxn], ctop[maxn], dfn[maxn];
    int jump[maxn][maxlog+1]; // Reserved for LCA
    void dfs1(int p)
    {
        size[p] = 1;
        for (int i = 1; i < maxlog; i++) {
            if (depth[p] < (1<<i))
                break;
            jump[p][i] = jump[jump[p][i-1]][i-1];
        }
        for (edge *ep = edges[p]; ep; ep = ep->next)
            if (ep->v != par[p]) {
                par[ep->v] = p;
                depth[ep->v] = depth[p] + 1;
                jump[ep->v][0] = p;
                dfs1(ep->v);
                size[p] += size[ep->v];
                if (size[ep->v] > size[maxch[p]])
                    maxch[p] = ep->v;
            }
        return ;
    }
    void dfs2(int p, int chaintop)
    {
        dfn[p] = ++dcnt;
        ctop[p] = chaintop;
        if (maxch[p])
            dfs2(maxch[p], chaintop);
        for (edge *ep = edges[p]; ep; ep = ep->next)
            if (depth[ep->v] == depth[p] + 1 && ep->v != maxch[p])
                dfs2(ep->v, ep->v);
        return ;
    }
    int lca(int x, int y)
    {
        if (depth[x] < depth[y])
            swap(x, y);
        // Ensured that x is deeper than y
        int dist = depth[x] - depth[y];
        // Letting x reach the depth par y
        for (int i = 0; i < maxlog; i++)
            if (dist & (1<<i))
                x = jump[x][i];
        // Syncing ancestors
        for (int i = maxlog - 1; i >= 0; i--)
            if (jump[x][i] != jump[y][i])
                x = jump[x][i],
                y = jump[y][i];
        if (x == y)
            return x;
        return jump[x][0];
    }
    void __change(int x, int y, int colour)
    {
        while (ctop[x] != ctop[y]) {
            st.change(dfn[ctop[x]], dfn[x], colour);
            x = jump[ctop[x]][0];
        }
        st.change(dfn[y], dfn[x], colour);
        return ;
    }
    int __query(int x, int y)
    {
        int res = 0;
        while (ctop[x] != ctop[y]) {
            int tmp = st.query(dfn[ctop[x]], dfn[x]).cols;
            res += tmp;
            if (st.query(dfn[jump[ctop[x]][0]]) == st.query(dfn[ctop[x]]))
                res -= 1;
            x = jump[ctop[x]][0];
        }
        int tmp = st.query(dfn[y], dfn[x]).cols;
        res += tmp;
        return res;
    }
    void change(int x, int y, int colour)
    {
        int z = lca(x, y);
        __change(x, z, colour);
        __change(y, z, colour);
        return ;
    }
    int query(int x, int y)
    {
        int z = lca(x, y);
        int res = __query(x, z)
            + __query(y, z) - 1;
        return res;
    }
    void init(int n, int arr[])
    {
        this->n = n;
        dcnt = 0;
        this->root = 1;
        // Generating DFS sequences
        depth[root] = 1;
        dfs1(root);
        dfs2(root, root);
        // Building segment tree, with minor modifications
        for (int i = 1; i <= n; i++)
            alt_arr[dfn[i]] = arr[i];
        st.build(n, alt_arr);
        return ;
    }
} graph;

int n, m;
int arr[maxn];
char str[64];

int main(int argc, char** argv)
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        scanf("%d", &arr[i]);
    for (int i = 1, a, b; i <= n - 1; i++) {
        scanf("%d%d", &a, &b);
        graph.add_edge(a, b);
    }
    // Building graph with integrated functions
    graph.init(n, arr);
    // Answering queries
    for (int idx = 1; idx <= m; idx++) {
        scanf("%s", str);
        int a, b, c;
        if (str[0] == 'C') {
            scanf("%d%d%d", &a, &b, &c);
            graph.change(a, b, c);
        } else if (str[0] == 'Q') {
            scanf("%d%d", &a, &b);
            int res = graph.query(a, b);
            printf("%d\n", res);
        }
    }
    // Finished
    return 0;
}