《统计学习方法》笔记——k近邻(二):kd树

本文介绍了kd树在k近邻算法中的应用,通过构建平衡kd树并演示Python实现,展示了如何通过结构化存储减少计算距离次数,以优化大规模数据的近邻查找。重点讲解了kd树的构造过程和搜索算法,包括如何在kd树中找到目标点的最近邻点。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上篇记录了k近邻算法原理及线性扫描实现。当样本量比较大时,计算比较耗时,为提高k近邻的搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数。方法之一就是kd树。
kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间的一个划分。

构造平衡kd树的算法

输入:k维空间数据集T={x1,x2,...,xN}T = \{x_1, x_2,...,x_N\}T={x1,x2,...,xN},其中xi=(xi(1),xi(2),...,xi(k))Tx_i = (x_i^{(1)}, x_i^{(2)},...,x_i^{(k)})^Txi=(xi(1),xi(2),...,xi(k))Ti=1,2,...,Ni = 1, 2, ..., Ni=1,2,...,N
输出:kd树。
(1)开始:构造根结点,根结点对应于包含TTTkkk维空间的超矩形区域。
   选择x(1)x^{(1)}x(1)为坐标轴,以TTT中所有实例的x(1)x^{(1)}x(1)坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴x(1)x^{(1)}x(1)垂直的超平面实现。
   由根结点生成深度为1的左右子结点:左子结点对应坐标x(1)x^{(1)}x(1)小于切分点的子区域,右子结点对应于坐标x(1)x^{(1)}x(1)大于切分点的子区域。
   将落在切分超平面上的实例点保存在根结点。
(2)重复:对深度为jjj的结点,选择x(l)x^{(l)}x(l)为切点的坐标轴,l=j(mod k)+1l = j(mod\ k) + 1l=j(mod k)+1,以该结点的区域中所有实例的x(l)x^{(l)}x(l)坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴x(l)x^{(l)}x(l)垂直的超平面实现。
   由该结点生成深度为j+1j + 1j+1的左、右子结点:左子结点对应坐标x(l)x^{(l)}x(l)小于切分点的子区域,右子结点对应坐标x(l)x^{(l)}x(l)大于切分点的子区域。
   将落在切分超平面上的实例点保存在该结点。
(3)直到两个子区域没有实例存在时停止。从而形成kd树的区域划分。

构造kd树Python实现

def createTree(dataset, layer=0, featureNume=2, attribute='root'):   
    # 拷贝数据
    datasetCopy = dataset[:]
    length = len(datasetCopy)
    # 选定某个特征
    feature = layer % featureNume
    layer += 1
    if length == 0:
        return None
    elif length == 1:
        return {'value':datasetCopy[0], 'right':None, 'left':None, 'layer':layer, 'feature':feature,'attribute':attribute}
    else:          
        # 根据某个特征排序
        datasetCopy.sort(key = lambda x: x[feature])
        mid = length // 2      
        return {'value':datasetCopy[mid], 'right':createTree(datasetCopy[mid+1:],layer,attribute='right'), 'left':createTree(datasetCopy[:mid],layer,attribute='left'), 'layer':layer, 'feature':feature,'attribute':attribute}

为了展示方便,我用字典存储这棵树,也可以自己构造一个类用来存储这棵树:

class Node(object):
    def __init__(self, value, left=None, right=None, layer=0, feature=0,attribute='root'):
        self.value = value
        self.left = left
        self.right = right
        self.layer = layer
        self.feature = feature
        self.attribute = attribute

我在这里统一用字典来处理了。
用书上p54页例3.2的例子来构造一个kd树。
在这里插入图片描述

将这棵树画出来:

