最近在看yolov3 的源码,在看yolo_loss的时候遇到了一个卡点,就是将真是标注的box终点坐标转换
到anchor点的坐标
true_xy = true_xy * tf.cast(grid_size, tf.float32) - tf.cast(grid, tf.float32)
raw_true_xy = y_true[l][..., :2] * grid_shapes[l][:] - grid
import tensorflow as tf
import numpy as np
#anchor box=13*13
grid_size=13
grid = tf.meshgrid(tf.range(grid_size), tf.range(grid_size))
grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2)
#batch_size=8 box=13*13 ,每一种规格的anchor box 对应3个box ,中点坐标是2维
T_xy=np.zeros([8, 13, 13, 3, 2])*1.0
T_xy[6,4,2]=[0.309,0.46212122]
T_xy=tf.constant(T_xy,dtype=tf.float32)
true_xy = T_xy * tf.cast(grid_size, tf.float32) - tf.cast(grid, tf.float32)
print(true_xy)
[[[[ 0. 0.]
[ 0. 0.]
[ 0. 0.]]
[[ -1. 0.]
[ -1. 0.]
[ -1. 0.]]
[[ -2. 0.]
[ -2. 0.]
[ -2. 0.]]
...
...
[[-10. -12.]
[-10. -12.]
[-10. -12.]]
[[-11. -12.]
[-11. -12.]
[-11. -12.]]
[[-12. -12.]
[-12. -12.]
[-12. -12.]]]]], shape=(8, 13, 13, 3, 2), dtype=float32)
xy_loss = obj_mask * box_loss_scale * \
tf.reduce_sum(tf.square(true_xy - pred_xy), axis=-1)
obj_mask=np.zeros([8, 13, 13, 3])
obj_mask[0,6,4,2]=1
obj_mask=tf.constant(obj_mask,dtype=tf.float32)