41目标检测数据集
import os
import pandas as pd
import torch
import torchvision
import matplotlib.pylab as plt
from d2l import torch as d2l
# 数据集下载链接
# http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip
# 读取数据集
#@save
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
data_dir = '../data/banana-detection/'
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
# 将 img_name 列设置为索引,以便后续操作中根据图片名称索引标签。
csv_data = csv_data.set_index('img_name')
images, targets = [], [] # images 用于存储图像,targets 用于存储标签。
for img_name, target in csv_data.iterrows():
images.append