【题目链接】
【思路要点】
- 简单树上背包即可。
- 时间复杂度\(O(NK)\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 1e5 + 5; const int MAXK = 105; const int P = 1e9 + 7; template <typename T> void read(T &x) { x = 0; int f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0'; x *= f; } template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } vector <int> a[MAXN]; int depth[MAXN], dp[MAXN][4][MAXK]; int n, k, l[MAXN][4], r[MAXN][4]; //0 : 00, 1 : 10, 2 : 01, 3 : 11. void work(int pos, int fa) { depth[pos] = depth[fa] + 1; if (depth[pos] >= 1e3) { printf("%d\n", 0); exit(0); } dp[pos][0][0] = 1; l[pos][0] = r[pos][0] = 0; dp[pos][1][0] = 0; l[pos][1] = r[pos][1] = 0; dp[pos][2][1] = 1; l[pos][2] = r[pos][2] = 1; dp[pos][3][0] = 0; l[pos][3] = r[pos][3] = 1; for (unsigned i = 0; i < a[pos].size(); i++) if (a[pos][i] != fa) { work(a[pos][i], pos); int tmp = a[pos][i]; //from 1 + (1, 3) -> 1 int tl = min(l[tmp][1], l[tmp][3]); int tr = max(r[tmp][1], r[tmp][3]); for (int j = min(r[pos][1] + tr, k); j >= l[pos][1]; j--) { int tnp = dp[pos][1][j]; dp[pos][1][j] = 0; for (int k = max(l[pos][1], j - tr), q = j - k; k <= r[pos][1] && q >= tl; k++, q--) dp[pos][1][j] = (dp[pos][1][j] + 1ll * (k == j ? tnp : dp[pos][1][k]) * (0ll + dp[tmp][1][q] + dp[tmp][3][q])) % P; } l[pos][1] += tl; chkmin(l[pos][1], k); r[pos][1] += tr; chkmin(r[pos][1], k); while (dp[pos][1][r[pos][1]] == 0 && r[pos][1] > l[pos][1]) r[pos][1]--; while (dp[pos][1][l[pos][1]] == 0 && r[pos][1] > l[pos][1]) l[pos][1]++; //from 0 + 3 -> 1 for (int j = min(r[pos][0] + r[tmp][3], k); j >= l[pos][0] + l[tmp][3]; j--) { for (int k = max(l[pos][0], j - r[tmp][3]), q = j - k; k <= r[pos][0] && q >= l[tmp][3]; k++, q--) dp[pos][1][j] = (dp[pos][1][j] + 1ll * dp[pos][0][k] * dp[tmp][3][q]) % P; } chkmin(l[pos][1], l[pos][0] + l[tmp][3]); chkmax(r[pos][1], min(r[pos][0] + r[tmp][3], k)); while (dp[pos][1][r[pos][1]] == 0 && r[pos][1] > l[pos][1]) r[pos][1]--; while (dp[pos][1][l[pos][1]] == 0 && r[pos][1] > l[pos][1]) l[pos][1]++; //from 0 + 1 -> 0 for (int j = min(r[pos][0] + r[tmp][1], k); j >= l[pos][0]; j--) { int tnp = dp[pos][0][j]; dp[pos][0][j] = 0; for (int k = max(l[pos][0], j - r[tmp][1]), q = j - k; k <= r[pos][0] && q >= l[tmp][1]; k++, q--) dp[pos][0][j] = (dp[pos][0][j] + 1ll * (k == j ? tnp : dp[pos][0][k]) * dp[tmp][1][q]) % P; } l[pos][0] += l[tmp][1]; chkmin(l[pos][0], k); r[pos][0] += r[tmp][1]; chkmin(r[pos][0], k); while (dp[pos][0][r[pos][0]] == 0 && r[pos][0] > l[pos][0]) r[pos][0]--; while (dp[pos][0][l[pos][0]] == 0 && r[pos][0] > l[pos][0]) l[pos][0]++; //from 3 + (0, 1, 2, 3) -> 3 tl = min(min(l[tmp][0], l[tmp][1]), min(l[tmp][2], l[tmp][3])); tr = max(max(r[tmp][0], r[tmp][1]), max(r[tmp][2], r[tmp][3])); for (int j = min(r[pos][3] + tr, k); j >= l[pos][3]; j--) { int tnp = dp[pos][3][j]; dp[pos][3][j] = 0; for (int k = max(l[pos][3], j - tr), q = j - k; k <= r[pos][3] && q >= tl; k++, q--) dp[pos][3][j] = (dp[pos][3][j] + 1ll * (k == j ? tnp : dp[pos][3][k]) * (0ll + dp[tmp][0][q] + dp[tmp][1][q] + dp[tmp][2][q] + dp[tmp][3][q])) % P; } l[pos][3] += tl; chkmin(l[pos][3], k); r[pos][3] += tr; chkmin(r[pos][3], k); while (dp[pos][3][r[pos][3]] == 0 && r[pos][3] > l[pos][3]) r[pos][3]--; while (dp[pos][3][l[pos][3]] == 0 && r[pos][3] > l[pos][3]) l[pos][3]++; //from 2 + (2, 3) -> 3 tl = min(l[tmp][2], l[tmp][3]); tr = max(r[tmp][2], r[tmp][3]); for (int j = min(r[pos][2] + tr, k); j >= l[pos][2] + tl; j--) { for (int k = max(l[pos][2], j - tr), q = j - k; k <= r[pos][2] && q >= tl; k++, q--) dp[pos][3][j] = (dp[pos][3][j] + 1ll * dp[pos][2][k] * (0ll + dp[tmp][2][q] + dp[tmp][3][q])) % P; } chkmin(l[pos][3], l[pos][2] + tl); chkmax(r[pos][3], min(r[pos][2] + tr, k)); while (dp[pos][3][r[pos][3]] == 0 && r[pos][3] > l[pos][3]) r[pos][3]--; while (dp[pos][3][l[pos][3]] == 0 && r[pos][3] > l[pos][3]) l[pos][3]++; //from 2 + (0, 1) -> 2 tl = min(l[tmp][1], l[tmp][0]); tr = max(r[tmp][1], r[tmp][0]); for (int j = min(r[pos][2] + tr, k); j >= l[pos][2]; j--) { int tnp = dp[pos][2][j]; dp[pos][2][j] = 0; for (int k = max(l[pos][2], j - tr), q = j - k; k <= r[pos][2] && q >= tl; k++, q--) dp[pos][2][j] = (dp[pos][2][j] + 1ll * (k == j ? tnp : dp[pos][2][k]) * (0ll + dp[tmp][1][q] + dp[tmp][0][q])) % P; } l[pos][2] += tl; chkmin(l[pos][2], k); r[pos][2] += tr; chkmin(r[pos][2], k); while (dp[pos][2][r[pos][2]] == 0 && r[pos][2] > l[pos][2]) r[pos][2]--; while (dp[pos][2][l[pos][2]] == 0 && r[pos][2] > l[pos][2]) l[pos][2]++; } } int main() { read(n), read(k); for (int i = 1; i <= n - 1; i++) { int x, y; read(x), read(y); a[x].push_back(y); a[y].push_back(x); } work(1, 0); printf("%d\n", (dp[1][1][k] + dp[1][3][k]) % P); return 0; }