首先是要下载数据集,下载地址:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition
数据解压之后会有两个文件夹,一个是“train”,一个是“test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据,也是网站要求提交标签的。
在train文件夹里边是一些已经命名好的图像,有猫也有狗
而在test文件夹中是只有编号名的图像
大致了解了数据集后,下边就开始划分数据集
先放一段代码,这是从书中截取出来的:
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
"""
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs)
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7 * imgs_num)]
else:
self.imgs = imgs[int(0.7 * imgs_num):]
if transforms is None:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
else:
self.transforms = T.Compose([
T.Resize(256),
T.RandomReSizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
这里建立了一个类,继承自data.Dataset,里边有三个方法是必须重写的:
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
"""
def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
def __len__(self):
下面开始解释每个方法中语句的功能
def __init__(self, root, transforms=None, train=True, test=False):
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs)
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7 * imgs_num)]
else:
self.imgs = imgs[int(0.7 * imgs_num):]
if transforms is None:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
else:
self.transforms = T.Compose([
T.Resize(256),
T.RandomReSizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
'def __getitem__(self, index):
"""
一次返回一张图片的数据
"""
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
'def __len__(self):
return len(self.imgs)
'到此位置,数据集的划分与数据类已经完成
完整训练过程可以看我另一篇博客:
https://blog.csdn.net/qq_41685265/article/details/104898848
相关知识
狗猫分类数据集划分详解
【猫狗数据集】宠物品种分类 计算机视觉 人工智能 机器学习 (含数据集)
如何正确划分训练数据集和测试数据集
YOLO数据集划分教程:如何划分训练、验证和测试集
猫狗数据集:12000张图助力AI分类训练
Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)
Python深度学习(4):猫狗分类
猫狗图片分类 03分析图片数据
个人笔记:OpenCV(一)图像分类——猫狗分类为例
机器学习:训练集与测试集的划分
网址: 狗猫分类数据集划分详解 https://m.mcbbbk.com/newsview640195.html
上一篇: 过年回家宠物怎么办?不用慌,这四 |
下一篇: 基于AI的猫狗图片识别与分类技术 |