Description
给定一棵有 \(n\) 个节点的无根树和 \(m\) 个操作, 操作有 2 类:
- 将节点 \(a\) 到节点 \(b\) 路径上所有点都染成颜色 \(c\);
- 询问节点 \(a\) 到节点 \(b\) 路径上的颜色段数量 (连续相同颜色被认为是同一段), 如
112221
由 3 段组成:11
,222
,1
.
请你写一个程序依次完成这 \(m\) 个操作.
Input
第一行包含 \(2\) 个整数 \(n\) 和 \(m\), 分别表示节点数和操作数;
第二行包含 \(n\) 个正整数表示 \(n\) 个节点的初始颜色;
接下来 \(n-1\) 行每行包含两个整数 \(x\) 和 \(y\), 表示 \(x\) 和 \(y\) 之间有一条无向边.
下面 \(m\) 行每行描述一个操作:
C a b c
表示这是一个染色操作, 把节点 \(a\) 到节点 \(b\) 路径上所有点 (包括 \(a\) 和 \(b\)) 都染成颜色 \(c\);Q a b
表示这是一个询问操作, 询问节点 \(a\) 到节点 \(b\) (包括 \(a\) 和 \(b\)) 路径 上的颜色段数量.
Output
对于每个询问操作, 输出一行答案.
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
Data Range
对于所有数据, 保证:\(n \leq 10^5, m \leq 10^5\), 所有的颜色 \(c\) 为整数且在 \([0, 10^9]\) 之间.
Explanation
可以用树链剖分来维护这棵树, 在树链剖分上套一棵线段树.
线段树维护区间中颜色的数量, 左端颜色和右端颜色.
最后的总时间复杂度为 \(O(n \log^2 n)\), 可能需要倍增 LCA.
Source Code
#include <iostream>
#include <cstdlib>
#include <cstdio>
using namespace std;
typedef long long lli;
const int maxn = 100100, maxm = 400100, maxlog = 17;
struct interval {
int lc, rc; // Left and right (boundary) colours
int cols; // Total consecutive colours
void set_colour(int col) {
lc = rc = col;
cols = 1;
return ; }
interval(void) {
this->set_colour(0); }
interval(int col) {
this->set_colour(col); }
interval(int l, int r, int col) {
lc = l, rc = r, cols = col;
return ; }
}; interval join(interval a, interval b) {
int cols = a.cols + b.cols;
if (a.rc == b.lc) cols -= 1;
return interval(a.lc, b.rc, cols);
}
class SegmentTree
{
public:
struct node
{
node *lc, *rc;
int lb, mb, rb, lazy;
interval val;
} *root, npool[maxn<<1];
int n, ncnt;
node* make_node(void)
{
node *p = &npool[++ncnt];
p->lc = p->rc = NULL;
p->lb = p->mb = p->rb = 0;
p->lazy = -1;
return p;
}
void dispatch_lazy(node *p)
{
if (p->lazy < 0 || p->lb == p->rb)
return ;
p->lc->lazy = p->rc->lazy = p->lazy;
p->lc->val.set_colour(p->lazy);
p->rc->val.set_colour(p->lazy);
p->lazy = -1;
return ;
}
void change(node *p, int l, int r, int col)
{
if (p->lb == l && p->rb == r) {
p->lazy = col;
p->val.set_colour(col);
return ;
}
dispatch_lazy(p);
if (r <= p->mb) {
change(p->lc, l, r, col);
} else if (l > p->mb) {
change(p->rc, l, r, col);
} else {
change(p->lc, l, p->mb, col);
change(p->rc, p->mb + 1, r, col);
}
p->val = join(p->lc->val, p->rc->val);
return ;
}
void change(int l, int r, int col)
{
return this->change(root, l, r, col);
}
interval query(node *p, int l, int r)
{
if (p->lb == l && p->rb == r) {
return p->val;
}
dispatch_lazy(p);
if (r <= p->mb) {
return query(p->lc, l, r);
} else if (l > p->mb) {
return query(p->rc, l, r);
} else {
return join(query(p->lc, l, p->mb),
query(p->rc, p->mb + 1, r));
}
return interval();
}
interval query(int l, int r)
{
return this->query(root, l, r);
}
int query(int pos)
{
node *p = root;
while (p->lb < p->rb) {
dispatch_lazy(p);
if (pos <= p->mb)
p = p->lc;
else
p = p->rc;
}
return p->val.lc;
}
node* build_tree(int l, int r, int arr[])
{
node *p = make_node();
int mid = (l + r) >> 1;
p->lb = l; p->rb = r; p->mb = mid;
if (p->lb == p->rb) {
p->val = interval(arr[mid]);
} else {
p->lc = build_tree(l, mid, arr);
p->rc = build_tree(mid + 1, r, arr);
p->val = join(p->lc->val, p->rc->val);
}
return p;
}
void build(int n, int arr[])
{
root = build_tree(1, n, arr);
return ;
}
} st;
class TreeChainPartition
{
public:
struct edge
{
int u, v;
edge *next;
};
int n, root, ecnt, dcnt;
int alt_arr[maxn];
edge *edges[maxn], epool[maxm];
void add_edge(int u, int v)
{
edge *p = &epool[++ecnt],
*q = &epool[++ecnt];
p->u = u; p->v = v;
q->u = v; q->v = u;
p->next = edges[u]; edges[u] = p;
q->next = edges[v]; edges[v] = q;
return ;
}
int size[maxn], par[maxn], depth[maxn];
int maxch[maxn], ctop[maxn], dfn[maxn];
int jump[maxn][maxlog+1]; // Reserved for LCA
void dfs1(int p)
{
size[p] = 1;
for (int i = 1; i < maxlog; i++) {
if (depth[p] < (1<<i))
break;
jump[p][i] = jump[jump[p][i-1]][i-1];
}
for (edge *ep = edges[p]; ep; ep = ep->next)
if (ep->v != par[p]) {
par[ep->v] = p;
depth[ep->v] = depth[p] + 1;
jump[ep->v][0] = p;
dfs1(ep->v);
size[p] += size[ep->v];
if (size[ep->v] > size[maxch[p]])
maxch[p] = ep->v;
}
return ;
}
void dfs2(int p, int chaintop)
{
dfn[p] = ++dcnt;
ctop[p] = chaintop;
if (maxch[p])
dfs2(maxch[p], chaintop);
for (edge *ep = edges[p]; ep; ep = ep->next)
if (depth[ep->v] == depth[p] + 1 && ep->v != maxch[p])
dfs2(ep->v, ep->v);
return ;
}
int lca(int x, int y)
{
if (depth[x] < depth[y])
swap(x, y);
// Ensured that x is deeper than y
int dist = depth[x] - depth[y];
// Letting x reach the depth par y
for (int i = 0; i < maxlog; i++)
if (dist & (1<<i))
x = jump[x][i];
// Syncing ancestors
for (int i = maxlog - 1; i >= 0; i--)
if (jump[x][i] != jump[y][i])
x = jump[x][i],
y = jump[y][i];
if (x == y)
return x;
return jump[x][0];
}
void __change(int x, int y, int colour)
{
while (ctop[x] != ctop[y]) {
st.change(dfn[ctop[x]], dfn[x], colour);
x = jump[ctop[x]][0];
}
st.change(dfn[y], dfn[x], colour);
return ;
}
int __query(int x, int y)
{
int res = 0;
while (ctop[x] != ctop[y]) {
int tmp = st.query(dfn[ctop[x]], dfn[x]).cols;
res += tmp;
if (st.query(dfn[jump[ctop[x]][0]]) == st.query(dfn[ctop[x]]))
res -= 1;
x = jump[ctop[x]][0];
}
int tmp = st.query(dfn[y], dfn[x]).cols;
res += tmp;
return res;
}
void change(int x, int y, int colour)
{
int z = lca(x, y);
__change(x, z, colour);
__change(y, z, colour);
return ;
}
int query(int x, int y)
{
int z = lca(x, y);
int res = __query(x, z)
+ __query(y, z) - 1;
return res;
}
void init(int n, int arr[])
{
this->n = n;
dcnt = 0;
this->root = 1;
// Generating DFS sequences
depth[root] = 1;
dfs1(root);
dfs2(root, root);
// Building segment tree, with minor modifications
for (int i = 1; i <= n; i++)
alt_arr[dfn[i]] = arr[i];
st.build(n, alt_arr);
return ;
}
} graph;
int n, m;
int arr[maxn];
char str[64];
int main(int argc, char** argv)
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
scanf("%d", &arr[i]);
for (int i = 1, a, b; i <= n - 1; i++) {
scanf("%d%d", &a, &b);
graph.add_edge(a, b);
}
// Building graph with integrated functions
graph.init(n, arr);
// Answering queries
for (int idx = 1; idx <= m; idx++) {
scanf("%s", str);
int a, b, c;
if (str[0] == 'C') {
scanf("%d%d%d", &a, &b, &c);
graph.change(a, b, c);
} else if (str[0] == 'Q') {
scanf("%d%d", &a, &b);
int res = graph.query(a, b);
printf("%d\n", res);
}
}
// Finished
return 0;
}