图像分类:从迁移学习到实用技巧
立即解锁
发布时间: 2025-09-01 01:57:36 阅读量: 2 订阅数: 11 AIGC 

### 图像分类:从迁移学习到实用技巧
#### 1. 年龄与性别预测及 `torch_snippets` 库
在图像分类任务中,年龄和性别预测是一个有趣的应用场景。之前的代码虽然能实现一次性对年龄和性别进行预测,但存在不稳定性,年龄值会因图像的不同方向和光照条件而有较大变化。在这种情况下,数据增强就显得尤为重要。此外,我们还可以通过只提取面部区域来训练模型,避免背景信息对年龄和性别的计算产生干扰。
在学习迁移学习、预训练架构以及如何在不同用例中应用它们的过程中,我们发现代码往往较长,需要手动导入大量包、创建空列表来记录指标,还需不断读取和显示图像进行调试。为了解决这个问题,`torch_snippets` 库应运而生。
以下是使用 `torch_snippets` 库进行年龄和性别预测的详细步骤:
1. **安装并加载库**:
```python
!pip install torch_snippets
from torch_snippets import *
```
该库能让我们加载所有重要的 `torch` 模块和实用工具,如 `NumPy`、`pandas`、`Matplotlib`、`Glob`、`Os` 等。
2. **下载数据并创建数据集类**:
```python
class GenderAgeClass(Dataset):
...
def __getitem__(self, ix):
...
age = f.age
im = read(file, 1)
return im, age, gen
def preprocess_image(self, im):
im = resize(im, IMAGE_SIZE)
im = torch.tensor(im).permute(2,0,1)
...
```
这里的 `im = read(file, 1)` 将 `cv2.imread` 和 `cv2.COLOR_BGR2RGB` 封装成了一个函数调用,“1” 表示以彩色图像读取,若不指定则默认加载黑白图像。同时,还有封装了 `cv2.resize` 的 `resize` 函数。
3. **指定训练和验证数据集并查看样本图像**:
```python
trn = GenderAgeClass(trn_df)
val = GenderAgeClass(val_df)
train_loader = DataLoader(trn, batch_size=32, shuffle=rue, \
drop_last=True, collate_fn=trn.collate_fn)
test_loader = DataLoader(val, batch_siz=32, collate_fn=val.collate_fn)
im, gen, age = trn[0]
show(im, title=f'Gender: {gen}\nAge: {age}', sz=5)
```
`show` 函数将 `import matplotlib.pyplot as plt` 和 `plt.imshow` 封装成了一个函数,它可以处理 GPU 上的 `torch` 数组,无论图像的通道维度是在第一个还是最后一个。`title` 关键字可为图像添加标题,`sz` 关键字可根据传入的整数值调整图像大小。
4. **创建数据加载器并检查张量**:
```python
train_loader = DataLoader(trn, batch_size=32, shuffleTrue, \
drop_last=True, collate_fn=trn.collate_fn)
test_loader = DataLoader(val, batch_sie=32, \
collate_fn=val.collate_fn)
ims, gens, ages = next(iter(train_loader))
inspect(ims, gens, ages)
```
`inspect` 函数可以检查张量的数据类型、最小值、平均值、最大值和形状,它可以接受任意数量的张量输入。输出示例如下:
```plaintext
============================================================
Tensor Shape: torch.Size([32, 3, 224, 224]) Min: -2.118 Max: 2.640 Mean: 0.133 dtype: torch.float32
============================================================
Tensor Shape: torch.Size([32]) Min: 0.000 Max: 1.000 Mean: 0.594 dtype: torch.float32
============================================================
Tensor Shape: torch.Size([32]) Min: 0.087 Max: 0.925 Mean: 0.400 dtype: torch.float32
============================================================
```
5. **创建模型、优化器、损失函数以及训练和验证批次的函数**:这一步由于每个深度学习实验都具有独特性,所以没有封装好的函数,需要按照常规方式创建。
6. **加载所有组件并开始训练**:
```python
model, criterion, optimizer = get_model()
n_epochs = 5
log = Report(n_epochs)
for epoch in range(n_epochs):
N = len(train_loader)
for ix, data in enumerate(train_loader):
total_loss,gender_loss,age_loss = train_batch(data,
model, optimizer, criterion)
log.record(epoch+(ix+1)/N, trn_loss=total_loss, end='\r')
N = len(test_loader)
for ix, data in enumerate(test_loader):
total_loss,gender_acc,age_mae = validate_batc(data, \
model, criterion)
gender_acc /= len(data[0])
age_mae /= len(data[0])
log.record(epoch+(ix+1)/N, val_loss=total_loss,
val_gender_acc=gender_acc,
val_age_mae=age_mae, end='\r')
log.report_avgs(epoch+1)
log.plot_epochs()
```
`Report` 类是一个非常实用的工具,它可以记录训练和验证过程中的各种指标,自动计算训练和验证的剩余时间,并在训练结束后绘制指标的折线图。
7. **加载样本图像并进行预测**:
```python
!wget -q https://www.dropbox.com/s/6kzr8/5_9.JPG
IM = read('/content/5_9.JPG', 1)
im = trn.preprocess_image(IM).to(device)
gender, age = model(im)
pred_gender = gender.to('cpu').detach().numpy()
pred_age = age.to('cpu').detach().numpy()
info = f'predicted gender: {np.where(pred_gender[0[0]<0.5, \
"Male","Female")}\n Predicted age {int(pred_age[0][0]*80)}'
show(IM, title=info, sz=10)
```
以下是 `torch_snippets` 库中一些重要的函数及其封装的原始函数:
| 函数 | 封装的原始函数 |
| ---- | ---- |
| `from torch_snippets import *` | - |
| `Glob` | `glob.glob` |
| `Choose` | `np.random.choice` |
| `Read` | `cv2.imread` |
| `Show` | `plt.imshow` |
| `Subplots` | `plt.subplots`(显示图像列表) |
| `Inspect` | `tensor.min, tensor.mean, tensor.max, tensor.shape, tensor.dtype`(多个张量的统计信息) |
| `Report` | 训练时记录所有指标并在训练后绘制指标图 |
#### 2. 生成类激活映射(CAMs)
在实际应用中,我们不仅希望模型能够做出准确的预测,还希望能够解释模型为什么会做出这样的预测。类激活映射(CAMs)就是一种非常有用的工具,它可以帮助我们理解模型在做出预测时所关注的图像区域。
##### 2.1 CAMs 的原理
特征图是卷积操作后的中间激活,通常其形状为 `n - 通道 x 高度
0
0
复制全文
相关推荐










