文档章节

梯度下降法求多元线性回归及Java实现

 冷血狂魔
发布于 07/17 19:28
字数 1613
阅读 1683
收藏 37

 对于数据分析而言,我们总是极力找数学模型来描述数据发生的规律, 有的数据我们在二维空间就可以描述,有的数据则需要映射到更高维的空间。数据表现出来的分布可能是完全离散的,也可能是聚集成堆的,那么机器学习的任务就是让计算机自己在数据中学习到数据的规律。那么这个规律通常是可以用一些函数来描述,函数可能是线性的,也可能是非线性的,怎么找到这些函数,是机器学习的首要问题。

    本篇博客尝试用梯度下降法,找到线性函数的参数,来拟合一个数据集。

    假设我们有如下函数,其中x是一个三个维度,

  

    写一个java程序来,随机产生100笔数据作为训练集。


Random random = new Random();
double[] results = new double[100];
double[][] features = new double[100][3];
for (int i = 0; i < 100; i++) {
    for (int j = 0; j < features[i].length; j++) {
	features[i][j] = random.nextDouble();
    }
    results[i] = 3 * features[i][0] + 4 * features[i][1] + 5 * features[i][2] + 10;
}

上面的程序中results就是函数的值,features的第二维就是随机产生的3个x。

    有了训练集,我们的任务就变成了如何求出3个各种的系数3、4、5,以及偏移量10,系数和偏移量可以取任意值,那么我们就得到了一个函数集,任务转化一下就变成了找出一个函数作用于训练集之后,与真实值的误差最小,如何评判误差的大小呢?我们需要定义一个函数来评判,那么给这个函数取一个名字,叫损失函数。这里,损失函数定义为,其中为真实值,问题就转化为在训练集中求如下函数:

    如何求这个函数的极小值呢?如果我们计算能力无限大,直接穷举就完了,但是这不是高效的办法,这时候就说的了梯度下降法,我们来看看数学里对梯度的定义。

    在微积分里面,对多元函数的参数求∂偏导数,把求得的各个参数的偏导数以向量的形式写出来,就是梯度。比如函数f(x,y), 分别对x,y求偏导数,求得的梯度向量就是(∂f/∂x, ∂f/∂y)T,简称grad f(x,y)或者▽f(x,y)。

    梯度告诉我们两件事情:

    1、函数增大的方向

    2、我们走向增大的方向,应该走多大步幅

    求极小值,我们反方向走即可,加个负号,但是这个步幅有个问题,如果过大,参数就直接飞出去了,就很难在找到最小值,如果太小,则很有可能卡在局部极小值的地方。所以,我们设计了一个系数来调节步幅,我们叫它学习速率learningRate。

    好了,为了好描述,我们把上面的函数泛化一下,表示成如下公式:

    

    损失函数对每个参数求偏导数,根据偏导数值,当然求导的过程需要用到链式法则,,这里我们直接给出参数更新公式如下:

    对于BGD(批量梯度下降法):

    

    

    

    

    对于SGD(随机梯度下降法),SGD与BGD不同的是每笔数据,我们都更新一次参数,效率比较低下。公式和上面类似,去掉求和符号和除以N即可。

    下面是具体的代码实现


import java.util.Random;
 
public class LinearRegression {
 
	public static void main(String[] args) {
		// y=3*x1+4*x2+5*x3+10
		Random random = new Random();
		double[] results = new double[100];
		double[][] features = new double[100][3];
		for (int i = 0; i < 100; i++) {
			for (int j = 0; j < features[i].length; j++) {
				features[i][j] = random.nextDouble();
			}
			results[i] = 3 * features[i][0] + 4 * features[i][1] + 5 * features[i][2] + 10;
		}
		double[] parameters = new double[] { 1.0, 1.0, 1.0, 1.0 };
		double learningRate = 0.01;
		for (int i = 0; i < 30; i++) {
			SGD(features, results, learningRate, parameters);
		}
		parameters = new double[] { 1.0, 1.0, 1.0, 1.0 };
		System.out.println("==========================");
		for (int i = 0; i < 3000; i++) {
			BGD(features, results, learningRate, parameters);
		}
	}
 
