Tensorflow tf.image.decode_image 返回无 shape

本文介绍在TensorFlow中使用tf.data.Dataset读取图像的方法,针对不同格式图像(如png和jpg)的处理技巧,包括decode_jpeg、decode_png及resize_images函数的正确应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在做语义分割的时候会经过读取图像的步骤,根据 TensorFlow 官方教程 我使用了 tf.data.Dataset 这个 API。
根据官方读取图像的例子,一开始我的代码如下:

def load_image(filename, resized_shape):
	'''
	:param filename: 图像文件名
	:param resized_shape: 缩放后图像大小
	'''
    image = tf.read_file(filename)
    image = tf.image.decode_png(image)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_images(
        image, size=resized_shape, method=tf.image.ResizeMethod.AREA)

    return image

因为训练图像是 png 格式,因此使用了 decode_png
这里还使用了 resize_images 函数,因为在进行小范围测试时将图像缩小可以加快训练速度。

至此都是没有问题的,但是最近处理数据集的时候遇到了 jpg 格式的训练图像,之前看到 TensorFlow 有 decode_image 这个函数,好像可以自动判定图像格式然后 decode。

但是使用了之后报错了,这个错误是在 resize_images 的时候发生的:

ValueError: 'images' contains no shape.

根据 decode_image 的官方文档:

Returns:

Tensor with type uint8 with shape [height, width, num_channels] for BMP
, JPEG, and PNG images and shape [num_frames, height, width, 3] for GIF images.

返回的 Tensor 是有形状的,但是从调试中可以看到 shape 是 unknown 的,所以返回应该是没有形状,google 了一下也没有发现能说清这个问题的,因此这个函数暂时用不了了。
在这里插入图片描述
不过我发现了一种解决办法:

def load_image(filename, resized_shape):
	'''
	:param filename: 图像文件名
	:param resized_shape: 缩放后图像大小
	'''
    image = tf.read_file(filename)
    
    image = tf.cond(
        tf.image.is_jpeg(image),
        lambda: tf.image.decode_jpeg(image),
        lambda: tf.image.decode_png(image))
        
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_images(
        image, size=resized_shape, method=tf.image.ResizeMethod.AREA)

    return image

就是使用 tf.cond() 函数,通过 tf.image.is_jpeg(image) 判断图像是不是 jpg 格式,如果是,就执行 decode_jpeg,如果不是,就执行 decode_png。因为在语义分割中,大部分训练图像都是 jpgpng,很少会有其他格式的图像,因此用一个条件就够了。

其实 decode_image 函数里面就是使用的 tf.cond 来判断的,判断之后 decode,然后再 convert_image_dtype,至于为什么返回没有 shape 我也不清楚。
在这里插入图片描述

decode_image 函数片段

还有一种方法是使用 tf.Tensor.set_shape ,我的代码里不方便使用这个方法, 所以就没有尝试。

