1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
| #include<bits/stdc++.h>
using namespace std; const int N = 2e5 + 10; int n, m, num, t, sum, c[N], siz[N], in[N], out[N], id[N], pos[N], cnt[N], tot[N], ans; vector<int> g[N]; struct node { int l, r; } q[N]; bool cmp(node a, node b) { if (pos[a.l] == pos[b.l]) { if (pos[a.l] & 1) { return a.r < b.r; } else { return a.r > b.r; } } return a.l < b.l; } void dfs(int x) { siz[x] = 1; in[x] = ++num; id[num] = x; for (int y : g[x]) { dfs(y); siz[x] += siz[y]; } out[x] = num; q[++t] = {in[x], out[x]}; } void add(int x) { if (cnt[c[x]]) { if (--tot[cnt[c[x]]] == 0) { sum--; } } cnt[c[x]]++; if (++tot[cnt[c[x]]] == 1) { sum++; } } void del(int x) { if (--tot[cnt[c[x]]] == 0) { sum--; } cnt[c[x]]--; if (cnt[c[x]]) { if (++tot[cnt[c[x]]] == 1) { sum++; } } }
int main() { cin >> n; m = sqrt(n); for (int i = 1; i <= n; i++) { pos[i] = (i - 1) / m + 1; int fa; cin >> c[i] >> fa; g[fa].push_back(i); } dfs(1); sort(q + 1, q + 1 + n, cmp); for (int l = 1, r = 0, i = 1; i <= n; i++) { while (l > q[i].l) add(id[--l]); while (r < q[i].r) add(id[++r]); while (l < q[i].l) del(id[l++]); while (r > q[i].r) del(id[r--]); if (sum == 1) { ans++; } } cout << ans; return 0; }
|