forked from Uminosachi/sd-webui-inpaint-anything
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathia_threading.py
172 lines (127 loc) · 4.73 KB
/
ia_threading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gc
import threading
from contextlib import ContextDecorator
from functools import wraps
import torch
from modules import devices, safe, shared
from modules.sd_models import load_model, reload_model_weights
backup_sd_model, backup_device, backup_ckpt_info = None, None, None
model_access_sem = threading.Semaphore(1)
def clear_cache():
gc.collect()
devices.torch_gc()
def webui_reload_model_weights(sd_model=None, info=None):
try:
reload_model_weights(sd_model=sd_model, info=info)
except Exception:
load_model(checkpoint_info=info)
def pre_offload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if shared.sd_model is not None:
backup_sd_model = shared.sd_model
backup_device = getattr(backup_sd_model, "device", devices.device)
backup_sd_model.to(devices.cpu)
clear_cache()
def await_pre_offload_model_weights():
global model_access_sem
thread = threading.Thread(target=pre_offload_model_weights, args=(model_access_sem,))
thread.start()
thread.join()
def pre_reload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None and backup_ckpt_info is not None:
webui_reload_model_weights(sd_model=shared.sd_model, info=backup_ckpt_info)
backup_ckpt_info = None
def await_pre_reload_model_weights():
global model_access_sem
thread = threading.Thread(target=pre_reload_model_weights, args=(model_access_sem,))
thread.start()
thread.join()
def backup_reload_ckpt_info(sem, info):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None:
backup_ckpt_info = shared.sd_model.sd_checkpoint_info
webui_reload_model_weights(sd_model=shared.sd_model, info=info)
def await_backup_reload_ckpt_info(info):
global model_access_sem
thread = threading.Thread(target=backup_reload_ckpt_info, args=(model_access_sem, info))
thread.start()
thread.join()
def post_reload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None and backup_ckpt_info is not None:
webui_reload_model_weights(sd_model=shared.sd_model, info=backup_ckpt_info)
backup_ckpt_info = None
def async_post_reload_model_weights():
global model_access_sem
thread = threading.Thread(target=post_reload_model_weights, args=(model_access_sem,))
thread.start()
def acquire_release_semaphore(sem):
with sem:
pass
def await_acquire_release_semaphore():
global model_access_sem
thread = threading.Thread(target=acquire_release_semaphore, args=(model_access_sem,))
thread.start()
thread.join()
def clear_cache_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
clear_cache()
res = func(*args, **kwargs)
clear_cache()
return res
return wrapper
def clear_cache_yield_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
clear_cache()
yield from func(*args, **kwargs)
clear_cache()
return wrapper
def post_reload_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
await_acquire_release_semaphore()
res = func(*args, **kwargs)
async_post_reload_model_weights()
return res
return wrapper
def offload_reload_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
await_pre_offload_model_weights()
res = func(*args, **kwargs)
async_post_reload_model_weights()
return res
return wrapper
def offload_reload_yield_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
await_pre_offload_model_weights()
yield from func(*args, **kwargs)
async_post_reload_model_weights()
return wrapper
class torch_default_load_cd(ContextDecorator):
def __init__(self):
self.backup_load = safe.load
def __enter__(self):
self.backup_load = torch.load
torch.load = safe.unsafe_torch_load
return self
def __exit__(self, *exc):
torch.load = self.backup_load
return False