本文最后更新于:2024年1月14日 晚上
在Keras网络训练过程中,
fit-generator
为我们提供了很多便利。调用fit-generator
时,每个epoch训练结束后会使用验证数据检测模型性能,Keras使用model.evaluate_generator
提供该功能。然而我遇到了需要提取验证集y_pred的需求,在网上没有找到现有的功能实现方法,于是自己对源码进行了微调,实现了可配置提取验证集模型预测结果的功能,记录如下。
原理简介
通过查看源代码,发现Keras调用了
model.evaluate_generator
验证数据,该函数最终调用的是TensorFlow(我用的后端是tf)的TF_SessionRunCallable
函数,封装得很死,功能是以数据为输入,输出模型预测的结果并与真实标签比较并计算评价函数得到结果。过程中不保存、不返回预测结果,这部分没有办法修改,但可以在评价数据的同时对数据进行预测,得到结果并记录下来,传入到
epoch_logs
中,随后在回调函数的on_epoch_end
中尽情使用。
代码修改
- Keras版本 2.2.4 其他版本不保证一定使用相同的方法,但大体思路不变
model.fit_generator
找到
fit_generator
函数定义位置,加入控制参数get_predict
:
1 |
|
1 |
|
training_generator.fit_generator
找到
training_generator.fit_generator
定义位置,加入get_predict
:
1 |
|
修改 # Epoch finished. 注释后的模块,可以看到Keras中
fit_generator
就是用model.evaluate_generator
对验证集评估的:
1 |
|
model.evaluate_generator
进入
model.evaluate_generator
函数,加入get_predict
变量:
1 |
|
1 |
|
training_generator.evaluate_generator
进入
training_generator.evaluate_generator
,添加get_predict
变量,新建三个变量:
1 |
|
在核心循环
while steps_done < steps:
中加入预测变量的内容:
1 |
|
修改返回值:
1 |
|
至此核心的功能已经实现,但还有一个小问题。
keras.callbacks.TensorBoard._write_logs
Keras的Tensorboard会记录logs中的内容,但是他只认识 int, float 等数值格式,我们保存在log中的复杂字典他没办法写入tesnorboard,需要对
_write_logs
做微小的调整:
1 |
|
大功告成!
测试
随便写个带
on_epoch_end
的回调函数,将get_predict
设置为True,测试logs中是否有我们想要的数据:
1 |
|
回调函数设断点,输出logs:
1 |
|
之后这些结果任君处置了;
附
将
get_predict
设为 False 时则屏蔽了我们做出的所有修改,与原始Keras代码完全相同;目前没有发现其他的问题,有任何不对头可以随时交流。
“觉得不错的话,给点打赏吧 ୧(๑•̀⌄•́๑)૭”
微信支付
支付宝支付