设为首页收藏本站
网站公告 | 这是第一条公告
     

 找回密码
 立即注册
缓存时间13 现在时间13 缓存数据 风骨神仙籍里人,诗狂酒圣且平生。开元一遇成何事,留得千秋万古名。

风骨神仙籍里人,诗狂酒圣且平生。开元一遇成何事,留得千秋万古名。 -- 杨花落尽子规啼

查看: 353|回复: 1

图文详解牛顿迭代算法原理及Python实现

[复制链接]

  离线 

TA的专栏

  • 打卡等级:热心大叔
  • 打卡总天数:205
  • 打卡月天数:0
  • 打卡总奖励:3134
  • 最近打卡:2023-08-27 08:03:57
等级头衔

等級:晓枫资讯-上等兵

在线时间
0 小时

积分成就
威望
0
贡献
405
主题
373
精华
0
金钱
4319
积分
810
注册时间
2022-12-24
最后登录
2025-5-31

发表于 2023-2-10 23:06:49 | 显示全部楼层 |阅读模式
1.引例

给定如图所示的某个函数,如何计算函数零点x0
000820m2jlaxla1024ak66.png

在数学上我们如何处理这个问题?
最简单的办法是解方程f(x)=0,在代数学上还有著名的零点判定定理
如果函数y=f(x)在区间[a,b]上的图象是连续不断的一条曲线,并且有f(a)⋅f(b)<0,那么函数y=f(x)在区间(a,b)内有零点,即至少存在一个c∈(a,b),使得f(c)=0,这个c也就是方程f(x)=0的根。
然而,数学上的方法并不一定适合工程应用,当函数形式复杂,例如出现超越函数形式;非解析形式,例如递推关系时,精确的方程解析一般难以进行,因为代数上还没发展出任意形式的求根公式。而零点判定定理求解效率也较低,需要不停试错。
因此,引入今天的主题——牛顿迭代法,服务于工程数值计算。

2.牛顿迭代算法求根

记第k轮迭代后,自变量更新为xk,令目标函数f(x)在x=xk泰勒展开:
f(x)=f(xk​)+f′(xk​)(x−xk​)+o(x)

我们希望下一次迭代到根点,忽略泰勒余项,令f(xk+1)=0,则
xk+1​=xk​−f(xk​)/f'(xk​)​

不断重复运算即可逼近根点。
在几何上,上面过程实际上是在做f(x)在x=xk处的切线,并求切线的零点,在工程上称为局部线性化。如图所示,若xk在x0的左侧,那么下一次迭代方向向右。
000820l2yspi88wxs087iz.png

若xk在x0的右侧,那么下一次迭代方向向左。
000820has6q9uuv5x6z900.png


3.牛顿迭代优化

将优化问题转化为求目标函数一阶导数零点的问题,即可运用上面说的牛顿迭代法。
具体地,记第k轮迭代后,自变量更新为xk ,令目标函数f(x)在x=xk泰勒展开:
f(x)=f(xk​)+f′(xk​)(x−xk​)+1/2​f′′(xk​)(x−xk​)2+o(x)

两边求导得
f′(x)=f′(xk​)+f′′(xk​)(x−xk​)

令f′(xk+1​)=f′(xk​)+f′′(xk​)(xk+1​−xk​)=0,从而得到
xk+1​=xk​−f′(xk​)/f'′(xk​)​

对于向量x=[x1​​ x2​​⋯​xd​​]T,将上述迭代公式推广为
xk+1​=xk​−[∇2f(xk​)]−1∇f(xk​)


其中∇2f(xk​)是Hessian矩阵,当其正定时可以保证牛顿优化算法往 减小的方向迭代
牛顿法的特点如下:
① 以二阶速率向最优点收敛,迭代次数远小于梯度下降法,优化速度快;
梯度下降法的解析参考图文详解梯度下降算法的原理及Python实现
②学习率为[∇2f(xk​)]−1 ,包含更多函数本身的信息,迭代步长可实现自动调整,可视为自适应梯度下降算法;
③ 耗费CPU计算资源多,每次迭代需要计算一次Hessian矩阵,且无法保证Hessian矩阵可逆且正定,因而无法保证一定向最优点收敛。
在实际应用中,牛顿迭代法一般不能直接使用,会引入改进来规避其缺陷,称为拟牛顿算法簇,其中包含大量不同的算法变种,例如共轭梯度法、DFP算法等等,今后都会介绍到。

