Pytorch是最近兴起的新的深度学习框架,不同于tensorflow的先构建计算图再进行运算,它是动态构建计算图,因此更为通俗易懂。
本文主要使用Pytorch来完成对三次函数的拟合任务,并比较集中不同的激活函数对预测结果的影响。
手动构造的三次函数如下(为了模拟真实场景,函数中加入了一些随机噪声)
在Pytorch中,为了完成模型的搭建,我们需要创建一个继承torch.nn.Module的类,类中我们需要实现forward函数和init函数。
代码如下:
import torchimport torch.nn.functional as Ffrom torch.autograd import Variableimport matplotlib.pyplot as pltx = torch.unsqueeze(torch.linspace(-1, 1, 300), dim=1) y = x.pow(3) - x.pow(2) + 0.2*torch.rand(x.size())plt.scatter(x.data.numpy(), y.data.numpy())plt.show()class Net(torch.nn.Module): # 继承 torch 的 Module def __init__(self, n_feature, n_hidden1, n_hidden2, n_output): super(Net, self).__init__() # 继承 __init__ 功能 # 定义每层用什么样的形式 self.hidden1 = torch.nn.Linear(n_feature, n_hidden1) # 隐藏层线性输出 self.hidden2 = torch.nn.Linear(n_hidden1, n_hidden2) # 隐藏层线性输出 self.predict = torch.nn.Linear(n_hidden2, n_output) # 输出层线性输出 def forward(self, x): # 这同时也是 Module 中的 forward 功能 # 正向传播输入值, 神经网络分析出输出值 x = F.relu(self.hidden1(x)) # 激励函数(隐藏层的线性值) x = F.relu(self.hidden2(x)) # 激励函数(隐藏层的线性值) x = self.predict(x) # 输出值 return x net = Net(n_feature=1, n_hidden1=10, n_hidden2=10, n_output=1)#创建一个类的对象optimizer = torch.optim.SGD(net.parameters(), lr=0.2) # 传入 net 的所有参数, 创建一个SGD优化器loss_func = torch.nn.MSELoss() # 定义损失函数plt.ion() # 画图plt.show()for t in range(300):#总共进行300代的优化 prediction = net(x) # 喂给 net 训练数据 x, 输出预测值 loss = loss_func(prediction, y) # 计算两者的误差 optimizer.zero_grad() # 清空上一步的残余更新参数值 loss.backward() # 误差反向传播, 计算参数更新值 optimizer.step() # 将参数更新值施加到 net 的 parameters 上 if t % 5 == 0: #每5步更新一次图像 plt.cla() plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'}) plt.pause(0.1)
比较一下使用不同激活函数的拟合效果:
RELU函数
softplus函数:
tanh函数:
对比三个激活函数的拟合效果,我们可以发现,经过relu函数拟合的曲线带有明显的“棱角”,这也是有relu函数自身不可导的特性所致。但是这种棱角特性在加深网络后就变得不再明显。同时,relu函数也是预测结果最好的激活函数,这符合我们的预期。
在过去的几十年间,大量的编程语言被发明、被取代、被修改或组合在一起。尽管人们多次试图创造一种通用的程序设计语言,却没有一次尝试是成功的。之所以有那么多种不同的编程语言存在的原因是,编写程序的初衷其实也各不相同;新手与老手之间技术的差距非常大,而且有许多语言对新手来说太难学;还有,不同程序之间的运行成本(runtime cost)各不相同。