祖孙询问

问题描述

已知一棵 n 个节点的有根树. 有 m 个询问. 每个询问给出了一对节点的编号 x 和 y, 询问 x 与 y 的祖孙关系.

输入格式

输入第一行包括一个整数 n 表示节点个数. 接下来 n 行每行一对整数对 a 和 b 表示 a 和 b 之间有连边. 如果 b 是-1, 那么 a 就是树的根. 第 n+2 行是一个整数 m 表示询问个数. 接下来 m 行, 每行两个正整数 x 和 y.

输出格式

对于每一个询问, 输出 1: 如果 x 是 y 的祖先, 输出 2: 如果 y 是 x 的祖先, 否则输出 0.

样例输入

10
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19

样例输出

1
0
0
0
2

数据规模

对于 30% 的数据, n, m≤1000. 对于 100% 的. 据, n, m≤40000, 每个节点的编号都不超过 40000.

Explanation

这道题可以用倍增 LCA 在 O (nlogn) 时间内做出来, 但是我还没想到写这道题, 所以暂时就弃坑了.

但是也可以用树链剖分做的说, 而且复杂度也是 O (nlogn) 啊, 所以就直接用树链剖分做了.

注意在判断是否是祖先的时候要注意, 最后深度与 chaintop[] 的关系.

Example Code


#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <queue>
#include <stack>

using namespace std;
const int maxn = 100010;

class TreeChainContainer
{
public:
    struct edge
    {
        int u, v;
        edge *next;
    };
    edge *edges[maxn], epool[maxn];
    int n, root, ecnt;
    int par[maxn], size[maxn], depth[maxn], maxson[maxn], chaintop[maxn];
    void addedge(int u, int v)
    {
        if (v == -1) {
            root = u;
            return ;
        }
        edge *p = &epool[ecnt++],
             *q = &epool[ecnt++];
        p->u = u; p->v = v; p->next = edges[u]; edges[u] = p;
        q->u = v; q->v = u; q->next = edges[v]; edges[v] = q;
        return ;
    }
    void init(void)
    {
        par[root] = 0;
        depth[root] = 1;
        // Procedure for dfs1.
        queue<int> que;
        stack<int> stk;
        que.push(root);
        stk.push(root);
        while (!que.empty()) {
            int p = que.front();
            que.pop();
            size[p] = 1;
            for (edge *ep = edges[p]; ep; ep = ep->next)
                if (depth[ep->v] == 0) {
                    depth[ep->v] = depth[p] + 1;
                    par[ep->v] = p;
                    que.push(ep->v);
                    stk.push(ep->v);
                }
        }
        while (!stk.empty()) {
            int p = stk.top();
            stk.pop();
            for (edge *ep = edges[p]; ep; ep = ep->next)
                if (par[ep->v] == p) {
                    size[p] += size[ep->v];
                    if (size[ep->v] > size[maxson[p]])
                        maxson[p] = ep->v;
                }
        }
        // Procedure for dfs2.
        chaintop[root] = root;
        que.push(root);
        while (!que.empty()) {
            int p = que.front();
            que.pop();
            if (!maxson[p])
                continue;
            chaintop[maxson[p]] = chaintop[p];
            que.push(maxson[p]);
            for (edge *ep = edges[p]; ep; ep = ep->next)
                if (par[ep->v] == p && ep->v != maxson[p]) {
                    chaintop[ep->v] = ep->v;
                    que.push(ep->v);
                }
        }
        return ;
    }
    int eval(int x, int y)
    {
        if (depth[x] == depth[y])
            return 0;
        else if (depth[x] > depth[y])
            return eval(y, x) ? 2 : 0;
        // y is deeper than x.
        // x = chaintop[x];
        while (y != x && y != root && chaintop[y] != chaintop[x]) {
            y = chaintop[y];
            if (y != root)
                y = par[y];
        }
        // printf("%d %d %d %d\n", x, y, chaintop[x], chaintop[y]);
        // int t = x;
        // while (t != root) {
        //     printf("%d -> %d\n", x, t);
        //     t = par[t];
        // }
        // t = y;
        // while (t != root) {
        //     printf("%d -> %d\n", y, t);
        //     t = par[t];
        // }
        // while (y != root && y != x)
        //     y = par[y];
        // if (y != x)
        //     return 0;
        // return 1;
        if (depth[x] > depth[y])
            return 0;
        return chaintop[x] == chaintop[y];
    }
} graph;

int n, m;

int main(int argc, char** argv)
{
    scanf("%d", &n);
    graph.n = n;
    for (int i = 1, a, b; i <= n; i++) {
        scanf("%d%d", &a, &b);
        graph.addedge(a, b);
    }
    graph.init();
    scanf("%d", &m);
    for (int i = 1, x, y; i <= m; i++) {
        scanf("%d%d", &x, &y);
        printf("%d\n", graph.eval(x, y));
    }
    return 0;
}