Keras 在fit-generator中获取验证数据的y_true和y_preds

本文最后更新于:2022年7月4日 上午

在Keras网络训练过程中,fit-generator为我们提供了很多便利。调用fit-generator时,每个epoch训练结束后会使用验证数据检测模型性能,Keras使用model.evaluate_generator提供该功能。然而我遇到了需要提取验证集y_pred的需求,在网上没有找到现有的功能实现方法,于是自己对源码进行了微调,实现了可配置提取验证集模型预测结果的功能,记录如下。

原理简介

通过查看源代码,发现Keras调用了model.evaluate_generator验证数据,该函数最终调用的是TensorFlow(我用的后端是tf)的TF_SessionRunCallable函数,封装得很死,功能是以数据为输入,输出模型预测的结果并与真实标签比较并计算评价函数得到结果。

过程中不保存、不返回预测结果,这部分没有办法修改,但可以在评价数据的同时对数据进行预测,得到结果并记录下来,传入到epoch_logs中,随后在回调函数的on_epoch_end中尽情使用

代码修改

  • Keras版本 2.2.4 其他版本不保证一定使用相同的方法,但大体思路不变

model.fit_generator

找到fit_generator函数定义位置,加入控制参数get_predict

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def fit_generator(self, generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0,
get_predict = False): # 加入 get_predict
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
return training_generator.fit_generator(
self, generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_data=validation_data,
validation_steps=validation_steps,
class_weight=class_weight,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
shuffle=shuffle,
initial_epoch=initial_epoch,
get_predict = get_predict) # 加入 get_predict

training_generator.fit_generator

找到training_generator.fit_generator定义位置,加入get_predict:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def fit_generator(model,
generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0,
get_predict = False): # 加入 get_predict

修改 # Epoch finished. 注释后的模块,可以看到Keras中fit_generator就是用model.evaluate_generator对验证集评估的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Epoch finished.
if steps_done >= steps_per_epoch and do_validation:
if val_gen:

if get_predict:
## 如果启动获取预测结果功能,那么将get_predict设置为True
## 返回值会包括 gts_and_preds
val_outs, gts_and_preds = model.evaluate_generator(
val_enqueuer_gen,
validation_steps,
workers=0,
get_predict=get_predict)
else:
val_outs = model.evaluate_generator(
val_enqueuer_gen,
validation_steps,
workers=0)
else:
# No need for try/except because
# data has already been validated.
val_outs = model.evaluate(
val_x, val_y,
batch_size=batch_size,
sample_weight=val_sample_weights,
verbose=0)
val_outs = to_list(val_outs)
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o

## 将返回值 gts_and_preds 保存到 log 中
if get_predict:
epoch_logs['val_gts_and_preds'] = gts_and_preds

if callback_model.stop_training:
break

model.evaluate_generator

进入model.evaluate_generator函数,加入get_predict变量:

1
2
3
4
5
6
7
def evaluate_generator(self, generator,
steps=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0,
get_predict=False): # 加入get_predict变量
1
2
3
4
5
6
7
8
return training_generator.evaluate_generator(
self, generator,
steps=steps,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
verbose=verbose,
get_predict=get_predict) # 加入get_predict变量

training_generator.evaluate_generator

进入training_generator.evaluate_generator,添加get_predict变量,新建三个变量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def evaluate_generator(model, generator,
steps=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0,
get_predict=False): # 加入get_predict变量
"""See docstring for `Model.evaluate_generator`."""
model._make_test_function()

if hasattr(model, 'metrics'):
for m in model.stateful_metric_functions:
m.reset_states()
stateful_metric_indices = [
i for i, name in enumerate(model.metrics_names)
if str(name) in model.stateful_metric_names]
else:
stateful_metric_indices = []

steps_done = 0
wait_time = 0.01
outs_per_batch = []
batch_sizes = []

if get_predict:
preds_dict={} # 新建保存结果的dict
gt_per_batch = [] # 新建 y_true 的 list
pr_per_batch = [] # 新建 y_pred 的 list

在核心循环while steps_done < steps:中加入预测变量的内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
while steps_done < steps:
generator_output = next(output_generator)
if not hasattr(generator_output, '__len__'):
raise ValueError('Output of generator should be a tuple '
'(x, y, sample_weight) '
'or (x, y). Found: ' +
str(generator_output))
if len(generator_output) == 2:
x, y = generator_output
sample_weight = None
elif len(generator_output) == 3:
x, y, sample_weight = generator_output
else:
raise ValueError('Output of generator should be a tuple '
'(x, y, sample_weight) '
'or (x, y). Found: ' +
str(generator_output))
outs = model.test_on_batch(x, y, sample_weight=sample_weight)
outs = to_list(outs)
outs_per_batch.append(outs)

