文档章节

Keras猫狗大战八:resnet50预训练模型迁移学习,图片先做归一化预处理,精度提高到97.5%

o
 osc_zoa3moe9
发布于 2019/12/07 22:10
字数 566
阅读 9
收藏 0

精选30+云产品,助力企业轻松上云!>>>

上一篇的基础上,对数据调用keras图片预处理函数preprocess_input做归一化预处理,进行训练。

导入preprocess_input:

import os

from keras import layers, optimizers, models
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.layers import *    
from keras.models import Model

数据生成添加preprocessing_function=preprocess_input

from keras.preprocessing.image import ImageDataGenerator

batch_size = 64

train_datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    preprocessing_function=preprocess_input)

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)


train_generator = train_datagen.flow_from_directory(
        # This is the target directory
        train_dir,
        # All images will be resized to 150x150
        target_size=(150, 150),
        batch_size=batch_size,
        # Since we use binary_crossentropy loss, we need binary labels
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='binary')

训练25epoch,学习率从1e-3下降到4e-5:

Epoch 1/100
281/281 [==============================] - 152s 540ms/step - loss: 0.2849 - acc: 0.8846 - lr: 0.0010 - val_loss: 0.1195 - val_acc: 0.9694 - val_lr: 0.0010
Epoch 2/100
281/281 [==============================] - 79s 282ms/step - loss: 0.2234 - acc: 0.9079 - lr: 0.0010 - val_loss: 0.1105 - val_acc: 0.9673 - val_lr: 0.0010
Epoch 3/100
281/281 [==============================] - 80s 285ms/step - loss: 0.2070 - acc: 0.9135 - lr: 0.0010 - val_loss: 0.1061 - val_acc: 0.9716 - val_lr: 0.0010
Epoch 4/100
281/281 [==============================] - 80s 283ms/step - loss: 0.1939 - acc: 0.9203 - lr: 0.0010 - val_loss: 0.0998 - val_acc: 0.9748 - val_lr: 0.0010
Epoch 5/100
......
Epoch 22/100
281/281 [==============================] - 80s 284ms/step - loss: 0.1368 - acc: 0.9470 - lr: 4.0000e-05 - val_loss: 0.0943 - val_acc: 0.9777 - val_lr: 4.0000e-05
Epoch 23/100
281/281 [==============================] - 80s 283ms/step - loss: 0.1346 - acc: 0.9479 - lr: 4.0000e-05 - val_loss: 0.1046 - val_acc: 0.9720 - val_lr: 4.0000e-05
Epoch 24/100
281/281 [==============================] - 79s 283ms/step - loss: 0.1320 - acc: 0.9476 - lr: 4.0000e-05 - val_loss: 0.0938 - val_acc: 0.9759 - val_lr: 4.0000e-05
Epoch 25/100
281/281 [==============================] - 79s 282ms/step - loss: 0.1356 - acc: 0.9476 - lr: 4.0000e-05 - val_loss: 0.1063 - val_acc: 0.9745 - val_lr: 4.0000e-05

在测试图片时也需要进行归一化预处理:
def get_input_xy(src=[]):
    pre_x = []
    true_y = []

    class_indices = {'cat': 0, 'dog': 1}

    for s in src:
        input = cv2.imread(s)
        input = cv2.resize(input, (150, 150))
        input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
        pre_x.append(preprocess_input(input))

        _, fn = os.path.split(s)
        y = class_indices.get(fn[:3])
        true_y.append(y)

    pre_x = np.array(pre_x)

    return pre_x, true_y

    
def plot_sonfusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    print(tick_marks, type(tick_marks))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks([-0.5,1.5], classes)

    print(cm)
    ok_num = 0
    for k in range(cm.shape[0]):
        print(cm[k,k]/np.sum(cm[k,:]))
        ok_num += cm[k,k]
        
    print(ok_num/np.sum(cm))
        
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predict label')

测试结果为97.5%,较前面提高了1.3%:

[[1225   25]
 [  38 1212]]
0.98
0.9696
0.9748
猫的准确度为98%,狗的为97%,总的准确度为97.5%。混淆矩阵图:

 

o
粉丝 1
博文 500
码字总数 0
作品 0
私信 提问
加载中
请先登录后再评论。
Keras深度学习实践3—计算机视觉问题:猫vs狗

内容参考以及代码整理自“深度学习四大名“著之一《Python深度学习》 一、卷积神经网络 卷积神经网络,也叫convnet,它是计算机视觉应用几乎都在使用的一种深度学习模型。 我们先来看一个简单...

小可哥哥V
2019/05/07
0
0
Tensorflow2.0学习(七):猫狗大战2、训练与保存模型

熟悉TF1.X,对Record数据情有独钟,那么TF2.0如何训练record数据呢? 在Tensorflow2.0学习(六):猫狗大战1、制作与读取record数据 已经介绍了如何生成与读取tfRecord数据,这篇介绍训练record数据...

hjxu2016
03/31
0
0
TensorFlow从1到2(九)迁移学习

迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块。而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训...

osc_gd4rlfym
2019/04/28
1
0
keras 数据集学习笔记 3/3

keras 数据集的学习笔记 3/3 深度学习需要有大量的数据集供机器来学习,本次就学习如何定义自己的数据集。 各种常用的数据集 数据集如何使用 定义自己的数据集 虽然该数据的准备和预处理工作...

gaoshine
2017/10/06
0
0
深度学习自学记录(2)——Keras迁移学习(升级版)+模型融合实现详解

最近时间充裕在学习深度学习,用博客记录一下自己的理解,毕竟好记性不如烂笔头。如果有错误的地方,希望大家指正,一起进步。如果这篇博客对你有帮助,点赞支持一下,码字不易。。。。。 迁...

qq_40784418
04/12
0
0

没有更多内容

加载失败,请刷新页面

加载更多

什么是token及怎样生成token

什么是token   Token是服务端生成的一串字符串,以作客户端进行请求的一个令牌,当第一次登录后,服务器生成一个Token便将此Token返回给客户端,以后客户端只需带上这个Token前来请求数据即...

osc_zzg7fpke
50分钟前
10
0
往事不堪回首

开局一张图,内容全靠编 从12年大学毕业到如今,兜兜转转,依然在码工,码农,码代码的路上徘徊着,从最初的用asp.net写站点,写内部的CRM,内部管理系统,内部的XXX,很难想象内部的系统居然...

osc_9os5791s
51分钟前
23
0
java 事件监听器

package com.qimh.springbootfiledemo.listener;/** * 事件监听器 * 监听person 事件源的eat 和 sleep 的方法 * @author */public interface PersonListener { void doEa...

qimh
52分钟前
21
0
[原创.数据可视化系列之四]跨平台,多格式的等值线和等值面的生成

这些年做项目的时候,碰到等值面基本都是arcgis server来支撑的,通过构建GP服务,一般的都能满足做等值线和等值面的需求。可是突然有一天,我发现如果没有arcgis server 的话,我既然没法生...

osc_bvincwvq
52分钟前
16
0
个人作业——软件工程实践总结&个人技术博客

这个作业属于哪个课程 2020春|S班(福州大学) 这个作业要求在哪里 个人作业——软件工程实践总结&个人技术博客 这个作业的目标 总结软件工程课程以及实践中的收获 作业正文 其他参考文献 一...

osc_9piujk2x
52分钟前
18
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部