当前位置:文档之家› 零基础入门深度学习(2) - 线性单元和梯度下降

零基础入门深度学习(2) - 线性单元和梯度下降

对比此前我们讲过的感知器
这样替换了激活函数之后,线性单元将返回一个实数值线性单元的模型
首先,我们随便选择一个点开始,比如上图的点。

接下来,每次迭代修改
你可能要问了,为啥每次修改的值,都能往函数最小值那个方向前进呢?这里的奥秘在于,我们每次都是向函数
向来修改。

什么是梯度呢?翻开大学高数课的课本,我们会发现梯度
然就是函数值下降最快的方向了。

我们每次沿着梯度相反方向去修改
最小值那个点,是因为我们每次移动的步长不会那么恰到好处,有可能最后一次迭代走远了越过了最小值那个点。

步长的选择是门手艺,如果选择小了,那么就会迭代很多轮才能走到最小值附近;如果选择大了,那可能就会越过最小值很远,收敛不到一个好的点上。

如上图,椭圆表示的是函数值的等高线,椭圆中心是函数的最小值点。

红色是BGD的逼近曲线,而紫色是SGD的逼近曲线。

我们可以看到BGD是一直向着最低点前进的,而SGD明显躁动了许多,但总体上仍然是向最低点逼近的。

最后需要说明的是,SGD不仅仅效率高,而且随机性有时候反而是好事。

今天的目标函数是一个『凸函数』,沿着梯度反方向就能找到全局唯一的最小值。

然而对于非凸函数来说,存在许多局部最小值。

随机性有助于我们逃离某些很糟糕的局部最小值,从而获得一个更好的模型。

实现线性单元
接下来,让我们撸一把代码。

因为我们已经写了感知器的代码,因此我们先比较一下感知器模型和线性单元模型,看看哪些代码能够复用。

算法感知器线性单元
模型
训练规则
比较的结果令人震惊,原来除了激活函数不同之外,两者的模型和训练规则是一样的(在上表中,线性单元的优化算法是SGD算法)。

那么,我们只需要把感知器的激活函数进行替换即可。

感知器的代码请参考上一篇文章零基础入门深度学习(1) - 感知器,这里就不再重复了。

对于一个养成良好习惯的程序员来说,重复代码是不可忍受的。

大家应该把代码保存在一个代码库中(比如git)。

1.f r o m p e r c e p t r o n i m p o r t P e r c e p t r o n
2.
3.#定义激活函数f
4.f=l a m b d a x:x
5.
6.c l a s s L i n e a r U n i t(P e r c e p t r o n):
7.d e f__i n i t__(s e l f,i n p u t_n u m):
8.'''初始化线性单元,设置输入参数的个数'''
9.P e r c e p t r o n.__i n i t__(s e l f,i n p u t_n u m,f)
通过继承Perceptron,我们仅用几行代码就实现了线性单元。

这再次证明了面向对象编程范式的强大。

接下来,我们用简单的数据进行一下测试。

1.d e f g e t_t r a i n i n g_d a t a s e t():
2.'''
3.捏造5个人的收入数据
4.'''
5.#构建训练数据
6.#输入向量列表,每一项是工作年限
7.i n p u t_v e c s=[[5],[3],[8],[1.4],[10.1]]
8.#期望的输出列表,月薪,注意要与输入一一对应
9.l a b e l s=[5500,2300,7600,1800,11400]
10.r e t u r n i n p u t_v e c s,l a b e l s
11.
12.
13.d e f t r a i n_l i n e a r_u n i t():
14.'''
15.使用数据训练线性单元
16.'''
17.#创建感知器,输入参数的特征数为1(工作年限)
18.l u=L i n e a r U n i t(1)
19.#训练,迭代10轮,学习速率为0.01
20.i n p u t_v e c s,l a b e l s=g e t_t r a i n i n g_d a t a s e t()
21.l u.t r a i n(i n p u t_v e c s,l a b e l s,10,0.01)
22.#返回训练好的线性单元
23.r e t u r n l u
24.
25.
26.i f__n a m e__=='__m a i n__':
27.'''训练线性单元'''
28.l i n e a r_u n i t=t r a i n_l i n e a r_u n i t()
29.#打印训练获得的权重
30.p r i n t l i n e a r_u n i t
31.#测试
32.p r i n t'W o r k3.4y e a r s,m o n t h l y s a l a r y=%.2f'%l i n e a r_u n i t.p r e d i c t([3.4])
33.p r i n t'W o r k15y e a r s,m o n t h l y s a l a r y=%.2f'%l i n e a r_u n i t.p r e d i c t([15])
34.p r i n t'W o r k1.5y e a r s,m o n t h l y s a l a r y=%.2f'%l i n e a r_u n i t.p r e d i c t([1.5])
35.p r i n t'W o r k6.3y e a r s,m o n t h l y s a l a r y=%.2f'%l i n e a r_u n i t.p r e d i c t([6.3])
拟合的直线如下图。

相关主题