NOIp 2018 day1 t3 road

这题考场上直接看错了啊……一顿狂码没过样例发现不太对劲……

最后半小时反应过来然后打了暴力回头10min淦完t2….

day1自然就爆炸了…..

题意

给定一棵N个点的树 边有边权(大于0) 现在要划分出K条边不相交的链 使得最短的链最长….

数据范围5e4….

Solution

极大化最小值上来先二分答案 二分的值是链的长度 看看能不能划分出K条链来

接下来就是如何检验K条链是不是能被划分出来。

随机选点为根 然后自底部向根合并链。

现在有点$x$和对应其儿子$son[x][k]$

显然对于一条链来说 有两种选择

从儿子$son[x][k]$延伸上来的链

一可以通过$e[i].w$和$e[j].w$这两条边和另一个儿子$son[x][q]$延伸上来的链合并

二呢可以通过$e[i].w$延伸到x然后通过$e[pre].w$延伸到x的父亲fa

其实不需要dp

定义$f[x]$为点x的子树内向上延伸到x的合法最长链。

合法的定义是 x的子树内两两合并完了以后剩下的里面选一条最大的

有cnt记录到达二分的长度$mid$的链条数

对于$x$子树内的链 扔到一个平衡树里会比较好处理

每次把最小的挑出来pop掉然后求$mid - now.w$的后继 如果有后继的话后继也pop掉

$now.w$是你每次挑的最小链的长度….

然后求个最大 就是f[x]

把整棵树做完以后发现cnt 还是<K 就return 0

否则在做的过程中就return 1 然后清空所有数据…

这题就做完了。

比day2t1 简单多啦…..

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include<cstdio>
#include<algorithm>
#include<cstring>

inline int max (int a, int b) {return a > b ? a : b;}
inline int min (int a, int b) {return a > b ? b : a;}
const int N = 5e4 + 7;
const int inf = 1e9 + 7;
int last[N], tot = 0, ans, dp[N], mid;
int srtNode;
struct Edge {int next, to, w;}e[N * 2];
inline void add (int u, int v, int w) {
e[++tot].next = last[u], e[tot].w = w, e[tot].to = v, last[u] = tot;
}
int n, m;
struct Splay {
int root, cnt, ch[N][2], size[N], A[N], f[N];
inline void pushup (int u) {size[u] = size[ch[u][0]] + size[ch[u][1]] + 1;}
inline int getwho (int x) {return ch[f[x]][0] == x ? 0 : 1;}
inline void Rotate (int x) {
int who = getwho (x), fa = f[x], gra = f[f[x]];
ch[gra][getwho (fa)] = x, f[x] = gra;
ch[fa][who] = ch[x][who ^ 1], f[ch[x][who ^ 1]] = fa;
ch[x][who ^ 1] = fa, f[fa] = x;
pushup (fa), pushup (x);
}
inline void splay (int x, int tar) {
while (f[x] != tar) {
if (f[f[x]] != tar) {
Rotate (getwho (f[x]) == getwho (x) ? f[x] : x);
} Rotate (x);
} if (!tar) root = x;
}
inline int findkth (int key) { int o = root;
while (o) if (size[ch[o][0]] + 1 == key) {splay (o, 0); srtNode = o;return A[o];}
else if (key <= size[ch[o][0]]) o = ch[o][0];
else key -= size[ch[o][0]] + 1, o = ch[o][1];
return 0;
}
inline void merge (int x, int y) {
if (!x) {root = y; return;}
if (!y) {root = x; return;}
while (ch[y][0]) size[y] += size[x], y = ch[y][0];
size[y] += size[x], f[x] = y, ch[y][0] = x, splay (x, 0);
}
inline void del (int x) { splay (x, 0);
f[ch[x][0]] = f[ch[x][1]] = 0; merge (ch[x][0], ch[x][1]);
}
inline void insert (int x) { A[++cnt] = x, size[cnt] = 1, ch[cnt][0] = ch[cnt][1] = f[cnt] = 0;
if (!root) {root = cnt; return;}
int last, o = root;
while (o) ++size[last = o], o = x < A[o] ? ch[o][0] : ch[o][1];
f[cnt] = last, ch[last][x < A[last] ? 0 : 1] = cnt, splay (cnt, 0);
}
inline int getNext (int x) { int o = root, k = inf, last = 0;
while (o) if (A[o] >= x) k = A[o], o = ch[last = o][0], srtNode = o;
else o = ch[last = o][1];
splay (last, 0); return k;
}
inline int find (int x) { int o = root;
while (o) if (A[o] == x) {splay (o, 0); return o;}
else o = A[o] > x ? ch[o][0] : ch[o][1];
return 0;
}
}t;
inline void dfs (int x, int fa) {
dp[x] = 0;
for (int i = last[x]; i; i = e[i].next) { int to = e[i].to;
if (to == fa) continue ;
dfs (to, x);
}
for (int i = last[x]; i; i = e[i].next) { int to = e[i].to;
if (to == fa) continue ;
if (e[i].w + dp[to] >= mid) {
ans++;
if (ans >= m) {
while (t.size[t.root]) t.del (t.root);
return ;
}
continue;
} else t.insert (e[i].w + dp[to]);
}
while (t.size[t.root]) {
if (ans >= m) {
while (t.size[t.root]) t.del (t.root);
return ;
}
if (t.size[t.root] == 1) {
int k = t.findkth(1);
dp[x] = max (k, dp[x]);
t.del(t.root);
return ;
}

int k = t.findkth (1);

t.del (srtNode);

int nowx = t.getNext (mid - k );
// tot++;
// printf ("%d ", tot);
if (nowx != inf) {
ans++, t.del (t.find(nowx));
} else dp[x] = max (dp[x], k);
}
}
inline int check () {
ans = 0;
t.cnt = t.root = 0;
dfs (1, 0);
if (ans >= m) return 1;
return 0;
}

int main () {
scanf ("%d%d", &n, &m);
int minx = inf, upmax = 0;
for (int i = 1; i < n; i++) {
int x, y, z; scanf ("%d%d%d", &x, &y, &z);
add (x, y, z), add (y, x, z);
minx = min (z, minx), upmax += z;
} tot = 0;
int l = minx, r = upmax;
while (l < r) {
mid = (l + r + 1) >> 1;
if (check ()) l = mid;
else r = mid - 1;
} printf ("%d", l);
return 0;
}