首页 > 分享 > 狗猫分类数据集划分详解

狗猫分类数据集划分详解

数据集介绍

首先是要下载数据集,下载地址:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition

数据解压之后会有两个文件夹,一个是“train”,一个是“test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据,也是网站要求提交标签的。

在train文件夹里边是一些已经命名好的图像,有猫也有狗

而在test文件夹中是只有编号名的图像

大致了解了数据集后,下边就开始划分数据集

代码

先放一段代码,这是从书中截取出来的:

import os

from PIL import Image

from torch.utils import data

import numpy as np

from torchvision import transforms as T

class DogCat(data.Dataset):

def __init__(self, root, transforms=None, train=True, test=False):

"""

主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据

"""

self.test = test

imgs = [os.path.join(root, img) for img in os.listdir(root)]

if self.test:

imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))

else:

imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))

imgs_num = len(imgs)

if self.test:

self.imgs = imgs

elif train:

self.imgs = imgs[:int(0.7 * imgs_num)]

else:

self.imgs = imgs[int(0.7 * imgs_num):]

if transforms is None:

normalize = T.Normalize(mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225])

if self.test or not train:

self.transforms = T.Compose([

T.Resize(224),

T.CenterCrop(224),

T.ToTensor(),

normalize

])

else:

self.transforms = T.Compose([

T.Resize(256),

T.RandomReSizedCrop(224),

T.RandomHorizontalFlip(),

T.ToTensor(),

normalize

])

def __getitem__(self, index):

"""

一次返回一张图片的数据

"""

img_path = self.imgs[index]

if self.test:

label = int(self.imgs[index].split('.')[-2].split('/')[-1])

else:

label = 1 if 'dog' in img_path.split('/')[-1] else 0

data = Image.open(img_path)

data = self.transforms(data)

return data, label

def __len__(self):

return len(self.imgs)

详解

这里建立了一个类,继承自data.Dataset,里边有三个方法是必须重写的:

class DogCat(data.Dataset):

def __init__(self, root, transforms=None, train=True, test=False):

"""

主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据

"""

def __getitem__(self, index):

"""

一次返回一张图片的数据

"""

def __len__(self):

下面开始解释每个方法中语句的功能

def __init__(self, root, transforms=None, train=True, test=False):

self.test = test

imgs = [os.path.join(root, img) for img in os.listdir(root)]

if self.test:

imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))

else:

imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))

imgs_num = len(imgs)

if self.test:

self.imgs = imgs

elif train:

self.imgs = imgs[:int(0.7 * imgs_num)]

else:

self.imgs = imgs[int(0.7 * imgs_num):]

if transforms is None:

normalize = T.Normalize(mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225])

if self.test or not train:

self.transforms = T.Compose([

T.Resize(224),

T.CenterCrop(224),

T.ToTensor(),

normalize

])

else:

self.transforms = T.Compose([

T.Resize(256),

T.RandomReSizedCrop(224),

T.RandomHorizontalFlip(),

T.ToTensor(),

normalize

])

'

def __getitem__(self, index):

"""

一次返回一张图片的数据

"""

img_path = self.imgs[index]

if self.test:

label = int(self.imgs[index].split('.')[-2].split('/')[-1])

else:

label = 1 if 'dog' in img_path.split('/')[-1] else 0

data = Image.open(img_path)

data = self.transforms(data)

return data, label

'

def __len__(self):

return len(self.imgs)

'

到此位置,数据集的划分与数据类已经完成

完整训练过程可以看我另一篇博客:

https://blog.csdn.net/qq_41685265/article/details/104898848

相关知识

狗猫分类数据集划分详解
【猫狗数据集】宠物品种分类 计算机视觉 人工智能 机器学习 (含数据集)
如何正确划分训练数据集和测试数据集
YOLO数据集划分教程:如何划分训练、验证和测试集
猫狗数据集:12000张图助力AI分类训练
Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)
Python深度学习(4):猫狗分类
猫狗图片分类 03分析图片数据
个人笔记:OpenCV(一)图像分类——猫狗分类为例
机器学习:训练集与测试集的划分

网址: 狗猫分类数据集划分详解 https://m.mcbbbk.com/newsview640195.html

所属分类:萌宠日常
上一篇: 过年回家宠物怎么办?不用慌,这四
下一篇: 基于AI的猫狗图片识别与分类技术