1、题目
给定一棵二叉树的根节点root
,任何两个节点之间都存在距离,返回整棵二叉树的最大距离。
2、分析
两个节点之间的距离就是从 A 节点到 B 节点沿途的节点数。
可以有多种解法,但是二叉树套路:
①目标:以 X 为根节点的情况下,整棵树的最大距离;
②可能性:可能和 X 有关,也可能和 X 无关。
-
和X无关的情况有两种可能性:(1)X左树最大距离;(2)X右树最大距离;
-
和X有关的情况的可能性:"X左树与X最远 + X右树与X最远 + 1“,而 X 左树与 X 最远就是左树的高度,X右树与X最远就是右树高度,1是X节点。
所以,向左树和右树想获得的信息:(1)树的最大距离;(2)树的高度;
最终结果就是在三种可能性中求最大值。
3、实现
C++ 版
/*************************************************************************
> File Name: 037.返回二叉树的最大距离.cpp
> Author: Maureen
> Mail: Maureen@qq.com
> Created Time: 五 7/ 1 17:13:32 2022
************************************************************************/
#include <iostream>
#include <ctime>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using namespace std;
class TreeNode {
public:
int value;
TreeNode *left;
TreeNode *right;
TreeNode(int data) : value(data) {}
};
//方法一:暴力枚举,先序遍历获得序列,暴力枚举两个节点之间的距离
void fillPrelist(TreeNode *root, vector<TreeNode*> &arr) {
if (root == nullptr) return ;
arr.push_back(root);
fillPrelist(root->left, arr);
fillPrelist(root->right, arr);
}
vector<TreeNode *> getPrelist(TreeNode *root) {
vector<TreeNode*> arr;
fillPrelist(root, arr);
return arr;
}
//记录每个节点的父节点
void fillParentMap(TreeNode *root, unordered_map<TreeNode*, TreeNode*> &parentMap) {
if (root->left != nullptr) {
parentMap[root->left] = root;
fillParentMap(root->left, parentMap);
}
if (root->right != nullptr) {
parentMap[root->right] = root;
fillParentMap(root->right, parentMap);
}
}
unordered_map<TreeNode *, TreeNode *> getParentMap(TreeNode *root) {
unordered_map<TreeNode*, TreeNode*> _map;
_map[root] = nullptr;
fillParentMap(root, _map);
return _map;
}
//获取两个节点之间的距离
int distance(unordered_map<TreeNode*, TreeNode*> &parentMap, TreeNode *o1, TreeNode *o2) {
unordered_set<TreeNode*> o1set;
TreeNode *cur = o1;
o1set.insert(cur);
while (parentMap[cur] != nullptr) {
cur = parentMap[cur];
o1set.insert(cur);
}
//找到o1和o2的最低公共祖先
cur = o2;
while (!o1set.count(cur)) {
cur = parentMap[cur];
}
//统计o1和o2到最低公共祖先的距离
TreeNode *lowestAncestor = cur;
cur = o1;
int distance1 = 1;
while (cur != lowestAncestor) {
cur = parentMap[cur];
distance1++;
}
cur = o2;
int distance2 = 1;
while (cur != lowestAncestor) {
cur = parentMap[cur];
distance2++;
}
return distance1 + distance2 - 1;
}
int maxDistance1(TreeNode *root) {
if (root == nullptr) return 0;
vector<TreeNode*> arr = getPrelist(root);
unordered_map<TreeNode *, TreeNode *> parentMap = getParentMap(root);
int ans = 0;
for (int i = 0; i < arr.size(); i++) {
for (int j = i; j < arr.size(); j++) {
ans = max(ans, distance(parentMap, arr[i], arr[j]));
}
}
return ans;
}
//方法二
class Info {
public:
int maxDistance;
int height;
Info(int m, int h) {
maxDistance = m;
height = h;
}
};
Info *process(TreeNode *x) {
if (x == nullptr) {
return new Info(0, 0);
}
Info *leftInfo = process(x->left);
Info *rightInfo = process(x->right);
int height = max(leftInfo->height, rightInfo->height) + 1;
int p1 = leftInfo->maxDistance;
int p2 = rightInfo->maxDistance;
int p3 = leftInfo->height + rightInfo->height + 1;
int maxDistance = max(max(p1, p2), p3);
return new Info(maxDistance, height);
}
int maxDistance2(TreeNode *root) {
return process(root)->maxDistance;
}
//for test
TreeNode *generate(int level, int maxl, int maxv) {
if (level > maxl || (rand() % 100 / (double)101) < 0.5)
return nullptr;
TreeNode *root = new TreeNode(rand() % maxv);
root->left = generate(level + 1, maxl, maxv);
root->right = generate(level + 1, maxl, maxv);
return root;
}
TreeNode *generateRandomBST(int level, int value) {
return generate(1, level, value);
}
int main() {
srand(time(0));
int level = 5;
int value = 100;
int testTimes = 1000001;
cout << "Test begin:" << endl;
for (int i = 0; i < testTimes; i++) {
TreeNode *root = generateRandomBST(level, value);
if (maxDistance1(root) != maxDistance2(root)) {
cout << "Failed!" << endl;
return 0;
}
if (i && i % 100 == 0) cout << i << " cases passed!" << endl;
}
cout << "Success!" << endl;
return 0;
}
Java 版
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import javax.security.auth.x500.X500PrivateCredential;
public class MaxDistance {
public static class Node {
public int value;
public Node left;
public Node right;
public Node(int data) {
this.value = data;
}
}
//方法一:暴力枚举,先序遍历获得序列,暴力枚举两个节点之间的距离
public static int maxDistance1(Node head) {
if (head == null) {
return 0;
}
ArrayList<Node> arr = getPrelist(head);
HashMap<Node, Node> parentMap = getParentMap(head);
int max = 0;
for (int i = 0; i < arr.size(); i++) {
for (int j = i; j < arr.size(); j++) {
max = Math.max(max, distance(parentMap, arr.get(i), arr.get(j)));
}
}
return max;
}
public static ArrayList<Node> getPrelist(Node head)
ArrayList<Node> arr = new ArrayList<>();
fillPrelist(head, arr);
return arr;
}
public static void fillPrelist(Node head, ArrayList<Node> arr) {
if (head == null) {
return;
}
arr.add(head);
fillPrelist(head.left, arr);
fillPrelist(head.right, arr);
}
public static HashMap<Node, Node> getParentMap(Node head) {
HashMap<Node, Node> map = new HashMap<>();
map.put(head, null);
fillParentMap(head, map);
return map;
}
public static void fillParentMap(Node head, HashMap<Node, Node> parentMap) {
if (head.left != null) {
parentMap.put(head.left, head);
fillParentMap(head.left, parentMap);
}
if (head.right != null) {
parentMap.put(head.right, head);
fillParentMap(head.right, parentMap);
}
}
public static int distance(HashMap<Node, Node> parentMap, Node o1, Node o2) {
HashSet<Node> o1Set = new HashSet<>();
Node cur = o1;
o1Set.add(cur);
while (parentMap.get(cur) != null) {
cur = parentMap.get(cur);
o1Set.add(cur);
}
cur = o2;
while (!o1Set.contains(cur)) {
cur = parentMap.get(cur);
}
Node lowestAncestor = cur;
cur = o1;
int distance1 = 1;
while (cur != lowestAncestor) {
cur = parentMap.get(cur);
distance1++;
}
cur = o2;
int distance2 = 1;
while (cur != lowestAncestor) {
cur = parentMap.get(cur);
distance2++;
}
return distance1 + distance2 - 1;
}
//方法二:递归套路
public static int maxDistance2(Node head) {
return process(head).maxDistance;
}
public static class Info {
public int maxDistance; //树的最大距离
public int height; //树的高度
public Info(int m, int h) {
this.maxDistance = m;
this.height = h;
}
}
//一定要返回Info,否则无法约束子树
public static Info process(Node x) {
//空树距离为0,值是好设置的
if (x == null) return new Info(0, 0);
//左树的信息
Info leftInfo = process(x.left);
//右树的信息
Info rightInfo = process(x.right);
//根据左右树的信息得出当前节点的信息
//当前节点为根的树的高度:max(左树高度,右树高度) + 1,1是当前节点这一层
int height = Math.max(leftInfo.height, rightInfo.height) + 1;
//可能性1:不经过x节点,单纯使用左树的最大距离
int p1 = leftInfo.maxDistance;
//可能性2:不经过x节点,单纯使用右树的最大距离
int p2 = rightInfo.maxDistance;
//可能性3:要经过x节点,左树的高度 + 右树的高度 + 1(自己)
int p3 = leftInfo.height + rightInfo.height + 1;
int maxDistance = Math.max(Math.max(p1, p2), p3);
return new Info(maxDistance, height);
}
//对数器
// for test
public static Node generateRandomBST(int maxLevel, int maxValue) {
return generate(1, maxLevel, maxValue);
}
// for test
public static Node generate(int level, int maxLevel, int maxValue) {
if (level > maxLevel || Math.random() < 0.5) {
return null;
}
Node head = new Node((int) (Math.random() * maxValue));
head.left = generate(level + 1, maxLevel, maxValue);
head.right = generate(level + 1, maxLevel, maxValue);
return head;
}
public static void main(String[] args) {
int maxLevel = 4;
int maxValue = 100;
int testTimes = 1000000;
for (int i = 0; i < testTimes; i++) {
Node head = generateRandomBST(maxLevel, maxValue);
if (maxDistance1(head) != maxDistance2(head)) {
System.out.println("Oops!");
}
}
System.out.println("finish!");
}
}