首页 > 分享 > 图像去模糊:跑通DeblurGAN

图像去模糊:跑通DeblurGAN

目录

一.环境的配置

二.跑通测试predict.py

三.跑通训练train.py

1.数据准备

2.数据增强方式

3.加载预训练模型

4.模型训练结果存储问题

工程:https://github.com/VITA-Group/DeblurGANv2

一.环境的配置

直接使用python train.py缺什么库,就安装什么库;

二.跑通测试predict.py

需要设置以下参数:

--img_pattern

/media/XXX/test/LR/2021-01-28_11-21-04_white.jpg

--mask_pattern

None

--weights_path

/media/XXX/deblur/DeblurGANv2-master/weights/fpn_inception.h5

--out_dir

submit/

--side_by_side

False

--video

False

三.跑通训练train.py

配置config中的config.yaml参数:

---

project: deblur_gan

experiment_desc: fpn #日志存储文件夹

train:

files_a: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/train/LR/*.jpg #&FILES_A /datasets/my_dataset/**/*.jpg #low quality/blury images

files_b: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/train/HR/*.jpg #*FILES_A #clean files

size: &SIZE 256

crop: random #裁剪方式选择,备选项为:center

preload: &PRELOAD false

preload_size: &PRELOAD_SIZE 0

bounds: [0, .9]

scope: geometric

corrupt: &CORRUPT

- name: cutout

prob: 0.5 #数据增强概率

num_holes: 3

max_h_size: 25

max_w_size: 25

- name: jpeg #增强方式选择,配合aug.py 函数def _resolve_aug_fn(name)中查看挑选需要的增强方式

quality_lower: 70

quality_upper: 90

- name: motion_blur #增强方式选择,配合aug.py 函数def _resolve_aug_fn(name)中查看挑选需要的增强方式

- name: median_blur

- name: gamma

- name: rgb_shift

- name: hsv_shift

- name: sharpen

val:

files_a: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/test/LR/*.jpg #*FILES_A

files_b: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/test/HR/*.jpg #*FILES_A

size: *SIZE

scope: geometric

crop: center

preload: *PRELOAD

preload_size: *PRELOAD_SIZE

bounds: [.9, 1]

corrupt: *CORRUPT

phase: train

warmup_num: 3

model:

g_name: fpn_inception

blocks: 9

d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale

d_layers: 3

content_loss: perceptual

adv_lambda: 0.001

disc_loss: wgan-gp

learn_residual: True

norm_layer: instance

dropout: True

num_epochs: 200

train_batches_per_epoch: 1000 #训练进度条长度

val_batches_per_epoch: 100 #验证时进度条长度

batch_size: 1

image_size: [256, 256] #图像推理尺寸

optimizer:

name: adam

lr: 0.0001

scheduler:

name: linear

start_epoch: 50

min_lr: 0.0000001

1.数据准备

注意:训练时数据推理尺寸为256*256,为了防止图像变形,所以使用的训练连样本都是宽高相等的图片;

准备自己的数据时,HR和LR图像的尺寸要相等,这个有别于超分辨率准备的数据,当HR和LR图像尺寸不相等时,模型训练的精度会一直起不来,本人训练时PSNR一直在16徘徊,跑了一晚上才醒悟(训练有问题啊);

2.数据增强方式

该项目中使用的是albumentations库,结合config中的config.yaml进行参数配置;

目前有这么多种增强方式,可修改源码

选其一:aug.py def get_transform中

albu.HorizontalFlip(always_apply=True),

albu.ShiftScaleRotate(always_apply=True),

albu.Transpose(always_apply=True),

albu.OpticalDistortion(always_apply=True),

albu.ElasticTransform(always_apply=True)

albu.RandomCrop

albu.CenterCrop

选其一:aug.py def _resolve_aug_fn(name):

albu.Cutout,

albu.RGBShift,

albu.HueSaturationValue,

albu.MotionBlur,

albu.MedianBlur,

albu.RandomSnow,

albu.RandomShadow,

albu.RandomFog,

albu.RandomBrightnessContrast,

albu.RandomGamma,

albu.RandomSunFlare,

albu.Sharpen,

albu.ImageCompression,

albu.ToGray,

albu.Downscale,

3.加载预训练模型

train.py中的def _init_params(self):

self.criterionG, criterionD = get_loss(self.config['model'])

self.netG, netD = get_nets(self.config['model'])

self.netG.load_state_dict(torch.load("weights/fpn_inception.h5", map_location='cpu')['model'])

4.模型训练结果存储问题

按照原工程中的设置,日志文件存储在fpn文件夹下;

训练模型只存储最新一个和最好的一个模型,而且是存储在工程根目录下,没有另起一个文件夹存储,可以修改def train(self)中的代码:

原代码为:

if self.metric_counter.update_best_model():

torch.save({'model': self.netG.state_dict()},

'best_{}.h5'.format(self.config['experiment_desc']))

torch.save({'model': self.netG.state_dict()

}, 'last_{}.h5'.format(self.config['experiment_desc']))

修改为:

if self.metric_counter.update_best_model():

torch.save({'model': self.netG.state_dict()},self.config['experiment_desc']+'/best_{}.h5'.format(self.config['experiment_desc']))

torch.save({'model': self.netG.state_dict()}, self.config['experiment_desc']+'/last_{}.h5'.format(self.config['experiment_desc']))

if epoch // 50:

torch.save({'model': self.netG.state_dict()}, self.config['experiment_desc']+'/epoch_{}.h5'.format(epoch))

其他链接:

图像去模糊之DeblurGAN-v2_年轻即出发,-CSDN博客_deblurganv2

相关知识

图像去模糊:跑通DeblurGAN
评测 | 跑得快就是硬道理!跑绿通选TA准没错儿
宠物通心术txt下载
智能图像分析
基于Python的图像分类 项目实践——图像分类项目
医学图像处理教学系统
宠物DR拍摄流程详解:从设备准备到图像诊断的全方位指南
图像评估模型训练方法、图像处理方法、装置、计算机设备和可读存储介质
使用python实现图像对比度增强
基于深度学习的宠物猫排泄物图像分类及其在宠物猫智能家居系统的应用研究

网址: 图像去模糊:跑通DeblurGAN https://m.mcbbbk.com/newsview368263.html

所属分类:萌宠日常
上一篇: 狗粮低敏什么意思
下一篇: 【学习笔记】pytorch 深度