	private static void SGD(double[][] features, double[] results, double learningRate, double[] parameters) {
		for (int j = 0; j < results.length; j++) {
			double gradient = (parameters[0] * features[j][0] + parameters[1] * features[j][1]
					+ parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][0];
			parameters[0] = parameters[0] - 2 * learningRate * gradient;
 
			gradient = (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2]
					+ parameters[3] - results[j]) * features[j][1];
			parameters[1] = parameters[1] - 2 * learningRate * gradient;
 
			gradient = (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2]
					+ parameters[3] - results[j]) * features[j][2];
			parameters[2] = parameters[2] - 2 * learningRate * gradient;
 
			gradient = (parameters[0] * features[j][0] + parameters[1] * features[j][1] + parameters[2] * features[j][2]
					+ parameters[3] - results[j]);
			parameters[3] = parameters[3] - 2 * learningRate * gradient;
		}
		
		double totalLoss = 0;
		for (int j = 0; j < results.length; j++) {
			totalLoss = totalLoss + Math.pow((parameters[0] * features[j][0] + parameters[1] * features[j][1]
					+ parameters[2] * features[j][2] + parameters[3] - results[j]), 2);
		}
		System.out.println(parameters[0] + " " + parameters[1] + " " + parameters[2] + " " + parameters[3]);
		System.out.println("totalLoss:" + totalLoss);
	}
 
	private static void BGD(double[][] features, double[] results, double learningRate, double[] parameters) {
		double sum = 0;
		for (int j = 0; j < results.length; j++) {
			sum = sum + (parameters[0] * features[j][0] + parameters[1] * features[j][1]
					+ parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][0];
		}
		double updateValue = 2 * learningRate * sum / results.length;
		parameters[0] = parameters[0] - updateValue;
 
		sum = 0;
		for (int j = 0; j < results.length; j++) {
			sum = sum + (parameters[0] * features[j][0] + parameters[1] * features[j][1]
					+ parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][1];
		}
		updateValue = 2 * learningRate * sum / results.length;
		parameters[1] = parameters[1] - updateValue;
 
		sum = 0;
		for (int j = 0; j < results.length; j++) {
			sum = sum + (parameters[0] * features[j][0] + parameters[1] * features[j][1]
					+ parameters[2] * features[j][2] + parameters[3] - results[j]) * features[j][2];
		}
		updateValue = 2 * learningRate * sum / results.length;
		parameters[2] = parameters[2] - updateValue;
 
		sum = 0;
		for (int j = 0; j < results.length; j++) {
			sum = sum + (parameters[0] * features[j][0] + parameters[1] * features[j][1]
					+ parameters[2] * features[j][2] + parameters[3] - results[j]);
		}
		updateValue = 2 * learningRate * sum / results.length;
		parameters[3] = parameters[3] - updateValue;
 
		double totalLoss = 0;
		for (int j = 0; j < results.length; j++) {
			totalLoss = totalLoss + Math.pow((parameters[0] * features[j][0] + parameters[1] * features[j][1]
					+ parameters[2] * features[j][2] + parameters[3] - results[j]), 2);
		}
		System.out.println(parameters[0] + " " + parameters[1] + " " + parameters[2] + " " + parameters[3]);
		System.out.println("totalLoss:" + totalLoss);
	}
}

运行结果如下:

同样是更新3000次参数。

1、SGD结果:

参数分别为:3.087332784857909 、4.075233812033048 、5.06020828348889、 9.89116046652793

totalLoss:0.13515381461776949

2、BGD结果:

参数分别为:3.0819123489025344 、4.064145151461403、5.046862571520019、 9.899847277313173

totalLoss:0.1050937019067582

可以看出,BGD有更好的表现。

快乐源于分享。

   此博客乃作者原创, 转载请注明出处

© 著作权归作者所有

共有 人打赏支持
粉丝 86
博文 36
码字总数 43730
作品 0
杭州
程序员
私信 提问
加载中

评论(12)

久永
久永

引用来自“久永”的评论

java跑,性能能跟得上吗?

引用来自“首席程序猿_默”的评论

java 不比Python快?

