-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathseed_everything.py
36 lines (25 loc) · 1.15 KB
/
seed_everything.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# basic random seed
import os
import random
import numpy as np
# torch random seed
import torch
DEFAULT_RANDOM_SEED = 2022
def seedBasic(seed=DEFAULT_RANDOM_SEED):
random.seed(seed) # python 내장함수 random에 seed를 추가합니다.
os.environ["PYTHONHASHSEED"] = str(seed) # python hash에 seed를 추가합니다.
np.random.seed(seed) # numpy library에 seed를 추가합니다.
def seedTorch(seed=DEFAULT_RANDOM_SEED): # pytorch를 위한 seed를 추가합니다.
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# basic + torch
def seedEverything(seed=DEFAULT_RANDOM_SEED): # python 내장과 torch 모두에 seed를 추가합니다.
seedBasic(seed)
seedTorch(seed)
def _init_fn(worker_id): # DataLoader에 seed를 추가하기 위한 함수입니다.
np.random.seed(
int(DEFAULT_RANDOM_SEED + worker_id)
) # default seed에 worker id를 더하여, 2개 이상의 worker에도 각각 seed를 부여하고, data loading이 재현가능하도록 합니다.