pytorch读取tfrecord文件
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import transforms, datasets
from tqdm import tqdm
from tfrecord.torch.dataset import TFRecordDataset
from model import resnet18
def parse_tfrecord(single_record):
image, label = single_record
image = torch.tensor(single_record["image/encoded"])
label = torch.tensor(single_record["image/label"]).squeeze()
image = torchvision.io.decode_jpeg(image).float()
return (image, label)
def main():
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
batch_s