少样本图像分类(FewShot):CAML算法代码的实现与教学(一)

少样本机器学习(Few-Shot Learning):CAML算法代码实现与教学

引言

在机器学习领域,少样本学习(Few-Shot Learning, FSL)是一个备受关注的研究方向。与传统的监督学习不同,少样本学习旨在通过极少量的标注数据来训练模型,使其能够在新的类别上表现良好。这种能力对于那些标注数据稀缺或获取成本高昂的应用场景尤为重要。

本文将深入探讨少样本学习的核心概念,并重点介绍一种名为CAML(Context-Aware Meta-Learning的算法。我们通过改造原作者的代码,并进行扩展使用。

一、少样本学习常识概念

1.1 少样本学习要点

少样本学习是一种机器学习范式,其目标是通过极少量的训练样本(通常是每个类别1到5个样本)来学习一个模型,使其能够在新的类别上进行泛化。与传统的监督学习不同,少样本学习不需要大量的标注数据,因此非常适合那些数据稀缺的场景。

1.2 N-way K-shot

N-way K-shot 是机器学习和深度学习中用于描述小样本学习(Few-Shot Learning)任务的一个术语。它通常用于元学习(Meta-Learning)或小样本分类任务中。

  • N-way: 表示在任务中有多少个类别(或类别数)需要进行分类。例如,如果任务是 “5-way”,那么模型需要在5个不同的类别中进行分类。

  • K-shot: 表示每个类别中有多少个样本(或示例)用于训练或测试。例如,如果任务是 “5-shot”,那么每个类别有5个样本。

  • N-way K-shot 表示在一个任务中,模型需要在N个类别中进行分类,每个类别有K个样本。

在少样本学习(Few-Shot Learning)中,支持集(Support Set)和查询集(Query Set)是两个关键的概念,它们用于训练和评估模型在有限样本情况下的学习能力。

1.3 支持集(Support Set)

支持集是用于提供模型学习所需信息的样本集合。它通常包含少量(即“少样本”)的标记数据,这些数据用于帮助模型识别和理解特定任务或类别。支持集的规模通常很小,可能只有几个样本(例如,每个类别只有1到5个样本)。

  • 组成:支持集由多个类别的样本组成,每个类别通常有少量样本。
  • 作用:支持集用于模型学习如何区分不同的类别。在训练过程中,模型通过支持集来学习类别的特征和模式。

1.4 查询集(Query Set)

查询集是用于测试模型性能的样本集合。它包含未标记的样本,模型需要根据从支持集中学到的知识来对这些样本进行分类或预测。

  • 组成:查询集通常包含与支持集相同类别的样本,但模型之前没有见过这些样本。
  • 作用:查询集用于评估模型在少样本情况下的泛化能力。模型需要根据支持集中的少量样本,对查询集中的样本进行正确的分类。

二、 图像分类算法CAML

CAML(Context-Aware Meta-Learning)是一种基于元学习的少样本学习算法。CAML通过引入上下文感知的机制,实现小样本图像的识别分类。CAML的核心思想是通过元学习的方式,学习一个能够在少量数据上快速适应的模型。CAML的核心要点(原理细节可以点此看我空间另一文章):

  • 适合少样本:CAML算法可以支持每个类别个位数的图片,当前已训练好的模型,支持5个及以下(way<=5)类别的图像分类。
  • 可动态加载support样本:在推理过程中,把支持集(Support Set)实时编码进上下文。
  • 上下文感知:CAML通过引入上下文感知的机制,进一步提升了模型的表现。
  • 效果表现靠前:在2024年排行榜里,少样本图像分类榜单中,排前三,开源中排第一。(详细对比点此看我另一篇文章
    在这里插入图片描述

三、CAML代码实现

1.1 CAML源码预览

首先,我们看下原作者的代码,为了方便演示讲解,按当前时间我进行了代码克隆。克隆后的代码仓库如下:
https://github.com/1148270327/CAML
截图如下:
在这里插入图片描述

1.2 CAML源码BUG与不足

  1. 不足:通过命令行传递调用evaluation/test.py
python src/evaluation/test.py --model CAML --gpu 4 --eval_dataset pascal_paintings  --fe_type timm:vit_base_patch16_clip_224.openai:768
  1. bug:evaluation/utils.py第18行代码错误,修改如下:
  if args.model == 'CAML':
    if 'openai' in args.fe_type:
      model_path = 'caml_pretrained_models/CAML_CLIP/model.pth'
    elif 'laion' in args.fe_type:
      model_path = 'caml_pretrained_models/CAML_Laion2b/model.pth'
    elif 'resnet' in args.fe_type:
      model_path = 'caml_pretrained_models/CAML_ResNet34/model.pth'
  1. 不足:大部分人无法运行代码,存在以下问题:
    • huggface的timm模型无法在线下载。
    • 已训练的CAML模型文件,在google网盘,无法下载。
    • 代码没有实现推理,只有对数据集做自动化评估,而且是对1-5shot的多个评估。
    • 支持集、查询集都是随机抽取的,无法进行固定测试。
    • 代码不全,而且数据集dataset格式各一。

1.2 CAML核心源码位置讲解

python src/evaluation/test.py --model CAML --gpu 4 --eval_dataset pascal_paintings  --fe_type timm:vit_base_patch16_clip_224.openai:768

对该执行传参的解析:

  • model CAML: 调用models/CAML.py,除了这个,还有元学习算法中的原型网络等。
  • gpu 4: Pytorch指定显卡调用, 类似字符串device=torch.device(“cuda:4”)。
  • eval_dataset 就是指定评测集的格式,和什么公开集标注一样的方式去加载自己的支持集。
  • fe_type timm:vit_base_patch16_clip_224.openai:768 :这个是通过字符串在代码里做切割,主要用于输入图像的特征编码,包括编码和Transform操作。里面涵盖了几个信息,从timm库里下载“vit_base_patch16_clip_224.openai”这个模型文件,而768是模型的维度,其实就是下载的timm库模型文件的dim,一般在模型文件下载链接的config.json中有标记。huggface截图如下:
    在这里插入图片描述
    对evaluation/test.py代码分析如下(看注释):
# 核心代码注释
if __name__ == '__main__':
  #命令行传参过来的内容,默认设置5个类别
  way = 
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LensonYuan

蚊子腿也是肉!感谢!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值