Python - random 和 numpy.random 线程安全

本文最后更新于:2021年11月1日 中午

代码中经常会用到随机的部分,此时需要使用程序自带的伪随机数发生器,本文探讨python随机数发生器的线程安全相关内容。

对比内容

  • python 原生 random 库
  • numpy 中 random 包

随机数安全需求

  • 我们需要随机数,但是特定条件下需要稳定的随机

  • 这表示我们需要产生固定的随机数,在保证算法或程序正常运行的同时保证结果稳定可复现,对于调试程序是否有必要

  • 安全需求为:在多线程情况下仍然可以保证稳定的伪随机

random

random 确定随机序列的方法有 seed 和 state 两种

random.seed(n)

可以使得随机数发生器以 $n$ 为种子产生随后的序列

  • 当运行 random.seed() 时表明使用当前系统时间作为随机种子,也就是随机重置随机数发生器
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import random


def get_random_num(tag):
random.seed()
for _ in range(100):
random.random()
print(tag)
print(random.random())


if __name__ == '__main__':

for index in range(5):
get_random_num(index)

输出

1
2
3
4
5
6
7
8
9
10
0
0.08855079666960641
1
0.9249561135155114
2
0.847403937717389
3
0.9581127578680636
4
0.3559537092834082

这表明变化的seed条件会产生不同的随机序列

  • 当固定随机种子时
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import random


def get_random_num(tag):
random.seed(7) # 固定随机种子为 7
for _ in range(100):
random.random()
print(tag)
print(random.random())


if __name__ == '__main__':

for index in range(5):
get_random_num(index)

输出

1
2
3
4
5
6
7
8
9
10
0
0.17621772849037032
1
0.17621772849037032
2
0.17621772849037032
3
0.17621772849037032
4
0.17621772849037032

这表明固定的seed会产生相同的随机序列

random.setstate(state)

random.setstate 可以将随机数发生状态设置为特定的某个情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import random


def get_random_num(tag, state):
random.setstate(state)
for _ in range(100):
random.random()
print(tag)
print(random.random())


if __name__ == '__main__':
cur_state = random.getstate()
for index in range(5):
get_random_num(index, cur_state)

输出

1
2
3
4
5
6
7
8
9
10
0
0.7362097058247947
1
0.7362097058247947
2
0.7362097058247947
3
0.7362097058247947
4
0.7362097058247947

表明固定的state会产生相同的随机序列

random.seed 线程安全

我们设计一个稍微复杂一些的多线程随机数发生的情况

程序会使用单线程和多线程的方法产生随机数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import threading
import numpy as np
import random
import time


def get_random_num(tag):
random.seed(7)
for _ in range(random.randint(10, 20)):
time.sleep(0.1 * random.random())
random.random()

print(tag, '-', random.random())


if __name__ == '__main__':

print("############ 单线程 ###########")
for index in range(5):
get_random_num(index)

print("############ 多线程 ###########")
thread_list = list()
for index in range(5):
t = threading.Thread(target=get_random_num, args=(index, ))
thread_list.append(t)
t.start()

for t in thread_list:
t.join()
pass

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
-> 第一次运行

############ 单线程 ###########
0 - 0.1878710267871435
1 - 0.1878710267871435
2 - 0.1878710267871435
3 - 0.1878710267871435
4 - 0.1878710267871435
############ 多线程 ###########
3 - 0.07031557615348971
2 - 0.9554680239214713
4 - 0.48806805903541084
1 - 0.4803951046156485
0 - 0.6920567688453093

-> 第二次运行

############ 单线程 ###########
0 - 0.1878710267871435
1 - 0.1878710267871435
2 - 0.1878710267871435
3 - 0.1878710267871435
4 - 0.1878710267871435
############ 多线程 ###########
4 - 0.125491512495977
3 - 0.4406268683247505
1 - 0.9554680239214713
2 - 0.4803951046156485
0 - 0.6920567688453093

-> 第三次运行

############ 单线程 ###########
0 - 0.1878710267871435
1 - 0.1878710267871435
2 - 0.1878710267871435
3 - 0.1878710267871435
4 - 0.1878710267871435
############ 多线程 ###########
4 - 0.19060953756680787
3 - 0.48806805903541084
0 - 0.4803951046156485
1 2 -- 0.74035122442809410.6920567688453093

可以看到多线程会打乱本来稳定的随机数发生器序列,产生不再那么稳定的随机数

random.setstate() 线程安全

我们将 random.seed 替换为 random.setstate

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import threading
import numpy as np
import random
import time


def get_random_num(tag, state):
random.setstate(state)
for _ in range(random.randint(10, 20)):
time.sleep(0.1 * random.random())
random.random()

print(tag, '-', random.random())


if __name__ == '__main__':

cur_state = random.getstate()

print("############ 单线程 ###########")
for index in range(5):
get_random_num(index, cur_state)

print("############ 多线程 ###########")
thread_list = list()
for index in range(5):
t = threading.Thread(target=get_random_num, args=(index, cur_state, ))
thread_list.append(t)
t.start()

for t in thread_list:
t.join()
pass

输出