引用来自“只一人在风中丶”的评论

java仅次于c++哦
不,比C++还快——某些场合。
问题是,
江河君
江河君
打扰了
沧海一刀
沧海一刀
久仰久仰
沧海一刀
沧海一刀
厉害厉害
对方是你的微信好友

引用来自“久永”的评论

java跑,性能能跟得上吗?

引用来自“首席程序猿_默”的评论

java 不比Python快?
java仅次于c++哦
简洛-默
简洛-默

引用来自“久永”的评论

java跑,性能能跟得上吗?
java 不比Python快?
开源社区金三胖
开源社区金三胖
打扰了
OSC_tpNdpk
OSC_tpNdpk
打扰了,告辞!
撸脱臼
撸脱臼
打扰了
久永
久永
java跑,性能能跟得上吗?
ND4J求多元线性回归以及GPU和CPU计算性能对比

上一篇博客《梯度下降法求多元线性回归及Java实现》简单了介绍了梯度下降法,并用Java实现了一个梯度下降法求回归的例子。本篇博客,尝试用dl4j的张量运算库nd4j来实现梯度下降法求多元线性回...

冷血狂魔
07/17
0
0
优化算法——牛顿法(Newton Method)

一、牛顿法概述 除了前面说的梯度下降法,牛顿法也是机器学习中用的比较多的一种优化算法。牛顿法的基本思想是利用迭代点处的一阶导数(梯度)和二阶导数(Hessen矩阵)对目标函数进行二次函数近...

google19890102
2014/11/13
0
0
机器学习-线性回归LinearRegression

概述 今天要说一下机器学习中大多数书籍第一个讲的(有的可能是KNN)模型-线性回归。说起线性回归,首先要介绍一下机器学习中的两个常见的问题:回归任务和分类任务。那什么是回归任务和分类...

hiyoung
10/09
0
0
多元线性回归公式推导及R语言实现

多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示。 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维...

知然
08/28
0
0
吴恩达机器学习笔记(2)——多变量线性回归

上一篇我们提到了单变量的线性回归模型,但是我们实际遇到的问题,都会有多个变量影响,比如上篇的例子——房价问题,在实际情况下影响房价的一定不止房子的面积,房子的地理位置,采光度等等...

机智的神棍酱
06/18
0
0

没有更多内容

加载失败,请刷新页面

加载更多

起薪2万的爬虫工程师,Python需要学到什么程度才可以就业?

爬虫工程师的的薪资为20K起,当然,因为大数据,薪资也将一路上扬。那么,Python需要学到什么程度呢?今天我们来看看3位前辈的回答。 1、前段时间快要毕业,而我又不想找自己的老本行Java开发...

糖宝lsh
17分钟前
1
0
携手开发者共建云生态 首届腾讯云+社区开发者大会在京举办

本文由云+社区发表 北京时间12月15日,由腾讯云主办,极客邦科技、微信、腾讯TEG协办的首届腾讯云+社区开发者大会在北京朝阳悠唐皇冠假日酒店举办。在会上,腾讯云发布了重磅产品开发者平台以...

腾讯云加社区
38分钟前
1
0
人工智能时代员工如何证明其IT工作价值

机器人可以取代你的工作吗?你能帮助机器人完成它的工作吗?如果你正在考虑自己的职业生涯以及今后将如何发展,那么应该询问自己这些问题了。 机器人可以取代你的工作吗?你能帮助机器人完成它的...

Linux就该这么学
39分钟前
2
0
CPU性能过剩提升乏力影响未来行业发展吗?

虽然CPU仍然在不断发展,但是它的性能已经不再仅仅受限于单个处理器类型或制造工艺上了。和过去相比,CPU性能提升的步伐明显放缓了,接下来怎么办,成为横亘在整个行业面前的大问题。 自201...

linuxCool
50分钟前
2
0
使用Autowired和Qualifier解决多个相同类型的bean如何共存的问题

注意: 实现类UserServiceImpl,MyUserServiceImpl 需要区分:@Service("userServicel") @Service("myUserService") https://blog.csdn.net/russle/article/details/80287763......

qimh
今天
5
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部