Constructing Ranches(点分治模板)

fyh 2021年11月03日 67次浏览

题意:给一颗带点权的树,求树上满足所有点权能构成多边形的路径数量。
分析:若$a_1, a_2 , ... , a_n$能构成多边形,则有$\sum _{i = 1} n a _i \gt 2 \times max _ { a_i } $。点分治分为经过重心的贡献和子树递归的贡献,记$calc(u)$为计算$u$子树内不考虑重复的贡献,若$v$为$u$的出边,那么经过$u$的贡献即为$calc(u) -\sum calc(v)$。对于$calc(u)$,考虑从$u$开始$dfs$,记$sum(v)$为$u$到$v$路径上权值之和,$mx(v)$为$u$到$v$路径上最大的权值,对于一颗子树$u$的所有$sum$和$mx$,对$mx$从小到大排序后,假设$j \lt i$,那么有$mx(j) \lt mx(i)$,那么有$sum(i) + sum(j) + a(u) \gt 2 \times max(mx(i), mx(j))$,转换后为$2 \times mx(i) - sum(i) + a(u) \lt sum(j)$。所以对一棵树的贡献只需要用树状数组维护$sum$即可。剩下的就是点分治的模板了。

#include <bits/stdc++.h>

using namespace std;

const int N = 4e5 + 10;

int n, a[N];

vector<int> e[N];

bool vis[N];

int get_siz(int u, int fa) {
  if (vis[u]) return 0;
  int res = 1;
  for (int v: e[u]) {
    if (v == fa) continue;
    res += get_siz(v, u);
  }
  return res;
}

int get_cg(int u, int fa, int tot, int &cg) {
  if (vis[u]) return 0;
  int siz = 1, mx = 0;
  for (int v: e[u]) {
    if (v == fa) continue;
    int son = get_cg(v, u, tot, cg);
    mx = max(mx, son);
    siz += son;
  }
  mx = max(mx, tot - siz);
  if (mx * 2 <= tot) cg = u;
  return siz;
}

long long ans;

vector<long long> nums;

inline int get_low(long long x) {
  return lower_bound(nums.begin(), nums.end(), x) - nums.begin() + 1;
}

inline int get_up(long long x) {
  return upper_bound(nums.begin(), nums.end(), x) - nums.begin() + 1;
}

vector<pair<int, long long> > val;

int tr[N];

void add(int p, int v) {
  for (int i = p; i <= n; i += (i & -i)) {
    tr[i] += v;
  }
}

int query(int p) {
  int res = 0;
  while (p) {
    res += tr[p];
    p -= (p & -p);
  }
  return res;
}

void dfs(int u, int fa, long long sum, int w) {
  sum += a[u];
  w = max(w, a[u]);
  val.push_back(make_pair(w, sum));
  nums.push_back(sum);
  for (int v: e[u]) {
    if (v == fa || vis[v]) continue;
    dfs(v, u, sum, w);
  }
}

/*
j < i, mx[j] < mx[i]
sum[i] + sum[j] - a[rt] > 2 * max(mx[i], mx[j])
2 * mx[i] < sum[i] + sum[j] - a[rt]
2 * mx[i] - sum[i] + a[rt] < sum[j]
*/

long long calc(int u, int w) {
  val.clear();
  nums.clear();
  dfs(u, 0, w, w);
  sort(nums.begin(), nums.end());
  nums.erase(unique(nums.begin(), nums.end()), nums.end());
  sort(val.begin(), val.end());
  int x = 0;
  long long res = 0;
  for (auto pi: val) {
    long long now = 2ll * pi.first - pi.second + (w ? w : a[u]);
    res += x;
    res -= query(get_up(now) - 1);
    add(get_low(pi.second), 1);
    x++;
  }
  for (auto pi: val) add(get_low(pi.second), -1);
  return res;
}

void solve(int u) {
  if (vis[u]) return;
  get_cg(u, 0, get_siz(u, 0), u);
  vis[u] = true;
  ans += calc(u, 0);
  for (int v: e[u]) {
    if (vis[v]) continue;
    ans -= calc(v, a[u]);
    solve(v);
  }
}

int main() {
  int tt;
  scanf("%d",&tt);
  while (tt--) {
    scanf("%d",&n);
    ans = 0;
    for (int i = 1; i <= n; i++) e[i].clear(), vis[i] = false;
    for (int i = 1; i <= n; i++) scanf("%d",&a[i]);
    for (int i = 1; i < n; i++) {
      int u, v;
      scanf("%d%d",&u,&v);
      e[u].push_back(v);
      e[v].push_back(u);
    }
    solve(1);
    printf("%lld\n",ans);    
  }
  return 0;
}