4 代码实战:Logistic回归
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import matplotlib.pyplot as plt
  5. import matplotlib as mpl
  6. from Logit import Logit

  7. '''
  8. * @breif: 从CSV中加载指定数据
  9. * @param[in]: file -> 文件名
  10. * @param[in]: colName -> 要加载的列名
  11. * @param[in]: mode -> 加载模式, set: 列名与该列数据组成的字典, df: df类型
  12. * @retval: mode模式下的返回值
  13. '''
  14. def loadCsvData(file, colName, mode='df'):
  15.     assert mode in ('set', 'df')
  16.     df = pd.read_csv(file, encoding='utf-8-sig', usecols=colName)
  17.     if mode == 'df':
  18.         return df
  19.     if mode == 'set':
  20.         res = {}
  21.         for col in colName:
  22.             res[col] = df[col].values
  23.         return res

  24. if __name__ == '__main__':
  25.     # ============================
  26.     # 读取CSV数据
  27.     # ============================
  28.     csvPath = os.path.abspath(os.path.join(__file__, "../../data/dataset3.0alpha.csv"))
  29.     dataX = loadCsvData(csvPath, ["含糖率", "密度"], 'df')
  30.     dataY = loadCsvData(csvPath, ["好瓜"], 'df')
  31.     label = np.array([
  32.         1 if i == "是" else 0
  33.         for i in list(map(lambda s: s.strip(), list(dataY['好瓜'])))
  34.     ])

  35.     # ============================
  36.     # 绘制样本点
  37.     # ============================
  38.     line_x = np.array([np.min(dataX['密度']), np.max(dataX['密度'])])
  39.     mpl.rcParams['font.sans-serif'] = [u'SimHei']
  40.     plt.title('对数几率回归模拟\nLogistic Regression Simulation')
  41.     plt.xlabel('density')
  42.     plt.ylabel('sugarRate')
  43.     plt.scatter(dataX['密度'][label==0],
  44.                 dataX['含糖率'][label==0],
  45.                 marker='^',
  46.                 color='k',
  47.                 s=100,
  48.                 label='坏瓜')
  49.     plt.scatter(dataX['密度'][label==1],
  50.                 dataX['含糖率'][label==1],
  51.                 marker='^',
  52.                 color='r',
  53.                 s=100,
  54.                 label='好瓜')

  55.     # ============================
  56.     # 实例化对数几率回归模型
  57.     # ============================
  58.     logit = Logit(dataX, label)

  59.     # 采用牛顿迭代法
  60.     logit.logitRegression(logit.newtomMethod)
  61.     line_y = -logit.w[0, 0] / logit.w[1, 0] * line_x - logit.w[2, 0] / logit.w[1, 0]
  62.     plt.plot(line_x, line_y, 'g-', label="牛顿迭代法")

  63.     # 绘图
  64.     plt.legend(loc='upper left')
  65.     plt.show()
复制代码
其中更新权重代码为
  1.     '''
  2.     * @breif: 牛顿迭代法更新权重
  3.     * @param[in]: None
  4.     * @retval: 优化参数的增量dw
  5.     '''
  6.     def newtomMethod(self):
  7.         wTx = np.dot(self.w.T, self.X).reshape(-1, 1)
  8.         p = Logit.sigmod(wTx)
  9.         dw_1 = -self.X.dot(self.y - p)
  10.         dw_2 = self.X.dot(np.diag((p * (1 - p)).reshape(self.N))).dot(self.X.T)
  11.         dw = np.linalg.inv(dw_2).dot(dw_1)
  12.         return dw
复制代码
000820k1d6484s36q8a796.png

到此这篇关于图文详解牛顿迭代算法原理及Python实现的文章就介绍到这了,更多相关Python牛顿迭代算法内容请搜索晓枫资讯以前的文章或继续浏览下面的相关文章希望大家以后多多支持晓枫资讯!

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
晓枫资讯-科技资讯社区-免责声明
免责声明:以上内容为本网站转自其它媒体,相关信息仅为传递更多信息之目的,不代表本网观点,亦不代表本网站赞同其观点或证实其内容的真实性。
      1、注册用户在本社区发表、转载的任何作品仅代表其个人观点,不代表本社区认同其观点。
      2、管理员及版主有权在不事先通知或不经作者准许的情况下删除其在本社区所发表的文章。
      3、本社区的文章部分内容可能来源于网络,仅供大家学习与参考,如有侵权,举报反馈:点击这里给我发消息进行删除处理。
      4、本社区一切资源不代表本站立场,并不代表本站赞同其观点和对其真实性负责。
      5、以上声明内容的最终解释权归《晓枫资讯-科技资讯社区》所有。
http://bbs.yzwlo.com 晓枫资讯--游戏IT新闻资讯~~~

  离线 

TA的专栏

等级头衔

等級:晓枫资讯-列兵

在线时间
0 小时

积分成就
威望
0
贡献
0
主题
0
精华
0
金钱
12
积分
4
注册时间
2022-12-24
最后登录
2022-12-24

发表于 3 天前 | 显示全部楼层
路过,支持一下
http://bbs.yzwlo.com 晓枫资讯--游戏IT新闻资讯~~~
严禁发布广告,淫秽、色情、赌博、暴力、凶杀、恐怖、间谍及其他违反国家法律法规的内容。!晓枫资讯-社区
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

1楼
2楼

手机版|晓枫资讯--科技资讯社区 本站已运行

CopyRight © 2022-2025 晓枫资讯--科技资讯社区 ( BBS.yzwlo.com ) . All Rights Reserved .

晓枫资讯--科技资讯社区

本站内容由用户自主分享和转载自互联网,转载目的在于传递更多信息,并不代表本网赞同其观点和对其真实性负责。

如有侵权、违反国家法律政策行为,请联系我们,我们会第一时间及时清除和处理! 举报反馈邮箱:点击这里给我发消息

Powered by Discuz! X3.5

快速回复 返回顶部 返回列表