1
2
3
4
5
6
7
8
9
10
11
12
############  单线程  ###########
0 - 0.15376246243379788
1 - 0.15376246243379788
2 - 0.15376246243379788
3 - 0.15376246243379788
4 - 0.15376246243379788
############ 多线程 ###########
0 - 0.37956604279157746
4 - 0.5552055004170326
2 - 0.40568119200883823
3 - 0.09736679342311894
1 - 0.9874404365309796

可以看到多线程输出的还是纷乱的随机数,表明设置状态还是会受到多线程的干扰

得出综合结论: python自带 random 模块线程不安全

numpy.random

numpy 也存在 seed 和 state 两种随机数状态设定策略

二者固定时也可以确定随机数发生序列,我们直接进入线程安全实验

numpy.random.seed 线程安全

设置和random模块测试相同的程序,仅替换随机数产生器为numpy

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import threading
import numpy as np

import time


def get_random_num(tag):
np.random.seed(7)
for _ in range(np.random.randint(10, 20)):
time.sleep(0.1 * np.random.random())
np.random.random()

print(tag, '-', np.random.random())


if __name__ == '__main__':

print("############ 单线程 ###########")
for index in range(5):
get_random_num(index)

print("############ 多线程 ###########")
thread_list = list()
for index in range(5):
t = threading.Thread(target=get_random_num, args=(index, ))
thread_list.append(t)
t.start()

for t in thread_list:
t.join()
pass

输出

1
2
3
4
5
6
7
8
9
10
11
12
############  单线程  ###########
0 - 0.4677528597449807
1 - 0.4677528597449807
2 - 0.4677528597449807
3 - 0.4677528597449807
4 - 0.4677528597449807
############ 多线程 ###########
2 - 0.7969513574435644
1 - 0.27425859319442125
3 - 0.8767009301816353
0 - 0.8078354592960695
4 - 0.16471665877901287

numpy 的 seed 也没有抗住多线程测试

numpy.random.set_state(state) 线程安全

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import threading
import numpy as np

import time


def get_random_num(tag, state):
np.random.set_state(state)
for _ in range(np.random.randint(10, 20)):
time.sleep(0.1 * np.random.random())
np.random.random()

print(tag, '-', np.random.random())


if __name__ == '__main__':
state = np.random.get_state()
print("############ 单线程 ###########")
for index in range(5):
get_random_num(index, state)

print("############ 多线程 ###########")
thread_list = list()
for index in range(5):
t = threading.Thread(target=get_random_num, args=(index, state, ))
thread_list.append(t)
t.start()

for t in thread_list:
t.join()
pass

输出

1
2
3
4
5
6
7
8
9
10
11
12
############  单线程  ###########
0 - 0.08401109992998335
1 - 0.08401109992998335
2 - 0.08401109992998335
3 - 0.08401109992998335
4 - 0.08401109992998335
############ 多线程 ###########
3 - 0.3760344064206267
1 - 0.9126801411121602
4 - 0.3426880186919875
0 - 0.11892388333372905
2 - 0.8831013606778471

仍然不是线程安全

问题分析

  • 总结下来,random模块和numpy模块的 seed 和 state 系列方法都没有做到线程安全

  • 事实上setstate 一类的方法和 seed 方法原理相同,都是设置随机数发生器的初始状态,问题在于这种设置是全局的

  • 当多线程穿插使用时会打乱这个序列

  • 因此线程安全的随机数发生器必须做到相互隔离

  • 解决问题的终极方案为 numpy.random.RandomState

numpy.random.RandomState

RandomState方法之所以解决问题,在于它不仅设置了随机数发生器的初始状态,也会生成一个随机数发生器实例,产生一个独立的变量生成随机数

只要不是同一个实例,相互之间就不会产生影响

上代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import threading
import numpy as np

import time


def get_random_num(tag):
rand_obj = np.random.RandomState(7)
for _ in range(rand_obj.randint(10, 20)):
time.sleep(0.1 * rand_obj.random())
rand_obj.random()

print(tag, '-', rand_obj.random())


if __name__ == '__main__':

print("############ 单线程 ###########")
for index in range(5):
get_random_num(index)

print("############ 多线程 ###########")
thread_list = list()
for index in range(5):
t = threading.Thread(target=get_random_num, args=(index, ))
thread_list.append(t)
t.start()

for t in thread_list:
t.join()
pass

输出

1
2
3
4
5
6
7
8
9
10
11
12
############  单线程  ###########
0 - 0.4677528597449807
1 - 0.4677528597449807
2 - 0.4677528597449807
3 - 0.4677528597449807
4 - 0.4677528597449807
############ 多线程 ###########
2 - 0.4677528597449807
043 - - 0.46775285974498070.4677528597449807

1 - - 0.46775285974498070.4677528597449807

这里输出是乱的,解释一下,这不是我的笔误,是因为随机数完全相同,几个线程的运行时间相同,就会在同一时间向终端输出内容,导致输出有点乱

不过还是可以看出来每个发生器产生的随机数完全相同,证实了 RandomState 的线程安全性

结论

  1. seed , state 一类方法可以确定随机数发生序列,但这种全局配置的随机数确定序列做不到线程安全

  2. 线程安全需要确定序列的同时创建线程内的随机数发生器实例,保证线程之间互不影响,才会产生真正的随机序列

  3. numpy.random.RandomState —— YYDS

参考资料