文档章节

TensorFlow学习笔记 --识别圆圈内的点

StanleySun
 StanleySun
发布于 2017/07/16 09:55
字数 850
阅读 34
收藏 1
点赞 0
评论 0

    在下面这个图上,找出哪些点在圆内,哪些在圆外,对我们来说非常简单。因为我们有眼睛,能看;有大脑,能想。 但是,如果让电脑来做这件事情,就没那么简单了。我们看一下TensorFlow是如何使用深度神经网络做到的。

介绍

在平面上画一个圆,表达式为x^2+y^2 = 100。 即以原点为中心,半径为100点圆。

在平面上随机生成一批点, 要求 -200<= x <=200, -200<= y <=200。如果点落在圆内(含边界上),则该点的label为0,即图中的实心圆点; 若落在圆外面,则该点label为1,即空心圆点.

要求:通过对数据的分析,生成模型,并对新数据的label进行预测。

步骤

  • 生成数据
  • 用TensorFlow训练模型
  • 预测新数据

1. 生成数据

我用的php代码,大家可以用任何自己喜欢但语言。 文件“generate.php”可以生成2个文件,训练数据training_data.csv和测试数据test_data.csv,代码如下:

<?php

$TRAINING_NUM = 200;//生成训练集坐标点的数量
$TEST_NUM = 100;//生成测试集坐标点的数量
$TRAINING_FILE = "training_data.csv";
$TEST_FILE = "test_data.csv";

generate_data($TRAINING_FILE,$TRAINING_NUM);
generate_data($TEST_FILE,$TEST_NUM);

function generate_data ($file, $num){
    unlink($file);
    file_put_contents($file,$num.',2,in,out'."\r\n",FILE_APPEND);
    $R = 100;
    $MIN_X = -200;
    $MAX_X = 200;
    $MIN_Y = -200;
    $MAX_Y = 200;
    for ($i=0; $i < $num; $i++) { 
        $x = rand($MIN_X,$MAX_X);
        $y = rand($MIN_Y,$MAX_Y);
        $label = 1;
        if (($x*$x + $y*$y) <= $R*$R){
            $label =0;
        }
        $line =  $x.','.$y.','.$label."\r\n";
        file_put_contents($file,$line,FILE_APPEND);
    }
}

运行

php generate.php

生成2个文件training_data.csv 和test_date.csv

内容类似下面这样:

200,2,in,out
-70,-81,0
-50,-198,0
169,-93,0
51,-78,1
...

第一行是header。第一行的第一个数字表示文件的总行数(不含header),第二个数字是特征数,本例中有2个特征: x坐标和y坐标。后面2个是label(可忽略)。从第二行开始,每行的三个数字分别是x,y和label。

2. 用TensorFlow训练模型 & 预测新样本

代码circle_dnn_classifier.py 如下:

#coding:utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

# 数据集
TRAINING_FILE = "training_data.csv";
TEST_FILE = "test_data.csv";

# 加载数据
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=TRAINING_FILE,
    target_dtype=np.int,
    features_dtype=np.int)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=TEST_FILE,
    target_dtype=np.int,
    features_dtype=np.int)

# 确定所有的特征类型为real-value,特征数量为2
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=2)]

# 创建一个3层的深度神经网络, 分别有 10, 20, 10 个神经元.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="model")

# 适配模型,训练2000步
classifier.fit(x=training_set.data,y=training_set.target,steps=2000)

# 评估结果
evaluate = classifier.evaluate(x=test_set.data,y=test_set.target)
print(evaluate)

# 对新样本进行预测
new_samples = np.array([[50, 12], [121, 20]], dtype=int)
y = list(classifier.predict(new_samples, as_iterable=True))

print('Predictions: {}'.format(str(y)))

运行代码:

python circle_dnn_classifier.py

结果

...
{'loss': 0.20674889, 'global_step': 2000, 'accuracy': 0.89999978} //测试数据监测准确率89.99%
...
Predictions: [0, 1]   //对新数据预测

可以看到,模型运行正常,准确率是89.99%。

两个新样本在图中的位置,label分别是0和1,TensorFlow识别正确。

 

可以通过一些简单的办法提高精度:

1.增加训练数据,比如将训练数据增加到5000条(相应地将测试集增加到1000)

2.增加训练次数,比如将step设置为8000

经测试,通过这样的优化,测试结果准确率提高到了99.4%!

大家有兴趣,可以用椭圆或者更加复杂的规则试试,看看TensorFlow训练的效果如何。

© 著作权归作者所有

共有 人打赏支持
StanleySun
粉丝 14
博文 35
码字总数 39262
作品 0
技术主管
机器学习Tensorflow笔记1:Hello World到MNIST实验

最近重新梳理了我职业生涯规划,其中人工智能是我最重要的一个职业方向,所以就开始了人工智能的学习,其中Tensorflow是机器学习中一个很热门的框架,是由Google开源的,是一个不错的方向。由...

ImWiki ⋅ 05/12 ⋅ 0

有道云笔记是如何使用TensorFlow Lite的?

文 / 有道技术团队 近年来,有道技术团队在移动端实时 AI 能力的研究上,做了很多探索及应用的工作。2017 年 11 月 Google 发布 TensorFlow Lite (TFLlite) 后,有道技术团队第一时间跟进 TF...

谷歌开发者 ⋅ 04/21 ⋅ 0

史上最全TensorFlow学习资源汇总

来源 悦动智能(公众号ID:aibbtcom) 本篇文章将为大家总结TensorFlow纯干货学习资源,非常适合新手学习,建议大家收藏。 ▌一 、TensorFlow教程资源 1)适合初学者的TensorFlow教程和代码示...

悦动智能 ⋅ 04/12 ⋅ 0

机器学习实战篇——用卷积神经网络算法在Kaggle上跑个分

