题意:给一颗带点权的树,求树上满足所有点权能构成多边形的路径数量。
分析:若$a_1, a_2 , ... , a_n$能构成多边形,则有$\sum _{i = 1} n a _i \gt 2 \times max _ { a_i } $。点分治分为经过重心的贡献和子树递归的贡献,记$calc(u)$为计算$u$子树内不考虑重复的贡献,若$v$为$u$的出边,那么经过$u$的贡献即为$calc(u) -\sum calc(v)$。对于$calc(u)$,考虑从$u$开始$dfs$,记$sum(v)$为$u$到$v$路径上权值之和,$mx(v)$为$u$到$v$路径上最大的权值,对于一颗子树$u$的所有$sum$和$mx$,对$mx$从小到大排序后,假设$j \lt i$,那么有$mx(j) \lt mx(i)$,那么有$sum(i) + sum(j) + a(u) \gt 2 \times max(mx(i), mx(j))$,转换后为$2 \times mx(i) - sum(i) + a(u) \lt sum(j)$。所以对一棵树的贡献只需要用树状数组维护$sum$即可。剩下的就是点分治的模板了。
#include <bits/stdc++.h>
using namespace std;
const int N = 4e5 + 10;
int n, a[N];
vector<int> e[N];
bool vis[N];
int get_siz(int u, int fa) {
if (vis[u]) return 0;
int res = 1;
for (int v: e[u]) {
if (v == fa) continue;
res += get_siz(v, u);
}
return res;
}
int get_cg(int u, int fa, int tot, int &cg) {
if (vis[u]) return 0;
int siz = 1, mx = 0;
for (int v: e[u]) {
if (v == fa) continue;
int son = get_cg(v, u, tot, cg);
mx = max(mx, son);
siz += son;
}
mx = max(mx, tot - siz);
if (mx * 2 <= tot) cg = u;
return siz;
}
long long ans;
vector<long long> nums;
inline int get_low(long long x) {
return lower_bound(nums.begin(), nums.end(), x) - nums.begin() + 1;
}
inline int get_up(long long x) {
return upper_bound(nums.begin(), nums.end(), x) - nums.begin() + 1;
}
vector<pair<int, long long> > val;
int tr[N];
void add(int p, int v) {
for (int i = p; i <= n; i += (i & -i)) {
tr[i] += v;
}
}
int query(int p) {
int res = 0;
while (p) {
res += tr[p];
p -= (p & -p);
}
return res;
}
void dfs(int u, int fa, long long sum, int w) {
sum += a[u];
w = max(w, a[u]);
val.push_back(make_pair(w, sum));
nums.push_back(sum);
for (int v: e[u]) {
if (v == fa || vis[v]) continue;
dfs(v, u, sum, w);
}
}
/*
j < i, mx[j] < mx[i]
sum[i] + sum[j] - a[rt] > 2 * max(mx[i], mx[j])
2 * mx[i] < sum[i] + sum[j] - a[rt]
2 * mx[i] - sum[i] + a[rt] < sum[j]
*/
long long calc(int u, int w) {
val.clear();
nums.clear();
dfs(u, 0, w, w);
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
sort(val.begin(), val.end());
int x = 0;
long long res = 0;
for (auto pi: val) {
long long now = 2ll * pi.first - pi.second + (w ? w : a[u]);
res += x;
res -= query(get_up(now) - 1);
add(get_low(pi.second), 1);
x++;
}
for (auto pi: val) add(get_low(pi.second), -1);
return res;
}
void solve(int u) {
if (vis[u]) return;
get_cg(u, 0, get_siz(u, 0), u);
vis[u] = true;
ans += calc(u, 0);
for (int v: e[u]) {
if (vis[v]) continue;
ans -= calc(v, a[u]);
solve(v);
}
}
int main() {
int tt;
scanf("%d",&tt);
while (tt--) {
scanf("%d",&n);
ans = 0;
for (int i = 1; i <= n; i++) e[i].clear(), vis[i] = false;
for (int i = 1; i <= n; i++) scanf("%d",&a[i]);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
solve(1);
printf("%lld\n",ans);
}
return 0;
}