本文最后更新于:2024年5月7日 下午
在前文基础上,我们已经获得了数据、张量和损失函数, 本文介绍 Pytorch 的进行训练和评估的核心流程 。
参考 深入浅出PyTorch ,系统补齐基础知识。
本节目录
- PyTorch的训练/评估模式的开启
- 完整的训练/评估流程
模型模式
首先应该设置模型的状态:如果是训练状态,那么模型的参数应该支持反向传播的修改;如果是验证/测试状态,则不应该修改模型参数。在PyTorch中,模型的状态设置非常简便,如下的两个操作二选一即可:
1 |
|
model.train()
model.train()
的作用是启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加 model.train()
。model.train()
是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()
是随机取一部分网络连接来训练更新参数。
model.eval()
model.eval()
的作用是不启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()
。model.eval()
是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()
是利用到了所有网络连接,即不进行随机舍弃神经元。
训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval()
,否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。
训练流程
数据加载
我们前面在DataLoader构建完成后介绍了如何从中读取数据,在训练过程中使用类似的操作即可,区别在于此时要用for循环读取DataLoader中的全部数据。
1 |
|
之后将数据放到GPU上用于后续计算,此处以.cuda()为例
1 |
|
zero_grad
开始用当前批次数据做训练时,应当先将优化器的梯度置零:
1 |
|
函数会遍历模型的所有参数,通过内置方法截断反向传播的梯度流,再将每个参数的梯度值设为0,即上一次的梯度记录被清空。
loss
之后将data送入模型中训练:
1 |
|
根据预先定义的criterion计算损失函数:
1 |
|
backward
将loss反向传播回网络:
1 |
|
PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。
具体来说,torch.tensor是autograd包的基础类,如果你设置tensor的requires_grads为True,就会开始跟踪这个tensor上面的所有运算,如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。
更具体地说,损失函数loss是由模型的所有权重w经过一系列运算得到的,若某个w的requires_grads为True,则w的所有上层参数(后面层的权重w)的.grad_fn属性中就保存了对应的运算,然后在使用loss.backward()后,会一层层的反向传播计算每个w的梯度值,并保存到该w的.grad属性中。
如果没有进行tensor.backward()的话,梯度值将会是None,因此loss.backward()要写在optimizer.step()之前。
更新参数
使用优化器更新模型参数:
1 |
|
step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。
注意:optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的。
验证流程
验证/测试的流程基本与训练过程一致,不同点在于:
- 需要预先设置torch.no_grad,以及将model调至eval模式
- 不需要将优化器的梯度置零
- 不需要将loss反向回传到网络
- 不需要更新optimizer
示例代码
一个完整的图像分类的训练过程如下所示:
1 |
|
对应的,一个完整图像分类的验证过程如下所示:
1 |
|
参考资料
文章链接:
https://www.zywvvd.com/notes/study/deep-learning/pytorch/torch-learning/torch-learning-7/
“觉得不错的话,给点打赏吧 ୧(๑•̀⌄•́๑)૭”
微信支付
支付宝支付