问题描述

一个数字被称为好数字当他满足下列条件: 1. 它有 2*n 个数位, n 是正整数 (允许有前导 0). 2. 构成它的每个数字都在给定的数字集合 S 中. 3. 它前 n 位之和与后 n 位之和相等或者它奇数位之和与偶数位之和相等例如对于 n=2, S={1, 2}, 合法的好数字有 1111, 1122, 1212, 1221, 2112, 2121, 2211, 2222 这样 8 种. 已知 n, 求合法的好数字的个数 mod 999983.

输入格式

第一行一个数 n. 接下来一个长度不超过 10 的字符串, 表示给定的数字集合.

输出格式

一行一个数字表示合法的好数字的个数 mod 999983.

样例输入

2
0987654321

样例输出

1240

数据规模

对于 20% 的数据, n≤7. 对于 100% 的. 据, n≤1000,|S|≤10.

Explanation

我们用类分治的想法来解这一道题. 首先将整个数列分成四部分:

  1. 在前一半, 位置为奇数位的数构成的数列 (a[])
  2. 在前一半, 位置为偶数位的数构成的数列 (b[])
  3. 在后一半, 位置为奇数位的数构成的数列 (c[])
  4. 在后一半, 位置为偶数位的数构成的数列 (d[])

然后由于这四部分分别独立, 所以可以用 dp 在\(O(n^2)\) 时间内求出来排列它们的方案数. 进一步, 我们可以知道最终的方案数为 \(a + b = c + d\ or\ a + c = b + d\).