之前的文章简单介绍了Kaggle平台以及如何用支撑向量(SVM)的机器学习算法识别手写数字图片。可见即使不用神经网络,传统的机器学习算法在图像识别的领域也能取得不错的成绩(我跑出来了97....

Hongtao洪滔 ⋅ 06/18 ⋅ 0

【干货】史上最全的Tensorflow学习资源汇总,速藏!

一 、Tensorflow教程资源: 1)适合初学者的Tensorflow教程和代码示例:(https://github.com/aymericdamien/TensorFlow-Examples)该教程不光提供了一些经典的数据集,更是从实现最简单的“Hel...

技术小能手 ⋅ 04/16 ⋅ 0

送书&优惠丨对深度学习感兴趣的你,不了解这些就太OUT了!

点击上方“程序人生”,选择“置顶公众号” 第一时间关注程序猿(媛)身边的故事 TensorFlow是什么? TensorFlow的前身是谷歌大脑(google brain)团队研发的DistBelief。自创建以来,它便被...

csdnsevenn ⋅ 05/03 ⋅ 0

机器学习Tensorflow笔记3:Python训练MNIST模型,在Android上实现评估

通常而言我们会通过Python编写代码训练Tensorflow,但是我们训练的数据需要实际应用起来,本文会介绍如何通过Python训练Tensorflow,训练的结果在Android上应用,当前也可以通过传输数据给服...

ImWiki ⋅ 05/16 ⋅ 0

机器学习Tensorflow笔记4:iOS通过Core ML使用Tensorflow训练模型

Tensorflow是Google推出的人工智能框架,而Core ML是苹果推出的人工智能框架,两者是有很大的区别,其中Tensorflow是包含了训练模型和评估模型,Core ML只支持在设备上评估模型,不能训练模型...

ImWiki ⋅ 05/16 ⋅ 0

Tensorflow官网教程:CIFAR-10分类代码阅读

1.题记 因为课程设计要用到TensorFlow,所以这几天在看TensorFlow官方给的几个示例代码,前面几个例子比较简单,看完之后试着重新写了一下,从CIFAR10开始我觉得需要做点笔记了。官网的例子是...

作死少女88 ⋅ 05/25 ⋅ 0

TensorFlow应用实战-18-Policy Gradient算法

Policy Gradient算法 policy Gradient算法不止一种。 有兴趣的话: 深度增强学习之Policy Gradient方法1 https://zhuanlan.zhihu.com/p/21725498 A3c实现3d赛车游戏: 成果展示 numworkders是 ...

天涯明月笙 ⋅ 06/15 ⋅ 0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

内存障碍: 软件黑客的硬件视图

此文为笔者近日有幸看到的一则关于计算机底层内存障碍的学术论文,并翻译(机译)而来[自认为翻译的还行],若读者想要英文原版的论文话,给我留言,我发给你。 内存障碍: 软件黑客的硬件视图...

Romane ⋅ 31分钟前 ⋅ 0

SpringCloud 微服务 (七) 服务通信 Feign

壹 继续第(六)篇RestTemplate篇 做到现在,本机上已经有注册中心: eureka, 服务:client、order、product 继续在order中实现通信向product服务,使用Feign方式 下面记录学习和遇到的问题 贰 or...

___大侠 ⋅ 48分钟前 ⋅ 0

001. 深入JVM学习—Java运行流程

1. Java运行流程图 2. Java运行时数据区 3. Java虚拟机栈 栈内存是线程私有的,其生命周期和线程相同; 虚拟机栈描述的是Java方法执行的内存模型:执行一个方法时会产生一个栈帧随后将其保存...

影狼 ⋅ 今天 ⋅ 0

gitee、github上issue标签方案

目录 [TOC] issue生命周期 st=>start: 开始e=>end: 结束op0=>operation: 新建issueop1=>operation: 评审issueop2=>operation: 任务负责人执行任务cond1=>condition: 是否通过?op3=>o......

lovewinner ⋅ 今天 ⋅ 0

浅谈mysql的索引设计原则以及常见索引的区别

索引定义:是一个单独的,存储在磁盘上的数据库结构,其包含着对数据表里所有记录的引用指针. 数据库索引的设计原则: 为了使索引的使用效率更高,在创建索引时,必须考虑在哪些字段上创建索...

屌丝男神 ⋅ 今天 ⋅ 0

String,StringBuilder,StringBuffer三者的区别

这三个类之间的区别主要是在两个方面,即运行速度和线程安全这两方面。 首先说运行速度,或者说是, 1.执行速度 在这方面运行速度快慢为:StringBuilder(线程不安全,可变) > StringBuffer...

时刻在奔跑 ⋅ 今天 ⋅ 0

java以太坊开发 - web3j使用钱包进行转账

首先载入钱包,然后利用账户凭证操作受控交易Transfer进行转账: Web3j web3 = Web3j.build(new HttpService()); // defaults to http://localhost:8545/Credentials credentials = Wallet......

以太坊教程 ⋅ 今天 ⋅ 0

Oracle全文检索配置与实践

Oracle全文检索配置与实践

微小宝 ⋅ 今天 ⋅ 0

mysql的分区和分表

1,什么是mysql分表,分区 什么是分表,从表面意思上看呢,就是把一张表分成N多个小表,具体请看mysql分表的3种方法 什么是分区,分区呢就是把一张表的数据分成N多个区块,这些区块可以在同一...

梦梦阁 ⋅ 今天 ⋅ 0

exception.ZuulException: Forwarding error

错误日志 com.netflix.zuul.exception.ZuulException: Forwarding error Caused by: com.netflix.hystrix.exception.HystrixRuntimeException: xxx timed-out and no fallback available. Ca......

jack_peng ⋅ 今天 ⋅ 0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部