TensorFlow分布式训练、调试、TPU应用及计算机视觉入门
立即解锁
发布时间: 2025-08-30 00:50:15 阅读量: 15 订阅数: 26 AIGC 


Python构建AI高级指南
### TensorFlow 分布式训练、调试、TPU 应用及计算机视觉入门
#### 1. TensorFlow 分布式模型
在 TensorFlow 集群中进行分布式训练时,监督对象(supervisor object)的创建有所不同,有两个额外的初始化函数。以下是同步更新时监督对象的初始化代码:
```python
# SYNC: sv is initialized differently for sync update
sv = tf.train.Supervisor(is_chief=is_chief,
init_op = tf.global_variables_initializer(),
local_init_op = local_init_op,
ready_for_local_init_op = optimizer.ready_for_local_init_op,
global_step=global_step)
```
在训练的会话块中,如果是主工作任务,需要初始化同步变量并启动队列运行器:
```python
# SYNC: if block added to make it sync update
if is_chief:
mts.run(init_token_op)
sv.start_queue_runners(mts, [chief_queue_runner])
```
其余代码与异步更新相同。TensorFlow 支持分布式训练的库和函数在不断发展,使用时需留意新功能的添加或函数签名的更改。
#### 2. TensorFlow 模型调试
在构建和训练 TensorFlow 模型时,可能会遇到各种错误或模型表现不如预期的情况,例如损失和指标输出出现 NaN,或者损失在多次迭代后仍无改善。以下是 TensorFlow 提供的调试工具和技术:
- **使用 `tf.Session.run()` 获取张量值**:可以使用 `tf.Session.run()` 获取想要打印的张量值,这些值以 NumPy 数组的形式返回,可使用 Python 语句打印或记录。但该方法的最大缺点是计算图会执行从获取的张量开始的所有依赖路径,如果这些路径包含训练操作,会使训练前进一步或一个 epoch。大多数情况下,会执行整个图并获取所有需要调试和不需要调试的张量。`tf.Session.partial_run()` 可用于执行部分图,但它是一个高度实验性的 API,不适合生产使用。
- **使用 `tf.Print()` 打印张量值**:可以将张量包装在 `tf.Print()` 中,当包含 `tf.Print()` 节点的路径执行时,会在标准错误控制台打印其值。`tf.Print()` 函数的签名如下:
```python
tf.Print(
input_,
data,
message=None,
first_n=None,
summarize=None,
name=None
)
```
各参数含义如下:
| 参数 | 含义 |
| ---- | ---- |
| `input_` | 函数返回的张量,不做任何处理 |
| `data` | 要打印的张量列表 |
| `message` | 打印输出的前缀字符串 |
| `first_n` | 打印输出的步数;如果为负数,则只要路径执行就始终打印 |
| `summarize` | 从张量中打印的元素数量;默认只打印三个元素 |
以下是修改 MNIST MLP 模型添加打印语句的示例:
```python
model = tf.Print(input_=model,
data=[tf.argmax(model,1)],
message='y_hat=',
summarize=10,
first_n=5
)
```
运行代码后,在 Jupyter 控制台会输出类似如下内容:
```
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 0 0 7 0 0 0 0 0 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 7 7 1 8 7 2 7 7 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[4 8 0 6 1 8 1 0 7 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 0 1 0 0 0 0 5 7 5...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[9 2 2 8 8 6 6 1 7 7...]
```
`tf.Print()` 的唯一缺点是格式化功能有限。
- **使用 `tf.Assert()` 进行条件断言**:`tf.Assert()` 函数接受一个条件,如果条件为假,则打印给定的张量列表并抛出 `tf.errors.InvalidArgumentError`。其签名如下:
```python
tf.Assert(
condition,
data,
summarize=None,
name=None
)
```
断言操作不像 `tf.Print()` 函数那样位于图的路径中,为确保 `tf.Assert()` 操作执行,需要将其添加到依赖项中。例如,定义一个断言来检查所有输入是否为正数:
```python
assert_op = tf.Assert(tf.reduce_all(tf.greater_equal(x,0)),[x])
```
在定义模型时将 `assert_op` 添加到依赖项中:
```python
with tf.control_dependencies([assert_op]):
# x is input layer
layer = x
# add hidden layers
for i in range(num_layers):
layer = tf.nn.relu(tf.matmul(layer, w[i]) + b[i])
# add output layer
layer = tf.matmul(layer, w[num_layers]) + b[num_layers]
```
为测试代码,在第 5 个 epoch 后引入杂质:
```python
if epoch > 5:
X_batch = np.copy(X_batch)
X_batch[0,0]=-2
```
代码在前五个 epoch 运行正常,之后会抛出错误:
```
epoch: 0000 loss = 6.975991
epoch: 0001 loss = 2.246228
epoch: 0002 loss = 1.924571
epoch: 0003 loss = 1.745509
epoch: 0004 loss = 1.616791
epoch: 0005 loss = 1.520804
--------------------------------------------------------------
---
InvalidArgumentError Traceback (most recent call last)
...
InvalidArgumentError: assertion failed: [[-2 0 0]...]
...
```
除了 `tf.Assert()` 函数,TensorFlow 还提供了一些检查特定条件且语法简单的断言操作,如 `assert_equal`、`assert_greater` 等。
- **使用 TensorFlow 调试器(tfdbg)**:TensorFlow 调试器(tfdbg)的工作原理与其他流行的调试器(如 pdb 和 gdb)类似。使用调试器的一般过程如下:
1. 在代码中想要中断并检查变量的位置设置断点。
2. 以调试模式运行代码。
3. 当代码在断点处中断时,检查并继续下一步。
使用 tfdbg 的步骤如下:
1. 导入所需模块并将会话包装在调试器包装器中:
```python
from tensorflow.python import debug as tfd
with tfd.LocalCLIDebugWrapperSession(tf.Session()) as tfs:
```
2. 为会话对象附加一个过滤器,这相当于在其他调
0
0
复制全文
相关推荐









