-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy path_2049.java
101 lines (94 loc) · 3.97 KB
/
_2049.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package com.fishercoder.solutions.thirdthousand;
import com.fishercoder.common.classes.TreeNode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class _2049 {
public static class Solution1 {
/*
* My completely original solution.
* Practice makes perfect!
*/
public int countHighestScoreNodes(int[] parents) {
Map<Integer, TreeNode> valToNodeMap = new HashMap<>();
TreeNode root = buildBinaryTree(parents, valToNodeMap);
// it'll be handy if we can cache the number of children each node has as we'll do this
// many times, so we can quickly calculate the score for each node
// key is the node since each node's value is unique, value if the number of children
// this node has
Map<Integer, Long> nodeCountMap = new HashMap<>();
// naturally we should use post-order traversal since we need to count the children for
// each child first, then we can roll up to add one to get the number of children for
// the root node
long allNodeCount = postOrder(root, nodeCountMap);
nodeCountMap.put(root.val, allNodeCount);
// now calculate the score of each node
List<Long> scoreList = new ArrayList<>();
long highestScore = 0;
for (int i = 0; i < parents.length; i++) {
long score = computeScore(i, nodeCountMap, valToNodeMap);
highestScore = Math.max(score, highestScore);
scoreList.add(score);
}
int count = 0;
for (long score : scoreList) {
if (score == highestScore) {
count++;
}
}
return count;
}
private Long computeScore(
int nodeVal, Map<Integer, Long> nodeCountMap, Map<Integer, TreeNode> nodeValueMap) {
// since this is a binary tree, so, at most, removing a node, it'll split the original
// tree into three disjoint trees
TreeNode node = nodeValueMap.get(nodeVal);
Long leftSubtree = 1L;
Long rightSubtree = 1L;
Long parentSubtree = 1L;
if (node.left != null) {
if (nodeCountMap.get(node.left.val) > 0) {
leftSubtree = nodeCountMap.get(node.left.val);
}
}
if (node.right != null) {
if (nodeCountMap.get(node.right.val) > 0) {
rightSubtree = nodeCountMap.get(node.right.val);
}
}
if (nodeVal != 0) {
long diff = nodeCountMap.get(0) - nodeCountMap.get(nodeVal);
if (diff > 0) {
parentSubtree = diff;
}
}
return leftSubtree * rightSubtree * parentSubtree;
}
private long postOrder(TreeNode root, Map<Integer, Long> map) {
if (root == null) {
return 0;
}
long leftCount = postOrder(root.left, map);
long rightCount = postOrder(root.right, map);
long sum = leftCount + rightCount + 1;
map.put(root.val, sum);
return sum;
}
private TreeNode buildBinaryTree(int[] parents, Map<Integer, TreeNode> map) {
map.put(0, new TreeNode(0));
for (int i = 1; i < parents.length; i++) {
TreeNode childNode = map.getOrDefault(i, new TreeNode(i));
TreeNode parentNode = map.getOrDefault(parents[i], new TreeNode(parents[i]));
if (parentNode.left == null) {
parentNode.left = childNode;
} else {
parentNode.right = childNode;
}
map.put(parents[i], parentNode);
map.put(i, childNode);
}
return map.get(0);
}
}
}