网站首页 > 技术文章 正文
简单地说,找到给定方程的最小值被认为是优化。这在现实生活中有很多应用:路径规划,工作车间调度,空中交通管理等。优化一直是机器学习的支柱,人们期望这些算法从海量数据中提取知识。
优化在神经网络中起着重要作用,神经网络中有数百万个参数,目标是找到正确的参数集以正确表示数据。尽管优化器的性能已经有了很大的提高,但是还有一个优化所依赖的问题,即初始点。优化的轨迹在很大程度上取决于初始点。
在这篇文章中,我们将看到初始点是如何影响一些优化算法的性能的。虽然我们在这里使用的是一个二维问题(因为它很容易可视化),但当参数(神经网络)达到数百万时,初始化问题就会变得更加普遍。
目标
初始化x、y,使用梯度下降算法找到x、y的最优值,使Beale函数的值为零(或尽可能低)。
优化算法简介
我们将考虑三种流行的优化算法,因为我们更加关注初始化,这些对于我们的分析就足够了。
- 随机梯度下降:随机梯度下降(SGD)算法每次执行一次更新,计算每一步的梯度。。
- momentum:通过考虑梯度在一段时间内的动量,解决了随机梯度下降更新缓慢的问题。
- Adam:被认为是最流行的优化算法。。
我们将使用PyTorch的autograd功能来获得梯度,使用matplotlib来绘制轨迹。首先导入Python库:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d
import Axes3D from matplotlib.colors
import LogNorm
import warnings
warnings.filterwarnings("ignore")
Beale函数
- Beale函数是在二维中定义的多峰非凸连续函数。
- 通常在(x,y)∈[-4.5,4.5]范围内进行评估。
- 该函数只有一个全局最小值(x,y)=(3,0.5)。
可视化Beale函数
由于Beale函数是一个介于-4.5和4.5之间的双变量函数,我们可以使用NumPy生成一个网格,将所有可能的x和y值传递给该函数。这使我们能够在每一个可能的点上得到Beale函数的输出,我们可以使用这些输出将函数可视化。
当我们将优化问题与神经网络联系起来时,我们将(x,y)称为(w1,w2)。当使用神经网络时,我们将目标函数称为损失函数,并将函数的输出称为损失。在这种情况下,我们将Beale函数称为损失函数,将输出称为损失。
# Defining function
f = lambda x, y: (1.5 - x + x*y)**2 + (2.25 - x + x*y**2)**2 + (2.625 - x + x*y**3)**2
# Defining the range of w1 and w2, step size
w1_min, w1_max, w1_step = -4.5, 4.5, .2
w2_min, w2_max, w2_step = -4.5, 4.5, .2
# Global minima of the function
minima_ = [3, 0.5]
# generating meshgrid
w1, w2 = np.meshgrid(np.arange(w1_min, w1_max + w1_step, w1_step),
np.arange(w2_min, w2_max + w2_step, w2_step))
losses = f(w1, w2)
现在,我们将使用以下Python代码绘制损失。
fig, ax = plt.subplots(figsize=(10, 6))
ax.contour(w1, w2, losses, levels=np.logspace(0, 5, 35),
norm=LogNorm(), cmap=plt.cm.jet, alpha = 0.8)
ax.plot(*minima_, 'r*', color='r',
markersize=10, alpha=0.7, label='minima')
ax.set_xlabel('w1')
ax.set_ylabel('w2')
ax.set_xlim((w1_min, w1_max))
ax.set_ylim((w2_min, w2_max))
ax.legend(bbox_to_anchor=(1.2, 1.))
ax.set_title("Beale Function")
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
输出
如图所示,蓝色区域表示Beale函数的值较低,红色区域表示Beale函数的值较高。最小值(3,0.5)用星号表示。
设定参数
当我们使用PyTorch时,我们需要将要优化的参数放入nn.Module类中。__init__()以(X,Y)作为输入来初始化参数(W1,W2)。另外,我们将在forward函数中编写Beale方程。
class Net_Beale(torch.nn.Module):
def __init__(self, x, y):
super(Net_Beale, self).__init__()
self.w1 = torch.nn.Parameter(torch.tensor([x]))
self.w2 = torch.nn.Parameter(torch.tensor([y]))
def forward(self):
# Beale Function Equation
a = (1.5 - self.w1 + self.w1*self.w2)**2
b = (2.25 - self.w1 + self.w1*self.w2**2)**2
c = (2.625 - self.w1 + self.w1*self.w2**3)**2
return a+b+c
优化和保存轨迹
下面的函数初始化网络的参数,初始化优化器,并在收集参数路径的同时针对指定的步骤数运行优化。
def get_trajectory(x, y, optim, lr, epochs, interval=1):
# Initialize Network
net = Net_Beale(x,y)
# Initialize Optimizer
if optim == "sgd":
optim = torch.optim.SGD(net.parameters(), lr)
elif optim == "mom":
optim = torch.optim.SGD(net.parameters(), lr, momentum=0.9)
elif optim == "adam":
optim = torch.optim.Adam(net.parameters(), lr)
# Initialize Trackers
w_1s = []
w_2s = []
# Run Optimization
for i in range(epochs):
optim.zero_grad()
o = net()
o.backward()
if i % interval == 0:
# Append current w1 and w2 to trackers
w_1s.append(net.w1.item())
w_2s.append(net.w2.item())
optim.step()
w_1s.append(net.w1.item())
w_2s.append(net.w2.item())
# Join w1's and w2's into one array
trajectory = np.array([w_1s, w_2s])
return trajectory
轨迹之间的比较
下面的函数给出了初始位置、优化器列表以及相应的学习率和epoch,并绘制了具有指定设置的算法轨迹。
def compare_trajectories(x, y, epochs, optims, lrs):
colors = ['k', 'g', 'b', 'r', 'y', 'c', 'm']
trajectories = []
names = []
# Loop on all optimizers in list
for ep, optim, lr in zip(epochs, optims, lrs):
trajectory = get_trajectory(float(x), float(y), optim=optim, lr=lr, epochs=ep)
names.append(optim)
trajectories.append(trajectory)
# Plot the Contour plot of Beale Function and trajectories of optimizers
fig, ax = plt.subplots(figsize=(10, 6))
ax.contour(w1, w2, losses, levels=np.logspace(0, 5, 35),
norm=LogNorm(), cmap=plt.cm.jet, alpha = 0.5)
for i, trajectory in enumerate(trajectories):
ax.quiver(trajectory[0,:-1], trajectory[1,:-1], trajectory[0,1:]-trajectory[0,:-1],
trajectory[1,1:]-trajectory[1,:-1], scale_units='xy', angles='xy', scale=1,
color=colors[i], label=names[i], alpha=0.8)
start_ =[x,y]
ax.plot(*start_, 'r*', color='k',markersize=10, alpha=0.7, label='start')
ax.plot(*minima_, 'r*', color='r',markersize=10, alpha=0.7, label='minima')
ax.set_xlabel('w1')
ax.set_ylabel('w2')
ax.set_xlim((w1_min, w1_max))
ax.set_ylim((w2_min, w2_max))
ax.set_title("Initial point - ({},{})".format(x,y))
ax.legend(bbox_to_anchor=(1.2, 1.))
fig.suptitle("Optimization Trajectory")
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
尝试不同的初始点
在设置好一切之后,我们现在准备用不同的初始点来比较这三种算法。
学习率设置:
- SGD— 0.0001
- momentum— 0.0001
- Adam-0.01
我们将对不同算法的初始点使用相同的学习率,以保持分析的简单性,因为我们没有进行超参数调优。
# Settings for optimizers
epochs = [10000] * 3
optims = ['sgd', 'mom', 'adam']
lrs = [0.0001, 0.0001, 0.01]
情况1:接近极小点
# A point closer to minima
x = 2.5
y = 2.
compare_trajectories(x, y, epochs, optims, lrs)
这三个都达到了全局最小值,让我们再进一步看看会发生什么。
情况2:离极小点远一点
# A little away in the same region
x = 1.5
y = 2.5
compare_trajectories(x, y, epochs, optims, lrs)
如上图所示,Adam优化器趋向于一个局部的最小值并停滞不前,而sgd和momentum则达到了全局最小值。需要注意的是,我们并没有改变这里的学习率,我们关注的是初始点对优化的影响。
情况3:远离极小值
# Lower left region
x = -4
y = -4
compare_trajectories(x, y, epochs, optims, lrs)
结论
初始点在优化问题中起着至关重要的作用。在这里,我们试图解决一个二维问题,与使用大型数据集和上百万个参数(维度)时找到最小值相比,这很容易。虽然我们在这里没有调优超参数,但是我们可以使用正确的超参数集有效地将优化推向正确的方向。
猜你喜欢
- 2024-10-14 Python之Matplotlib数据可视化一:简易线形图
- 2024-10-14 圆:circle-sin-cos动画的matplotlib
- 2024-10-14 python 100天 68 利用Python绘制两个波形正弦sin函数相关性
- 2024-10-14 画直线不简单!python-matplotlib告诉你为什么
- 2024-10-14 用Python下一场流星雨,女生看了都哭了
- 2024-10-14 手把手教你使用Numpy、Matplotlib、Scipy等5个Python库
- 2024-10-14 走进Matplotlib世界(一)(matplotlib.org)
- 2024-10-14 Python 数据分析——matplotlib 坐标变换和注释
- 2024-10-14 利用axe对象绘制地图局部缩放图(下面几种建模对象能通过基本实体工具直接绘制的是)
- 2024-10-14 Python动态绘图的方法(上)(canvas python动态绘图)
- 最近发表
- 标签列表
-
- cmd/c (64)
- c++中::是什么意思 (83)
- 标签用于 (65)
- 主键只能有一个吗 (66)
- c#console.writeline不显示 (75)
- pythoncase语句 (81)
- es6includes (73)
- sqlset (64)
- windowsscripthost (67)
- apt-getinstall-y (86)
- node_modules怎么生成 (76)
- chromepost (65)
- c++int转char (75)
- static函数和普通函数 (76)
- el-date-picker开始日期早于结束日期 (70)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- & (66)
- java (73)
- js数组插入 (83)
- linux删除一个文件夹 (65)
- mac安装java (72)
- eacces (67)
- 查看mysql是否启动 (70)
- 无效的列索引 (74)