from graphviz import Digraph
class DrawTree(object):
    def __init__(self,tree):
        self.num = 1
        self.tree = tree
        
    def drawTree(self, outfile='kd Tree'):
        dot = Digraph(name="kd Tree", comment="kd Tree", format="png")
        if self.tree:
            dot.node(name=str(self.num), label=str(self.tree['value']))

            def draw(node, dot):
                p = str(self.num)
                if node is None:
                    return dot
                else:
                    nodeLeft = node['left']
                    nodeRight = node['right']
                    if nodeLeft:
                        dot.node(name=str(self.num + 1), label=str(nodeLeft['value']))
                        dot.edge(p, str(self.num+1), label=str(node['feature']))
                        self.num += 1
                        draw(nodeLeft, dot)
                    if nodeRight:
                        dot.node(name=str(self.num + 1), label=str(nodeRight['value']))
                        dot.edge(p, str(self.num+1), label=str(node['feature']))
                        self.num += 1
                        draw(nodeRight, dot)                    
            draw(self.tree, dot)
        dot.view(filename=outfile)

在这里插入图片描述

用kd树的最近邻搜索算法

输入:已构造的kd树,目标点x;
输出:x的最近邻。
(1)在kd树中找出包含目标点x的叶结点:从根结点出发,递归地向下访问kd树。若目标点x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止。
(2)以此叶结点为"当前最近点"。
(3)递归地向上回退,在每个结点进行以下操作:
 (a)如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为"当前最近点"。
 (b)当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一子结点对应的区域是否有更近的点。
 具体地,检查另一子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。
 如果相交,可能在另一个子结点对应的区域内存在距目标点更近的点,移动到另一子结点。接着,递归地进行最近邻搜索;
 如果不相交,向上回退。
(4)当回退到根结点时,搜索结束。最后的"当前最近点"即为x的最近邻点。

kd树最近邻搜索Python实现

首先实现点与点之间的距离

import numpy as np
def point_distance(x, y, p='inf'):
    if isinstance(x, list):
        x = np.array(x)
    if isinstance(y, list):
        y = np.array(y)
    if p == 'inf':
        return np.max(np.abs(x-y))   
    assert isinstance(p, int)
    assert len(x.shape) == len(y.shape)
    assert len(x.shape) ==1    
    return np.power(np.sum(np.power(np.abs(x-y), p)), 1.0/p)

然后实现最近邻搜索

def findNearest(tree, point): 
    nodes = []
    attributes = []
    while tree:
        feature = tree['feature']
        value = tree['value']
        featureValue = point[feature]
        nodes.append(tree)
        if featureValue > value[feature]:
            tree = tree['right']
        else:
            tree = tree['left'] 
    nearlistNode = nodes.pop()
    nearlistPoint = nearlistNode['value']
    nearlistDistance = point_distance(nearlistPoint, point, 2)
    nodeAttribute = nearlistNode['attribute']
    print("最近节点:{0},最近距离:{1}".format(nearlistPoint,nearlistDistance))
    while nodes:
        # a如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为"当前最近点"
        thisNode = nodes.pop()
        thisPoint = thisNode['value']
        thisFeature = thisNode['feature']
        print("当前节点:{}".format(thisPoint))
        thisDistance = point_distance(thisPoint, point, 2)
        if thisDistance < nearlistDistance:
            nearlistDistance = thisDistance
            nearlistPoint = thisPoint
            print("更新最近节点:{0},最近距离:{1}".format(nearlistPoint,nearlistDistance))
            if thisNode['left'] or thisNode['right']:
                # b
                facePoint = point.copy()
                facePoint[thisFeature] = thisPoint[thisFeature]
                faceDistance = point_distance(facePoint, point, 2)
                print("点到超平面的距离:{}".format(faceDistance))             
                # 如果相交,进入另一子结点
                if faceDistance < nearlistDistance:
                    nodes.append(thisNode)
                    if nodeAttribute == 'left':
                        nodes.append(thisNode['right'])
                    elif nodeAttribute == 'right':
                        nodes.append(thisNode['left'])
                # 如果不相交,继续回退
                else:
                    nodeAttribute = thisNode['attribute']
        else:
            print("继续回退,最近节点:{0},最近距离:{1}".format(nearlistPoint,nearlistDistance))
    return nearlistPoint, nearlistDistance       

在这里插入图片描述
在以上kd树中搜索[2,4.5]的最近邻点是[2,3],最近距离是1.5。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值