上篇记录了k近邻算法原理及线性扫描实现。当样本量比较大时,计算比较耗时,为提高k近邻的搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数。方法之一就是kd树。
kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间的一个划分。
构造平衡kd树的算法
输入:k维空间数据集T={x1,x2,...,xN}T = \{x_1, x_2,...,x_N\}T={x1,x2,...,xN},其中xi=(xi(1),xi(2),...,xi(k))Tx_i = (x_i^{(1)}, x_i^{(2)},...,x_i^{(k)})^Txi=(xi(1),xi(2),...,xi(k))T,i=1,2,...,Ni = 1, 2, ..., Ni=1,2,...,N
输出:kd树。
(1)开始:构造根结点,根结点对应于包含TTT的kkk维空间的超矩形区域。
选择x(1)x^{(1)}x(1)为坐标轴,以TTT中所有实例的x(1)x^{(1)}x(1)坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴x(1)x^{(1)}x(1)垂直的超平面实现。
由根结点生成深度为1的左右子结点:左子结点对应坐标x(1)x^{(1)}x(1)小于切分点的子区域,右子结点对应于坐标x(1)x^{(1)}x(1)大于切分点的子区域。
将落在切分超平面上的实例点保存在根结点。
(2)重复:对深度为jjj的结点,选择x(l)x^{(l)}x(l)为切点的坐标轴,l=j(mod k)+1l = j(mod\ k) + 1l=j(mod k)+1,以该结点的区域中所有实例的x(l)x^{(l)}x(l)坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴x(l)x^{(l)}x(l)垂直的超平面实现。
由该结点生成深度为j+1j + 1j+1的左、右子结点:左子结点对应坐标x(l)x^{(l)}x(l)小于切分点的子区域,右子结点对应坐标x(l)x^{(l)}x(l)大于切分点的子区域。
将落在切分超平面上的实例点保存在该结点。
(3)直到两个子区域没有实例存在时停止。从而形成kd树的区域划分。
构造kd树Python实现
def createTree(dataset, layer=0, featureNume=2, attribute='root'):
# 拷贝数据
datasetCopy = dataset[:]
length = len(datasetCopy)
# 选定某个特征
feature = layer % featureNume
layer += 1
if length == 0:
return None
elif length == 1:
return {'value':datasetCopy[0], 'right':None, 'left':None, 'layer':layer, 'feature':feature,'attribute':attribute}
else:
# 根据某个特征排序
datasetCopy.sort(key = lambda x: x[feature])
mid = length // 2
return {'value':datasetCopy[mid], 'right':createTree(datasetCopy[mid+1:],layer,attribute='right'), 'left':createTree(datasetCopy[:mid],layer,attribute='left'), 'layer':layer, 'feature':feature,'attribute':attribute}
为了展示方便,我用字典存储这棵树,也可以自己构造一个类用来存储这棵树:
class Node(object):
def __init__(self, value, left=None, right=None, layer=0, feature=0,attribute='root'):
self.value = value
self.left = left
self.right = right
self.layer = layer
self.feature = feature
self.attribute = attribute
我在这里统一用字典来处理了。
用书上p54页例3.2的例子来构造一个kd树。
将这棵树画出来:
from graphviz import Digraph
class DrawTree(object):
def __init__(self,tree):
self.num = 1
self.tree = tree
def drawTree(self, outfile='kd Tree'):
dot = Digraph(name="kd Tree", comment="kd Tree", format="png")
if self.tree:
dot.node(name=str(self.num), label=str(self.tree['value']))
def draw(node, dot):
p = str(self.num)
if node is None:
return dot
else:
nodeLeft = node['left']
nodeRight = node['right']
if nodeLeft:
dot.node(name=str(self.num + 1), label=str(nodeLeft['value']))
dot.edge(p, str(self.num+1), label=str(node['feature']))
self.num += 1
draw(nodeLeft, dot)
if nodeRight:
dot.node(name=str(self.num + 1), label=str(nodeRight['value']))
dot.edge(p, str(self.num+1), label=str(node['feature']))
self.num += 1
draw(nodeRight, dot)
draw(self.tree, dot)
dot.view(filename=outfile)
用kd树的最近邻搜索算法
输入:已构造的kd树,目标点x;
输出:x的最近邻。
(1)在kd树中找出包含目标点x的叶结点:从根结点出发,递归地向下访问kd树。若目标点x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止。
(2)以此叶结点为"当前最近点"。
(3)递归地向上回退,在每个结点进行以下操作:
(a)如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为"当前最近点"。
(b)当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一子结点对应的区域是否有更近的点。
具体地,检查另一子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。
如果相交,可能在另一个子结点对应的区域内存在距目标点更近的点,移动到另一子结点。接着,递归地进行最近邻搜索;
如果不相交,向上回退。
(4)当回退到根结点时,搜索结束。最后的"当前最近点"即为x的最近邻点。
kd树最近邻搜索Python实现
首先实现点与点之间的距离
import numpy as np
def point_distance(x, y, p='inf'):
if isinstance(x, list):
x = np.array(x)
if isinstance(y, list):
y = np.array(y)
if p == 'inf':
return np.max(np.abs(x-y))
assert isinstance(p, int)
assert len(x.shape) == len(y.shape)
assert len(x.shape) ==1
return np.power(np.sum(np.power(np.abs(x-y), p)), 1.0/p)
然后实现最近邻搜索
def findNearest(tree, point):
nodes = []
attributes = []
while tree:
feature = tree['feature']
value = tree['value']
featureValue = point[feature]
nodes.append(tree)
if featureValue > value[feature]:
tree = tree['right']
else:
tree = tree['left']
nearlistNode = nodes.pop()
nearlistPoint = nearlistNode['value']
nearlistDistance = point_distance(nearlistPoint, point, 2)
nodeAttribute = nearlistNode['attribute']
print("最近节点:{0},最近距离:{1}".format(nearlistPoint,nearlistDistance))
while nodes:
# a如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为"当前最近点"
thisNode = nodes.pop()
thisPoint = thisNode['value']
thisFeature = thisNode['feature']
print("当前节点:{}".format(thisPoint))
thisDistance = point_distance(thisPoint, point, 2)
if thisDistance < nearlistDistance:
nearlistDistance = thisDistance
nearlistPoint = thisPoint
print("更新最近节点:{0},最近距离:{1}".format(nearlistPoint,nearlistDistance))
if thisNode['left'] or thisNode['right']:
# b
facePoint = point.copy()
facePoint[thisFeature] = thisPoint[thisFeature]
faceDistance = point_distance(facePoint, point, 2)
print("点到超平面的距离:{}".format(faceDistance))
# 如果相交,进入另一子结点
if faceDistance < nearlistDistance:
nodes.append(thisNode)
if nodeAttribute == 'left':
nodes.append(thisNode['right'])
elif nodeAttribute == 'right':
nodes.append(thisNode['left'])
# 如果不相交,继续回退
else:
nodeAttribute = thisNode['attribute']
else:
print("继续回退,最近节点:{0},最近距离:{1}".format(nearlistPoint,nearlistDistance))
return nearlistPoint, nearlistDistance
在以上kd树中搜索[2,4.5]的最近邻点是[2,3],最近距离是1.5。