<think>好的,我现在要仔细分析用户提供的这段TensorFlow代码,并逐行解释其功能和方法。首先,我需要理解这段代码的整体结构,它是用于生成输入数据的生产者,可能用于训练一个深度学习模型,尤其是涉及图像处理的任务。 首先,函数定义是input_producer,接收batch_size参数,默认10。接下来内部定义了一个read_data函数,负责读取图像数据。这里用到了tf.read_file和tf.image.decode_image来读取和解码图片。路径拼接使用了tf.string_join,可能将目录和文件名组合起来。然后调用preprocessing函数对图像进行处理。 preprocessing函数将图像转换为float32并归一化到[0,1],根据模型类型决定是否转换为灰度图。然后是随机裁剪,使用tf.random_crop,将两张图片堆叠后裁剪,再拆分成单独的图片。这一步可能需要检查crop_size和chns是否正确设置。 接下来是with tf.variable_scope('input')块,里面将数据列表转换为张量List_all,分割成输入列表和真实值列表。然后使用tf.train.slice_input_producer创建输入队列,capacity是队列容量。之后调用read_data获取处理后的图像,再通过tf.train.batch生成批次数据。 需要注意TensorFlow 1.x和2.x的区别,因为代码中使用的是tf.variable_scope和slice_input_producer,这在TF2中可能不再推荐使用。用户可能需要调整代码以适应版本差异,比如使用tf.data API代替旧的队列方式。 此外,decode_image返回的是未知形状的张量,可能需要设置形状以确保后续操作的正确性。预处理中的随机裁剪是否正确应用了相同的裁剪区域给两张图片,这对某些任务(如图像到图像的转换)很重要,确保对应区域一致。 最后,函数返回的是批处理后的输入和真实值张量,供训练使用。需要指出潜在的问题和改进建议,比如形状问题、TF版本兼容性,以及可能的预处理步骤调整。</think>以下是对代码的逐行解析及方法详解,按照执行流程分步说明: --- ### 一、函数定义 ```python def input_producer(self, batch_size=10): ``` - **作用**:定义数据输入管道,生成批量训练数据 - **参数**:`batch_size=10` 指定每个批次包含10个样本 --- ### 二、数据读取函数 ```python def read_data(): img_a = tf.image.decode_image(tf.read_file(tf.string_join(['./training_set/', self.data_queue[0]])), channels=3) img_b = tf.image.decode_image(tf.read_file(tf.string_join(['./training_set/', self.data_queue[1]])), channels=3) img_a, img_b = preprocessing([img_a, img_b]) return img_a, img_b ``` 1. `tf.string_join()`:拼接文件路径(例:`./training_set/filename.jpg`) 2. `tf.read_file()`:读取图片二进制数据 3. `tf.image.decode_image()`:解码JPEG/PNG图片为uint8张量 - `channels=3` 保留RGB三通道 4. `preprocessing()`:对两张图片进行联合预处理 --- ### 三、预处理函数 ```python def preprocessing(imgs): imgs = [tf.cast(img, tf.float32) / 255.0 for img in imgs] # 归一化到[0,1] if self.args.model != 'color': imgs = [tf.image.rgb_to_grayscale(img) for img in imgs] # 转为灰度图 img_crop = tf.unstack( tf.random_crop( tf.stack(imgs, axis=0), [2, self.crop_size, self.crop_size, self.chns] ), axis=0 ) return img_crop ``` 1. **归一化**:`tf.cast`转换数据类型后除以255 2. **灰度转换**:根据模型类型决定是否转为单通道 3. **联合随机裁剪**: - `tf.stack(imgs, axis=0)`:将两张图片沿第0维度堆叠(形状变为$[2, H, W, C]$) - `tf.random_crop()`:执行相同区域的随机裁剪 - `tf.unstack(axis=0)`:将堆叠后的张量拆分为两个独立张量 --- ### 四、输入管道构建 ```python with tf.variable_scope('input'): List_all = tf.convert_to_tensor(self.data_list, dtype=tf.string) gt_list = List_all[:, 0] # 真实值文件名列表 in_list = List_all[:, 1] # 输入文件名列表 self.data_queue = tf.train.slice_input_producer( [in_list, gt_list], capacity=20 ) image_in, image_gt = read_data() batch_in, batch_gt = tf.train.batch( [image_in, image_gt], batch_size=batch_size, num_threads=8, capacity=20 ) ``` 1. **数据列表处理**: - `self.data_list` 应为二维列表,每行包含一对文件名 `[gt_file, in_file]` 2. **队列创建**: - `slice_input_producer`:创建文件名队列,实现乱序读取 - `capacity=20`:队列最大容量 3. **批次生成**: - `tf.train.batch`:将预处理后的数据打包成批次 - `num_threads=8`:使用8个线程加速数据准备 --- ### 五、关键方法详解 1. **`tf.train.slice_input_producer`** - 功能:创建文件名输入队列,支持自动重排(shuffling) - 参数:`capacity` 控制预加载样本数量 2. **`tf.random_crop`** - 数学表达:从尺寸为$H \times W$的原图中随机裁剪子区域 - 输出尺寸:$[\text{crop\_size} \times \text{crop\_size} \times \text{chns}]$ 3. **`tf.train.batch`** - 队列机制:维持一个队列,当样本数达到`batch_size`时出队 - 动态填充:当队列元素不足时自动等待 --- ### 六、潜在问题及改进 1. **形状不确定性**:`decode_image`返回shape可能为`None`,建议添加`set_shape` 2. **TensorFlow版本**:此代码基于TF1.x设计,TF2.x需使用`tf.data` API重构 3. **数据增强**:可添加翻转、旋转等增强操作 代码整体实现了图像对的动态读取、联合预处理和批量生成,适用于图像翻译任务(如Pix2Pix)。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值