问题描述
一个数字被称为好数字当他满足下列条件: 1. 它有 2*n 个数位, n 是正整数 (允许有前导 0). 2. 构成它的每个数字都在给定的数字集合 S 中. 3. 它前 n 位之和与后 n 位之和相等或者它奇数位之和与偶数位之和相等例如对于 n=2, S={1, 2}, 合法的好数字有 1111, 1122, 1212, 1221, 2112, 2121, 2211, 2222 这样 8 种. 已知 n, 求合法的好数字的个数 mod 999983.
输入格式
第一行一个数 n. 接下来一个长度不超过 10 的字符串, 表示给定的数字集合.
输出格式
一行一个数字表示合法的好数字的个数 mod 999983.
样例输入
2
0987654321
样例输出
1240
数据规模
对于 20% 的数据, n≤7. 对于 100% 的. 据, n≤1000,|S|≤10.
Explanation
我们用类分治的想法来解这一道题. 首先将整个数列分成四部分:
- 在前一半, 位置为奇数位的数构成的数列 (
a[]
) - 在前一半, 位置为偶数位的数构成的数列 (
b[]
) - 在后一半, 位置为奇数位的数构成的数列 (
c[]
) - 在后一半, 位置为偶数位的数构成的数列 (
d[]
)
然后由于这四部分分别独立, 所以可以用 dp 在\(O(n^2)\) 时间内求出来排列它们的方案数. 进一步, 我们可以知道最终的方案数为 \(a + b = c + d\ or\ a + c = b + d\).
但是这就带来一个问题, 常数大约为\(100\) 的 $O (n^2) 做法一定会 TLE (实测大概需要 4 秒左右, 但是 常数实在太大无法优化, 也就是已经达到了渐进复杂度极限. 所以我们现在要从另一个角度来看这件事. 我们对这个方程进行化简:
\[let\ dp[0]_i\ be\ a+b, dp[1]_i\ be\ c+d, dp[2]_i\ be\ a+c, dp[3]_i\ be\ b+d.\]
\[res = \sum_{i=0}^{m}(dp[0]_i \times dp[1]_i) + \sum_{i=0}^{m}(dp[2]_i \times dp[3]_i) - \sum_{i=0}^{m}a_i d_i \times \sum_{i=0}^{m}b_i c_i\]
\[= \sum_{i=0}^{m}(dp[0]_i \times dp[1]_i + dp[2]_i \times dp[3]_i) - \sum_{i=0}^{m}a_i d_i \times \sum_{i=0}^{m}b_i c_i\]
\[= \sum_{i=0}^{m}(\sum_{j=0}^{m}a_j b_{i-j} \times \sum_{j=0}^{m}c_j d_{i-j} + \sum_{j=0}^{m}a_j c_{i-j} \times \sum_{j=0}^{m}b_j d_{i-j}) - \sum_{i=0}^{m}a_i d_i \times \sum_{i=0}^{m}b_i c_i\]
\[\because a = d, b = c\]
\[\therefore res = 2\sum_{i=0}^{m}(\sum_{j=0}^{m} a_j b_{i-j})^2 - \sum_{i=0}^{m} a_i^2 - \sum_{i=0}^{m} b_i^2\]
\[\because \sum_{j=0}^{m} a_j b_{i-j} = dp_{n,i}\]
\[\therefore res = 2\sum_{i=0}^{m} dp[n][i] - \sum_{i=0}^{m} dp[\lceil \frac{n}{2} \rceil][i]^2 - \sum_{i=0}^{m} dp[\lfloor \frac{n}{2} \rfloor][i]^2\]
然后这就被优化成了一个常数为\(10\) 的\(O(n^2)\) 算法, 刚好踩着 0. 9 秒的时限能够 AC.
Example Code
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
typedef long long lli;
const int maxn = 1010, maxm = 10100, maxx = 10;
const lli modr = 999983;
// #define USE_BRUTE_FORCE
bool avail[maxx]; // Available numbers
int n, availmax;
#ifdef USE_BRUTE_FORCE
lli dp[2][maxm]; // Rolling array, dp[i][j] = count(len = i, sum = j)
int dp_idx = 0;
#else
lli dp[maxn][maxm];
#endif
#ifdef USE_BRUTE_FORCE
lli proc_dp[4][maxm];
void proc_sub(int pos, lli a[], lli b[], int max_size)
{
for (int i = 0; i <= max_size; i++) {
proc_dp[pos][i] = 0;
for (int j = 0; j <= i; j++)
proc_dp[pos][i] += a[j] * b[i - j];
proc_dp[pos][i] %= modr;
}
return ;
}
lli proc(lli a[], lli b[], lli c[], lli d[])
{
// Returns a+b=c+d || a+c=b+d => a=d, b=c
int max_size = n;
for (int i = maxx - 1; i >= 0; i--)
if (avail[i]) { max_size *= i; break; }
// Preprocess initial data
proc_sub(0, a, b, max_size);
proc_sub(1, c, d, max_size);
proc_sub(2, a, c, max_size);
proc_sub(3, b, d, max_size);
// Type 1 (+) : a+b=c+d
lli res = 0, res_a = 0, res_b = 0;
for (int i = 0; i <= max_size; i++)
res_a += proc_dp[0][i] * proc_dp[1][i];
res += res_a % modr;
// Type 2 (+) : a+c=b+d
for (int i = 0; i <= max_size; i++)
res_b += proc_dp[2][i] * proc_dp[3][i];
res += res_b % modr;
// Type 3 (-) : a=d, b=c
res_a = 0, res_b = 0, res %= modr;
for (int i = 0; i <= max_size; i++)
res_a += a[i] * d[i];
for (int i = 0; i <= max_size; i++)
res_b += b[i] * c[i];
res -= ((res_a % modr) * (res_b % modr)) % modr;
return res;
}
#else
#endif
int main(int argc, char** argv)
{
scanf("%d", &n);
char str[20];
scanf("%s", str);
for (unsigned int i = 0; i < strlen(str); i++) {
avail[str[i] - '0'] = true;
availmax = max(availmax, str[i] - '0');
}
// Done input, ready to dp.
#ifdef USE_BRUTE_FORCE
#define dp_now dp[dp_idx]
#define dp_last dp[!dp_idx]
int maxdp = 0;
dp_now[0] = 1;
int cl_n_2 = (n + 1) / 2;
for (int i = 1; i <= cl_n_2; i++) {
for (int j = 0; j <= maxdp + 9; j++)
dp_last[j] = 0;
for (int j = 0; j <= maxdp; j++)
for (int k = 0; k < maxx; k++)
if (avail[k])
dp_last[j + k] += dp_now[j];
for (int j = 0; j <= maxdp + 10; j++)
dp_last[j] %= modr;
dp_idx ^= 1;
maxdp += availmax;
}
dp_idx ^= 1;
// Processing with distinction
lli res = 0;
if (n % 2 == 0)
res = proc(dp_last, dp_last, dp_last, dp_last);
else
res = proc(dp_now, dp_last, dp_last, dp_now);
#else
// Do not use brute force, or rolling arrays
dp[0][0] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= i * availmax; j++)
for (int k = 0; k <= availmax; k++)
if (avail[k])
dp[i][j + k] += dp[i - 1][j];
for (int j = 0; j <= (i + 1) * availmax; j++)
dp[i][j] %= modr;
}
lli res = 0, A = 0, B = 0;
for (int i = 0; i <= n * availmax; i++)
res += 2 * dp[n][i] * dp[n][i];
res %= modr;
int a = (n + 1) / 2, b = n - a;
for (int i = 0; i <= a * availmax; i++)
A += dp[a][i] * dp[a][i];
for (int i = 0; i <= b * availmax; i++)
B += dp[b][i] * dp[b][i];
A %= modr, B %= modr;
res = (res + modr - A * B % modr) % modr;
#endif
printf("%lld\n", res);
return 0;
}