黑红树 (brtree)

背景

Mz 们在 czy 的生日送他一个黑红树种子...... czy 种下种子, 结果种子很快就长得飞快, 它的枝干伸入空中看不见了......

题目描述

Czy 发现黑红树具有一些独特的性质.

  1. 这是二叉树, 除根节点外每个节点都有红与黑之间的一种颜色.
  2. 每个节点的两个儿子节点都被染成恰好一个红色一个黑色.
  3. 这棵树你是望不到头的 (树的深度可以到无限大)
  4. 黑红树上的高度这样定义: h (根节点)=0, h [son]=h [father]+1.

Czy 想从树根顺着树往上爬. 他有 \(\frac{p}{q}\) 的概率到达红色的儿子节点, 有 \(1 - \frac{p}{q}\) 的概率到达黑色节点. 但是他知道如果自己经过的路径是不平衡的, 他会马上摔下来. 一条红黑树上的链是不平衡的, 当且仅当红色节点与黑色节点的个数之差大于 \(1\). 现在他想知道他刚好在高度为 \(h\) 的地方摔下来的概率的精确值 \(\frac{a}{b}, gcd(a, b) = 0\). 那可能很大, 所以他只要知道 \(a\),\(b\)\(K\) 取模的结果就可以了. 另外, czy 对输入数据加密: 第 \(i\) 个询问 \(Q_i\) 真正大小将是给定的 \(Q\) 减上一个询问 的第一个值 \(a\ mod\ K\).

格式

第一行四个数\(p, q, T, k\), 表示走红色节点概率是 \(\frac{p}{q}\), 以下 \(T\) 组询问, 答案对 \(K\) 取模. 接下来 \(T\) 行, 每行一个数 \(Q\), 表示 czy 想知道刚好在高度 \(Q\) 掉下 来的概率 (已加密)

输出 \(T\) 行, 每行两个整数, 表示要求的概率 \(\frac{a}{b}\)\(a\ mod\ K\)\(b\ mod\ K\) 的精确值. 如果这个概率就是 \(0\)\(1\), 直接输出 0 01 1 (中间有空格).

样例输入 1

2 3 2 100
1
2

样例输出 1

0 0
5 9

样例输入 2

2 3 2 20
4
6

样例输出 2

0 1
0 9

数据范围

对于 30% 数据,\(p, q \leq 5, T \leq 1000, K \leq 127\), 对于任意解密后的 Q, 有 \(Q \leq 30\)

对于 60% 数据,\(p, q \leq 20, T \leq 100000, K \leq 65535\), 对于任意解密后的 Q, 有 \(Q \leq 1000\)

对于 100% 数据,\(p, q \leq 100, T \leq 1000000, K \leq 1000000007\), 对于任意解密后的 Q, 有 \(Q \leq 1000000\)

对于 100% 数据, 有 \(q \gt p\), 即 \(0 \leq \frac{p}{q} \leq 1\)

Explanation

以下是官方题解, 我的题解和这个差不多:

把树分成两层两层考虑, 那么下面的一层显然不可能出现结束的状态, 因为取到红和黑的点数之差一定为 1. 每一层只有三种情况: 1 红 1 黑, 0 红 2 黑, 2 红 0 黑. 因为达到这一层的时候 一定取到红点的个数和取到黑点的个数之和一定是偶数, 因此红点和黑点的个数一定相等. 当取到 0 红 2 黑和 2 红 0 黑的时候在这一层就结束了, 否则 1 红 1 黑就走到下一层. 结束的概率是:

\[\frac{p^2 + (p - q)^2}{q^2}\]

\[= \frac{2 p^2 - 2 p q + q^2}{q^2}\]

不结束的概率是:

\[1 - \frac{p^2 + (p - q)^2}{q^2}\]

\[= \frac{2 p q - 2 p^2}{q^2}\]

令能结束的概率是 \(\frac{A}{B}\), 一轮不能结束的概率是 \(\frac{C}{D}\). 那么答案就是 \(\frac{C}{D}^{t - 1} \cdot \frac{A}{B}\). 因为还要约分, 考虑到对于 \(n \leq 10000\),\(2\) 的因子最多不会超过 \(20\) 个, 所以前 \(20\) 个直接用分解质因数搞一下就好, 后面直接乘起来.

Source Code


#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>

using namespace std;
typedef long long lli;
const int maxn = 1001000, maxf = 1030;
lli modr = 1000000007;

int T;
lli p, q, K;

lli div_1, div_2; // res = (div_1 / div_2) ^ n

lli gcd(lli a, lli b) {
    if (b == 0) return a;
    return gcd(b, a % b);
}

lli prime[maxf], primes;
void sort_primes(void) {
    primes = 0;
    prime[++primes] = 2;
    for (int i = 3; primes < maxf - 1; i++) {
        bool isprime = true;
        for (int j = 1; j <= primes; j++)
            if (i % prime[j] == 0)
                isprime = false;
        if (!isprime) continue;
        prime[++primes] = i;
    }
    return ;
}

struct fact_res
{ int facts[maxf]; };

fact_res factor(lli in) {
    fact_res out;
    memset(out.facts, 0, sizeof(out.facts));
    for (int i = 1; i < maxf; i++) {
        if (in == 1)
            break;
        while (in % prime[i] == 0) {
            in /= prime[i];
            out.facts[i]++;
        }
    }
    return out;
}

