首页 > 分享 > CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类

CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类

在上一篇文章:CNN训练前的准备:PyTorch处理自己的图像数据(Dataset和Dataloader),大致介绍了怎么利用pytorch把猫狗图片处理成CNN需要的数据,今天就用该数据对自己定义的CNN模型进行训练及测试。

首先导入需要的包:

import torch from torch import optim import torch.nn as nn from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader from PIL import Image 1234567 定义自己的CNN网络

class cnn(nn.Module): def __init__(self): super(cnn, self).__init__() self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.conv1 = nn.Sequential( nn.Conv2d( in_channels=3, out_channels=16, kernel_size=3, stride=2, ), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) # self.conv2 = nn.Sequential( nn.Conv2d( in_channels=16, out_channels=32, kernel_size=3, stride=2, ), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) # self.conv3 = nn.Sequential( nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=2, ), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.fc1 = nn.Linear(3 * 3 * 64, 64) self.fc2 = nn.Linear(64, 10) self.out = nn.Linear(10, 2) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) # print(x.size()) x = x.view(x.shape[0], -1) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.out(x) # x = F.log_softmax(x, dim=1) return x 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 训练(GPU)

def train(): Dtr, Val, Dte = load_data() print('train...') epoch_num = 30 best_model = None min_epochs = 5 min_val_loss = 5 model = cnn().to(device) optimizer = optim.Adam(model.parameters(), lr=0.0008) criterion = nn.CrossEntropyLoss().to(device) # criterion = nn.BCELoss().to(device) for epoch in tqdm(range(epoch_num), ascii=True): train_loss = [] for batch_idx, (data, target) in enumerate(Dtr, 0): data, target = Variable(data).to(device), Variable(target.long()).to(device) # target = target.view(target.shape[0], -1) # print(target) optimizer.zero_grad() output = model(data) # print(output) loss = criterion(output, target) loss.backward() optimizer.step() train_loss.append(loss.cpu().item()) # validation val_loss = get_val_loss(model, Val) model.train() if epoch + 1 > min_epochs and val_loss < min_val_loss: min_val_loss = val_loss best_model = copy.deepcopy(model) tqdm.write('Epoch {:03d} train_loss {:.5f} val_loss {:.5f}'.format(epoch, np.mean(train_loss), val_loss)) torch.save(best_model.state_dict(), "model/cnn.pkl") 12345678910111213141516171819202122232425262728293031323334

一共训练30轮,训练的步骤如下:

初始化模型:

model = cnn().to(device) 1 选择优化器以及优化算法,这里选择了Adam:

optimizer = optim.Adam(model.parameters(), lr=0.00005) 1 选择损失函数,这里选择了交叉熵:

criterion = nn.CrossEntropyLoss().to(device) 1 对每一个batch里的数据,先将它们转成能被GPU计算的类型:

data, target = Variable(data).to(device), Variable(target.long()).to(device) 1 梯度清零、前向传播、计算误差、反向传播、更新参数:

optimizer.zero_grad() # 梯度清0 output = model(data)[0] # 前向传播 loss = criterion(output, target) # 计算误差 loss.backward() # 反向传播 optimizer.step() # 更新参数 12345 测试(GPU)

def test(): Dtr, Val, Dte = load_data() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = cnn().to(device) model.load_state_dict(torch.load("model/cnn.pkl"), False) model.eval() total = 0 current = 0 for (data, target) in Dte: data, target = data.to(device), target.to(device) outputs = model(data) predicted = torch.max(outputs.data, 1)[1].data total += target.size(0) current += (predicted == target).sum() print('Accuracy:%d%%' % (100 * current / total)) 12345678910111213141516

结果:80%
在这里插入图片描述
如果需要更高的准确率,可以使用一些预训练的模型,详见:
PyTorch搭建预训练AlexNet、DenseNet、ResNet、VGG实现猫狗图片分类

完整代码:cnn-dogs-vs-cats。原创不易,下载时请给个follow和star!感谢!!

相关知识

CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类
PyTorch深度学习:猫狗情感识别
PyTorch猫狗:深度学习在宠物识别中的应用
CNN参数设置经验
基于CNN的狗叫,猫叫语音分类
基于Python的图像分类 项目实践——图像分类项目
基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码
深度学习的艺术:从理论到实践
web网页html版通过CNN卷积神经网络的宠物行为训练识别
推荐几个提供免费GPU计算资源的平台,助力你的AI之路

网址: CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类 https://m.mcbbbk.com/newsview171061.html

所属分类:萌宠日常
上一篇: 基于Python的图像分类 项目
下一篇: iOS17宠物识别功能在iPho