## 加入预测功能,保存preds和y_true
if get_predict:
preds = model.predict_on_batch(x)
gt_per_batch.append(y.tolist())
pr_per_batch.append(preds.tolist())

if x is None or len(x) == 0:
# Handle data tensors support when no input given
# step-size = 1 for data tensors
batch_size = 1
elif isinstance(x, list):
batch_size = x[0].shape[0]
elif isinstance(x, dict):
batch_size = list(x.values())[0].shape[0]
else:
batch_size = x.shape[0]
if batch_size == 0:
raise ValueError('Received an empty batch. '
'Batches should contain '
'at least one item.')
steps_done += 1
batch_sizes.append(batch_size)
if verbose == 1:
progbar.update(steps_done)
## 将结果保存到dict中
if get_predict:
preds_dict['y_true'] = gt_per_batch
preds_dict['y_pred'] = pr_per_batch

修改返回值:

1
2
3
4
5
if get_predict:
return unpack_singleton(averages), preds_dict

else:
return unpack_singleton(averages)

至此核心的功能已经实现,但还有一个小问题。

keras.callbacks.TensorBoard._write_logs

Keras的Tensorboard会记录logs中的内容,但是他只认识 int, float 等数值格式,我们保存在log中的复杂字典他没办法写入tesnorboard,需要对_write_logs做微小的调整:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def _write_logs(self, logs, index):
for name, value in logs.items():
if name in ['batch', 'size']:
continue
summary = tf.Summary()
summary_value = summary.value.add()
if isinstance(value, np.ndarray):
summary_value.simple_value = value.item()
## 跳过我们生成的字典
elif isinstance(value, dict):
pass
else:
summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, index)
self.writer.flush()

大功告成!

测试

随便写个带on_epoch_end的回调函数,将get_predict设置为True,测试logs中是否有我们想要的数据:

1
2
3
4
5
6
7
8
9
10
11
model.fit_generator(
generator = train_data_generator,
steps_per_epoch = 10,
epochs = config.Epochs,
verbose = 1,
use_multiprocessing=False,
validation_data=val_data_generator,
validation_steps=10,
callbacks = callbacks,
get_predict= True
)

回调函数设断点,输出logs:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
logs['val_gts_and_preds']
{'y_pred': [[[2.5419962184969336e-05, 0.9999746084213257],
[0.6694663763046265, 0.33053362369537354],
[0.3561754524707794, 0.643824577331543]],
[[5.548826155499231e-12, 1.0],
[2.701560219975363e-08, 1.0],
[4.0011427699937485e-06, 0.9999959468841553]],
[[7.97858723533551e-11, 1.0],
[2.3924835659272503e-06, 0.999997615814209],
[3.359668880875688e-07, 0.9999996423721313]],
[[0.06622887402772903, 0.9337711930274963],
[4.1211248458239425e-07, 0.9999996423721313],
[8.561290087527595e-06, 0.9999914169311523]],
[[9.313887403550325e-07, 0.9999990463256836],
[2.614793537247806e-08, 1.0],
[8.66139725985704e-06, 0.9999912977218628]],
[[7.047830763440288e-09, 1.0],
[0.010548637248575687, 0.9894513487815857],
[1.8744471252940542e-10, 1.0]],
[[8.760089875714527e-11, 1.0],
[0.0015734446933493018, 0.9984265565872192],
[1.5642463040421717e-06, 0.9999984502792358]],
[[0.004750440828502178, 0.9952495098114014],
[6.984401466070267e-07, 0.9999992847442627],
[0.00013592069444712251, 0.9998641014099121]],
[[7.22906318140204e-11, 1.0],
[2.402198795437016e-08, 1.0],
[9.673745138272238e-10, 1.0]],
[[3.1848256298872e-07, 0.9999996423721313],
[0.0035940599627792835, 0.9964058995246887],
[1.9458911912351162e-11, 1.0]]],
'y_true': [[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]]}

之后这些结果任君处置了;

get_predict设为 False 时则屏蔽了我们做出的所有修改,与原始Keras代码完全相同;

目前没有发现其他的问题,有任何不对头可以随时交流。


Keras 在fit-generator中获取验证数据的y_true和y_preds
https://www.zywvvd.com/notes/study/deep-learning/keras/get-gts-and-preds-from-evaluator/get-gts-and-preds-from-evaluator/
作者
Yiwei Zhang
发布于
2020年6月10日
许可协议