lli res_a[maxn], res_b[maxn];
int main(int argc, char** argv)
{
    scanf("%lld%lld%d%lld", &p, &q, &T, &K);
    modr = K;
    // Pre-processing data
    if ((p == 1 && q == 1) || (p == 0 && q == 0)) {
        // Not entirely falling down, etc.
        res_a[2] = res_b[2] = 1;
    } else {
        sort_primes();
        lli c = p * p + (q - p) * (q - p),
            d = q * q,
            a = d - c,
            b = d; // a / b, c / d, two fractions.
        lli div_a_b = gcd(a, b),
            div_c_d = gcd(c, d);
        a /= div_a_b, b /= div_a_b;
        c /= div_c_d, d /= div_c_d;
        // Retrieved best case. Factoring expressions
        fact_res f_a = factor(a),
                 f_b = factor(b),
                 f_c = factor(c),
                 f_d = factor(d);
        // Setting initial state.
        res_a[2] = c % modr;
        res_b[2] = d % modr;
        // Prime counters
        int prmc[maxf];
        memset(prmc, 0, sizeof(prmc));
        for (int i = 1; i < maxf; i++)
            prmc[i] += f_c.facts[i] - f_d.facts[i];
        // Division
        for (int i = 4; i <= 40; i += 2) {
            lli sum_1 = 1, sum_2 = 1;
            for (int j = 1; j < maxf; j++)
                prmc[j] += f_a.facts[j] - f_b.facts[j];
            for (int j = 1; j < maxf; j++) {
                if (prmc[j] > 0)
                    for (int k = 1; k <= prmc[j]; k++)
                        sum_1 = (sum_1 * prime[j]) % modr;
                else if (prmc[j] < 0)
                    for (int k = 1; k <= -prmc[j]; k++)
                        sum_2 = (sum_2 * prime[j]) % modr;
            }
            res_a[i] = sum_1;
            res_b[i] = sum_2;
        }
        // The rest of them doesn't contain factor problems
        for (int i = 42; i <= 1000000; i += 2) {
            res_a[i] = (res_a[i - 2] * a) % modr;
            res_b[i] = (res_b[i - 2] * b) % modr;
        }
    }
    // Input sets of data and respond accordingly.
    lli last_enc = 0;
    for (int idx = 1; idx <= T; idx++) {
        lli Q = 0;
        scanf("%lld", &Q);
        Q -= last_enc; // Deciphering input
        printf("%lld %lld\n", res_a[Q], res_b[Q]);
        last_enc = res_a[Q];
    }
    return 0;
}

还有一半没写完的快速幂版本, 仅供部分参考:


#include <iostream>
#include <cstdlib>
#include <cstdio>

using namespace std;
typedef long long lli;
const int maxn = 1010;
lli modr = 1000000007;

class Matrix { public: lli dat[4][4]; };
class RowMatrix { public: lli dat[4]; };
Matrix operator * (Matrix a, Matrix b) {
    Matrix c;
    for (int i = 1; i <= 3; i++)
        for (int j = 1; j <= 3; j++) {
            c.dat[i][j] = 0;
            for (int k = 1; k <= 3; k++)
                c.dat[i][j] += a.dat[i][k] * b.dat[k][j];
            c.dat[i][j] %= modr;
        }
    return c; }
RowMatrix operator * (RowMatrix a, Matrix b) {
    RowMatrix c;
    for (int i = 1; i <= 3; i++) {
        c.dat[i] = 0;
        for (int k = 1; k <= 3; k++)
            c.dat[i] += a.dat[k] * b.dat[k][i];
        c.dat[i] %= modr;
    }
    return c; }
Matrix make_matrix(int _1, int _2, int _3, int __1, int __2, int __3, int _1_, int _2_, int _3_) {
    Matrix c;
    c.dat[1][1] = _1,  c.dat[1][2] = _2,  c.dat[1][3] = _3;
    c.dat[2][1] = __1, c.dat[2][2] = __2, c.dat[2][3] = __3;
    c.dat[3][1] = _1_, c.dat[3][2] = _3_, c.dat[3][3] = _3_;
    return c; }
RowMatrix make_row_matrix(int _1, int _2, int _3) {
    RowMatrix c;
    c.dat[1] = _1, c.dat[2] = _2, c.dat[3] = _3;
    return c; }

lli dp[maxn][3];
int p, q, T, K;

lli eval(lli height)
{
    lli tmp = 1;
    // Creating matrix of sequence
    Matrix m_base = make_matrix(
        0,     p,     0,
        q - p, 0,     p,
        0,     q - p, 0);
    Matrix m_calc = make_matrix(
        1, 0, 0,
        0, 1, 0,
        0, 0, 1);
    while (tmp <= height) {
        if (height & tmp)
            m_calc = m_calc * m_base;
        m_base = m_base * m_base;
    }
    // Applying matrix calculation results to number
    RowMatrix rm_c = make_row_matrix(
        0, 1, 0);
    return (rm_c.dat[1] + rm_c.dat[2] + rm_c.dat[3]) % modr;
}

int main(int argc, char** argv)
{
    scanf("%d%d%d%d", &p, &q, &T, &K);
    modr = (lli)K;
    for (int idx = 1; idx <= T; idx++) {
        // The rest of the code tampers with the reading and reacting.
        // I gurantee that this would WA most of the data.
        // So this chunk is not going to be maintained.
    }
    return 0;
}