前面5篇,上来就以实践的方式学了MNIST手写数字识别的代码,发现在深入MNIST之前,还需要强化一下基本功,就是补一下之前官方说的,TensorFlow计算的优点在哪里这个坑。也是从最基本了解TensorFlow的运作方式。
TensorFlow,TensorFlow。既然Google起了这么一个名字,总是有原因的。 先从Tensor说起。
引用官方一句话:TensorFlow 程序使用 tensor 数据结构来代表所有的数据, 计算图中, 操作间传递的数据都是 tensor。
一目了然,豁然开朗。那么什么是tensor呢?翻译来说就是张量,百度一波张量看完内心是崩溃的。于是我按照矩阵来理解,就是TensorFlow中一个简单的常数都是一个1×1的矩阵。(有说张量是几何定义,矩阵是数学定义====>才疏学浅,在当前不影响学习进度情况下我选择死亡忽视,也看到说0维的叫做常量;一维的叫做向量;二维的叫做矩阵;≥3维的叫做张量。 嗯,我想吃张亮麻辣烫了,(¯﹃¯))
官方文档介绍了一波TensorFlow的运作方式,看完怎么说呢……云里雾里。。。还是老办法,结合代码来说:
首先,TensorFlow在编程时分为两个阶段,构建阶段和执行阶段。
Section one:构建阶段
在构建阶段,我们把需要用到的数据(矩阵)都一个个摆好,代码如下:
import tensorflow as tf # 这一步没的说,要编写TensorFlow的程序,就要先调动相关的包
matrix1 = tf.constant([[3., 3.]]) # 这里,就是构建我们需要的数据了(以下三句都是)
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
需要说明的是,在TensorFlow中,我们构建的数据叫做一个节点(op),节点包括输入、输出和计算方式。emmm。。。我是按照机电传动的理论来理解这个节点的概念,想象成一个神经元也是可以的((* ̄︶ ̄))。
因此,上述matrix1、2与product分别是三个节点。
你可能会问,不是说需要输入吗?为什么matrix1、2没有?(produc中matrix1、2为输入内容)因为matrix为一个源节点,它们定义了两个常数的矩阵。那么tf.constant构造两个数组究竟长什么样呢?
我们在上述语句后面加入
sess = tf.Session()
print(sess.run(matrix1))
print(sess.run(matrix2))
【先忽视掉sess这句】编译执行,就发现输出结果为:
一个一行两列元素都为3的矩阵和一个两行一列元素都为2的矩阵。
【当然,如果用
,这里的shape类型我猜就是一行二列、二行一列,我的猜功是不是很强 (✺ω✺)】
而tf.matmul就比较简单了,是将输入的两个矩阵相乘的一个节点。
完成了构建阶段,我们进入执行阶段。
Section two: 执行阶段
TensorFlow的执行需在会话(Session)中进行(我也不知道为啥,还是从简,我将会话理解为一个块(block),一段代码在这块中运行,这个块中产生的结果对另一个块无效)。
于是我们先启动对话:
sess = tf.Session()
随后,引用官方文档,说的比较明白
# 调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.
# 上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回
# 矩阵乘法 op 的输出.
#
# 整个执行过程是自动化的, 会话负责传递 op 所需的全部输入. op 通常是并发执行的.
#
# 函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op) 的执行.
#
# 返回值 'result' 是一个 numpy `ndarray` 对象.
result = sess.run(product)
print result
# ==> [[ 12.]]
# 任务完成, 关闭会话.
sess.close()
至此,程序结束,算出来是[12.]这个结果,这里也是TensorFlow与众不同白里透红的地方,它的结果也还是个矩阵。
末尾,再次附上复制就可以运行的程序:
import tensorflow as tf
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
sess = tf.Session()
print(sess.run(matrix1))
print(sess.run(matrix2))
print(sess.run([product]))
result = sess.run(product)
print(result)