TensorFlow 单元测试的一个简单例子

本文介绍了如何利用TensorFlow的tf.test.TestCase进行单元测试,包括assertAllEqual方法和session运行计算图节点的功能。通过示例展示测试全连接层和张量扩张塑形函数的测试用例,验证测试结果成功。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlow中的tf.test.TestCase类继承了unittest.TestCase类,用于对tensorflow代码进行单元测试。

tf.test.TestCase 提供了 assertAllEqual 用于判断两个numpy array具有完全相同的值,session方法来运行计算图结点,以及其他方法,具体请看链接

现在我们有如下的两个函数:

# Python3
import tensorflow as tf


def dense_layer(x, W, bias, activation=None):
  y = x @ W + bias
  if activation:
    return activation(y)
  else:
    return y


def expand_reshape_tensor(x, high, width):
  return tf.reshape(x, (high, width, 1, 1))

第一个函数就是一个全连接层,第二个函数用于对张量进行扩张塑形操作。

接下来我们创建UtilsTests类,继承tf.test.TestCase类,定义test_dense_layer方法对第一个函数进行测试,定义test_expand_reshape_tensor方法对第二个函数进行测试。

import tensorflow as tf
import utils


class UtilsTests(tf.test.TestCase):

  def test_dense_layer(self):
    x = tf.reshape(tf.range(27), (9, 3)) - 13
    W = tf.reshape(tf.range(9), (3, 3)) - 4
    bias = tf.range(3) - 1
    y1 = utils.dense_layer(x, W, bias)
    y2 = utils.dense_layer(x, W, bias, tf.nn.relu)
    with self.session() as sess:
      y1, y2 = sess.run((y1, y2))

    self.assertAllEqual(
      [[-13, -12, -11],
       [-10,  -9,  -8],
       [ -7,  -6,  -5],
       [ -4,  -3,  -2],
       [ -1,   0,   1],
       [  2,   3,   4],
       [  5,   6,   7],
       [  8,   9,  10],
       [ 11,  12,  13]], x) # 对x的值进行验证

    self.assertAllEqual(
      [[-4, -3, -2],
       [-1,  0,  1],
       [ 2,  3,  4]], W) # 对W的值进行验证

    self.assertAllEqual([-1,  0,  1], bias) # 验证bias的值

    self.assertAllEqual(
      [[ 41,   6, -29],
       [ 32,   6, -20],
       [ 23,   6, -11],
       [ 14,   6,  -2],
       [  5,   6,   7],
       [ -4,   6,  16],
       [-13,   6,  25],
       [-22,   6,  34],
       [-31,   6,  43]], y1) # 验证无激活函数的情况
    self.assertAllEqual(
      [[41,  6,  0],
       [32,  6,  0],
       [23,  6,  0],
       [14,  6,  0],
       [ 5,  6,  7],
       [ 0,  6, 16],
       [ 0,  6, 25],
       [ 0,  6, 34],
       [ 0,  6, 43]], y2) # 验证有激活函数的情况

  def test_expand_reshape_tensor(self):
    x = tf.range(9)
    y = utils.expand_reshape_tensor(x, 3, 3)
    shape = tf.shape(y)
    with self.session() as sess:
      shape = sess.run(shape)
    self.assertAllEqual(shape, (3, 3, 1, 1)) # 验证是否塑形成功


if __name__ == "__main__":
  tf.test.main()  # 运行测试样例

接下来,我们得到这样的结果,测试了三个函数,其中一个跳过,也就是测试了两个函数。测试成功。

----------------------------------------------------------------------
Ran 3 tests in 1.167s

OK (skipped=1)
[Finished in 4.6s]

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值