但是这就带来一个问题, 常数大约为\(100\) 的 $O (n^2) 做法一定会 TLE (实测大概需要 4 秒左右, 但是 常数实在太大无法优化, 也就是已经达到了渐进复杂度极限. 所以我们现在要从另一个角度来看这件事. 我们对这个方程进行化简:

\[let\ dp[0]_i\ be\ a+b, dp[1]_i\ be\ c+d, dp[2]_i\ be\ a+c, dp[3]_i\ be\ b+d.\]

\[res = \sum_{i=0}^{m}(dp[0]_i \times dp[1]_i) + \sum_{i=0}^{m}(dp[2]_i \times dp[3]_i) - \sum_{i=0}^{m}a_i d_i \times \sum_{i=0}^{m}b_i c_i\]

\[= \sum_{i=0}^{m}(dp[0]_i \times dp[1]_i + dp[2]_i \times dp[3]_i) - \sum_{i=0}^{m}a_i d_i \times \sum_{i=0}^{m}b_i c_i\]

\[= \sum_{i=0}^{m}(\sum_{j=0}^{m}a_j b_{i-j} \times \sum_{j=0}^{m}c_j d_{i-j} + \sum_{j=0}^{m}a_j c_{i-j} \times \sum_{j=0}^{m}b_j d_{i-j}) - \sum_{i=0}^{m}a_i d_i \times \sum_{i=0}^{m}b_i c_i\]

\[\because a = d, b = c\]

\[\therefore res = 2\sum_{i=0}^{m}(\sum_{j=0}^{m} a_j b_{i-j})^2 - \sum_{i=0}^{m} a_i^2 - \sum_{i=0}^{m} b_i^2\]

\[\because \sum_{j=0}^{m} a_j b_{i-j} = dp_{n,i}\]

\[\therefore res = 2\sum_{i=0}^{m} dp[n][i] - \sum_{i=0}^{m} dp[\lceil \frac{n}{2} \rceil][i]^2 - \sum_{i=0}^{m} dp[\lfloor \frac{n}{2} \rfloor][i]^2\]

然后这就被优化成了一个常数为\(10\)\(O(n^2)\) 算法, 刚好踩着 0. 9 秒的时限能够 AC.

Example Code


#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <cmath>

using namespace std;
typedef long long lli;
const int maxn = 1010, maxm = 10100, maxx = 10;
const lli modr = 999983;
// #define USE_BRUTE_FORCE

bool avail[maxx]; // Available numbers
int n, availmax;
#ifdef USE_BRUTE_FORCE
lli dp[2][maxm]; // Rolling array, dp[i][j] = count(len = i, sum = j)
int dp_idx = 0;
#else
lli dp[maxn][maxm];
#endif

#ifdef USE_BRUTE_FORCE
lli proc_dp[4][maxm];
void proc_sub(int pos, lli a[], lli b[], int max_size)
{
    for (int i = 0; i <= max_size; i++) {
        proc_dp[pos][i] = 0;
        for (int j = 0; j <= i; j++)
            proc_dp[pos][i] += a[j] * b[i - j];
        proc_dp[pos][i] %= modr;
    }
    return ;
}
lli proc(lli a[], lli b[], lli c[], lli d[])
{
    // Returns a+b=c+d || a+c=b+d => a=d, b=c
    int max_size = n;
    for (int i = maxx - 1; i >= 0; i--)
        if (avail[i]) { max_size *= i; break; }
    // Preprocess initial data
    proc_sub(0, a, b, max_size);
    proc_sub(1, c, d, max_size);
    proc_sub(2, a, c, max_size);
    proc_sub(3, b, d, max_size);
    // Type 1 (+) : a+b=c+d
    lli res = 0, res_a = 0, res_b = 0;
    for (int i = 0; i <= max_size; i++)
        res_a += proc_dp[0][i] * proc_dp[1][i];
    res += res_a % modr;
    // Type 2 (+) : a+c=b+d
    for (int i = 0; i <= max_size; i++)
        res_b += proc_dp[2][i] * proc_dp[3][i];
    res += res_b % modr;
    // Type 3 (-) : a=d, b=c
    res_a = 0, res_b = 0, res %= modr;
    for (int i = 0; i <= max_size; i++)
        res_a += a[i] * d[i];
    for (int i = 0; i <= max_size; i++)
        res_b += b[i] * c[i];
    res -= ((res_a % modr) * (res_b % modr)) % modr;
    return res;
}
#else
#endif

int main(int argc, char** argv)
{
    scanf("%d", &n);
    char str[20];
    scanf("%s", str);
    for (unsigned int i = 0; i < strlen(str); i++) {
        avail[str[i] - '0'] = true;
        availmax = max(availmax, str[i] - '0');
    }
    // Done input, ready to dp.
    #ifdef USE_BRUTE_FORCE
    #define dp_now dp[dp_idx]
    #define dp_last dp[!dp_idx]
    int maxdp = 0;
    dp_now[0] = 1;
    int cl_n_2 = (n + 1) / 2;
    for (int i = 1; i <= cl_n_2; i++) {
        for (int j = 0; j <= maxdp + 9; j++)
            dp_last[j] = 0;
        for (int j = 0; j <= maxdp; j++)
            for (int k = 0; k < maxx; k++)
                if (avail[k])
                    dp_last[j + k] += dp_now[j];
        for (int j = 0; j <= maxdp + 10; j++)
            dp_last[j] %= modr;
        dp_idx ^= 1;
        maxdp += availmax;
    }
    dp_idx ^= 1;
    // Processing with distinction
    lli res = 0;
    if (n % 2 == 0)
        res = proc(dp_last, dp_last, dp_last, dp_last);
    else
        res = proc(dp_now, dp_last, dp_last, dp_now);
    #else
    // Do not use brute force, or rolling arrays
    dp[0][0] = 1;
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j <= i * availmax; j++)
            for (int k = 0; k <= availmax; k++)
                if (avail[k])
                    dp[i][j + k] += dp[i - 1][j];
        for (int j = 0; j <= (i + 1) * availmax; j++)
            dp[i][j] %= modr;
    }
    lli res = 0, A = 0, B = 0;
    for (int i = 0; i <= n * availmax; i++)
        res += 2 * dp[n][i] * dp[n][i];
    res %= modr;
    int a = (n + 1) / 2, b = n - a;
    for (int i = 0; i <= a * availmax; i++)
        A += dp[a][i] * dp[a][i];
    for (int i = 0; i <= b * availmax; i++)
        B += dp[b][i] * dp[b][i];
    A %= modr, B %= modr;
    res = (res + modr - A * B % modr) % modr;
    #endif
    printf("%lld\n", res);
    return 0;
}