获取XGBoost中树模型的最大深度

本文介绍如何在XGBoost中查看树模型深度,即使无get_depth方法,通过绘图或解析json输出来获取信息,包括树深度计算和可视化方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文讲无论是否设置max_depth,怎么样查看XGBoost中各棵树模型的深度。

查阅XGBoost文档可以发现,XGBoost并未提供像sklearn一样的get_depth()方法
所以解决方法如下:

1、绘图、查看图像得出(不够自动化)
import xgboost as xgb
from sklearn.datasets import make_classification
X,y = make_classification(random_state=99)

# 上面的数据因人而异,这里构建的模型才是我们需要的
clf = xgb.XGBModel(objective='binary:logistic',n_estimators = 3,max_depth = 7)
clf.fit(X,y)

然后我们提取对象并将树信息转换为数据表:

booster = clf.get_booster()			# 如果你的clf已经是Booster了那就不需要这句
tree_df = booster.trees_to_dataframe()

如果想查看第0棵树的信息,可以使用如下语句

tree_df[tree_df['Tree'] == 0]
    Tree    Node    ID  Feature Split   Yes No  Missing Gain    Cover
0   0   0   0-0 f11 -0.233068   0-1 0-2 0-1 48.161629   25.00
1   0   1   0-1 f1  -1.081945   0-3 0-4 0-3 0.054384    9.25
2   0   2   0-2 f14 0.480458    0-5 0-6 0-5 8.410727    15.75
3   0   3   0-3 Leaf    NaN NaN NaN NaN -0.150000   1.00
4   0   4   0-4 Leaf    NaN NaN NaN NaN -0.535135   8.25
5   0   5   0-5 f18 0.261421    0-7 0-8 0-7 5.638095    6.50
6   0   6   0-6 f9  -1.585489   0-9 0-10    0-9 0.727795    9.25
7   0   7   0-7 f18 -0.640538   0-11    0-12    0-11    4.342857    4.00
8   0   8   0-8 f0  0.072811    0-13    0-14    0-13    1.028571    2.50
9   0   9   0-9 Leaf    NaN NaN NaN NaN 0.163636    1.75
10  0   10  0-10    Leaf    NaN NaN NaN NaN 0.529412    7.50
11  0   11  0-11    Leaf    NaN NaN NaN NaN -0.120000   1.50
12  0   12  0-12    Leaf    NaN NaN NaN NaN 0.428571    2.50
13  0   13  0-13    Leaf    NaN NaN NaN NaN -0.000000   1.00
14  0   14  0-14    Leaf    NaN NaN NaN NaN -0.360000   1.50

可视化第0棵树。决策树的深度是从根到叶的分裂次数,所以这棵树的深度为 4 :
plot_tree需要安装graphviz,未安装可跳过该步骤,或参考这两篇《graphviz下载 安装》《GraphViz’s executables not found的解决方法》

xgb.plotting.plot_tree(booster, num_trees=0)
2、遍历 json 输出并计算每棵树的深度
def item_generator(json_input, lookup_key):
    if isinstance(json_input, dict):
        for k, v in json_input.items():
            if k == lookup_key:
                yield v
            else:
                yield from item_generator(v, lookup_key)
    elif isinstance(json_input, list):
        for item in json_input:
            yield from item_generator(item, lookup_key)

def tree_depth(json_text):
    json_input = json.loads(json_text)
    return max(list(item_generator(json_input, 'depth'))) + 1

import json		# json库一定要import进来
booster = clf.get_booster()			# 如果你的clf已经是Booster了那就不需要这句,下一句的booster改为clf就行
[tree_depth(x) for x in booster.get_dump(dump_format="json")]

参考:
《How to find the actual tree depth for XG-Boost regressor model in python?》
《recursive iteration through nested json for specific key in python》
《AttributeError: ‘Booster’ object has no attribute ‘get_booster’ #6449》
《NameError: name ‘json’ is not defined》

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值