forked from KoboldAI/KoboldAI-Client
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
305 lines (285 loc) · 13.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
from threading import Timer
import re
import shutil
import json
import subprocess
import tempfile
import requests
import requests.adapters
import time
from tqdm.auto import tqdm
import os
import itertools
from typing import Optional
vars = None
num_shards: Optional[int] = None
current_shard = 0
from_pretrained_model_name = ""
from_pretrained_index_filename: Optional[str] = None
from_pretrained_kwargs = {}
bar = None
default_sampler_order = [0, 1, 2, 3, 4, 5]
#==================================================================#
# Decorator to prevent a function's actions from being run until
# at least x seconds have passed without the function being called
#==================================================================#
def debounce(wait):
def decorator(fun):
def debounced(*args, **kwargs):
def call_it():
fun(*args, **kwargs)
try:
debounced.t.cancel()
except AttributeError:
pass
debounced.t = Timer(wait, call_it)
debounced.t.start()
return debounced
return decorator
#==================================================================#
# Replace fancy quotes and apostrope's with standard ones
#==================================================================#
def fixquotes(txt):
txt = txt.replace("“", '"')
txt = txt.replace("”", '"')
txt = txt.replace("’", "'")
txt = txt.replace("`", "'")
return txt
#==================================================================#
#
#==================================================================#
def trimincompletesentence(txt):
# Cache length of text
ln = len(txt)
# Find last instance of punctuation (Borrowed from Clover-Edition by cloveranon)
lastpunc = max(txt.rfind("."), txt.rfind("!"), txt.rfind("?"))
# Is this the end of a quote?
if(lastpunc < ln-1):
if(txt[lastpunc+1] == '"'):
lastpunc = lastpunc + 1
if(lastpunc >= 0):
txt = txt[:lastpunc+1]
return txt
#==================================================================#
#
#==================================================================#
def replaceblanklines(txt):
txt = txt.replace("\n\n", "\n")
return txt
#==================================================================#
#
#==================================================================#
def removespecialchars(txt, vars=None):
if vars is None or vars.actionmode == 0:
txt = re.sub(r"[#/@%<>{}+=~|\^]", "", txt)
else:
txt = re.sub(r"[#/@%{}+=~|\^]", "", txt)
return txt
#==================================================================#
# If the next action follows a sentence closure, add a space
#==================================================================#
def addsentencespacing(txt, vars):
# Don't add sentence spacing if submission is empty or starts with whitespace
if(len(txt) == 0 or len(txt) != len(txt.lstrip())):
return txt
# Get last character of last action
if(len(vars.actions) > 0):
if(len(vars.actions[vars.actions.get_last_key()]) > 0):
action = vars.actions[vars.actions.get_last_key()]
lastchar = action[-1] if len(action) else ""
else:
# Last action is blank, this should never happen, but
# since it did let's bail out.
return txt
else:
action = vars.prompt
lastchar = action[-1] if len(action) else ""
if(lastchar == "." or lastchar == "!" or lastchar == "?" or lastchar == "," or lastchar == ";" or lastchar == ":"):
txt = " " + txt
return txt
def singlelineprocessing(txt, vars):
txt = vars.regex_sl.sub('', txt)
if(len(vars.actions) > 0):
if(len(vars.actions[vars.actions.get_last_key()]) > 0):
action = vars.actions[vars.actions.get_last_key()]
lastchar = action[-1] if len(action) else ""
else:
# Last action is blank, this should never happen, but
# since it did let's bail out.
return txt
else:
action = vars.prompt
lastchar = action[-1] if len(action) else ""
if(lastchar != "\n"):
txt = txt + "\n"
return txt
#==================================================================#
# Cleans string for use in file name
#==================================================================#
def cleanfilename(filename):
filteredcharacters = ('/','\\')
filename = "".join(c for c in filename if c not in filteredcharacters).rstrip()
return filename
#==================================================================#
# Newline substitution for fairseq models
#==================================================================#
def encodenewlines(txt):
if(vars.newlinemode == "s"):
return txt.replace('\n', "</s>")
return txt
def decodenewlines(txt):
if(vars.newlinemode == "s"):
return txt.replace("</s>", '\n')
if(vars.newlinemode == "ns"):
return txt.replace("</s>", '')
return txt
#==================================================================#
# Returns number of layers given an HF model config
#==================================================================#
def num_layers(config):
return config.num_layers if hasattr(config, "num_layers") else config.n_layer if hasattr(config, "n_layer") else config.num_hidden_layers
#==================================================================#
# Downloads huggingface checkpoints using aria2c if possible
#==================================================================#
def aria2_hook(pretrained_model_name_or_path: str, force_download=False, cache_dir=None, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers
import transformers.modeling_utils
from huggingface_hub import HfFolder
if shutil.which("aria2c") is None: # Don't do anything if aria2 is not installed
return
if local_files_only: # If local_files_only is true, we obviously don't need to download anything
return
if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path + ".index") or transformers.modeling_utils.is_remote_url(pretrained_model_name_or_path):
return
if proxies:
print("WARNING: KoboldAI does not support using aria2 to download models from huggingface.co through a proxy. Disabling aria2 download mode.")
return
if use_auth_token:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
_cache_dir = str(cache_dir) if cache_dir is not None else transformers.TRANSFORMERS_CACHE
sharded = False
headers = {"user-agent": transformers.file_utils.http_user_agent(user_agent)}
if use_auth_token:
headers["authorization"] = f"Bearer {use_auth_token}"
def is_cached(url):
try:
transformers.file_utils.get_from_cache(url, cache_dir=cache_dir, local_files_only=True)
except (FileNotFoundError, transformers.file_utils.EntryNotFoundError):
return False
return True
while True: # Try to get the huggingface.co URL of the model's pytorch_model.bin or pytorch_model.bin.index.json file
try:
filename = transformers.modeling_utils.WEIGHTS_INDEX_NAME if sharded else transformers.modeling_utils.WEIGHTS_NAME
except AttributeError:
return
url = transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, filename, revision=revision, mirror=mirror)
if is_cached(url) or requests.head(url, allow_redirects=True, proxies=proxies, headers=headers):
break
if sharded:
return
else:
sharded = True
if not sharded: # If the model has a pytorch_model.bin file, that's the only file to download
filenames = [transformers.modeling_utils.WEIGHTS_NAME]
else: # Otherwise download the pytorch_model.bin.index.json and then let aria2 download all the pytorch_model-#####-of-#####.bin files mentioned inside it
map_filename = transformers.file_utils.cached_path(url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, user_agent=user_agent)
with open(map_filename) as f:
map_data = json.load(f)
filenames = set(map_data["weight_map"].values())
urls = [transformers.file_utils.hf_bucket_url(pretrained_model_name_or_path, n, revision=revision, mirror=mirror) for n in filenames]
if not force_download:
urls = [u for u in urls if not is_cached(u)]
if not urls:
return
etags = [h.get("X-Linked-Etag") or h.get("ETag") for u in urls for h in [requests.head(u, headers=headers, allow_redirects=False, proxies=proxies, timeout=10).headers]]
headers = [requests.head(u, headers=headers, allow_redirects=True, proxies=proxies, timeout=10).headers for u in urls]
filenames = [transformers.file_utils.url_to_filename(u, t) for u, t in zip(urls, etags)]
for n in filenames:
path = os.path.join(_cache_dir, "kai-tempfile." + n + ".aria2")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, "kai-tempfile." + n)
if os.path.exists(path):
os.remove(path)
if force_download:
path = os.path.join(_cache_dir, n + ".json")
if os.path.exists(path):
os.remove(path)
path = os.path.join(_cache_dir, n)
if os.path.exists(path):
os.remove(path)
total_length = sum(int(h["Content-Length"]) for h in headers)
lengths = {}
aria2_config = "\n".join(f"{u}\n out=kai-tempfile.{n}" for u, n in zip(urls, filenames)).encode()
s = requests.Session()
s.mount("http://", requests.adapters.HTTPAdapter(max_retries=requests.adapters.Retry(total=120, backoff_factor=1)))
bar = None
done = False
secret = os.urandom(17).hex()
try:
with tempfile.NamedTemporaryFile("w+b", delete=False) as f:
f.write(aria2_config)
f.flush()
p = subprocess.Popen(["aria2c", "-x", "10", "-s", "10", "-j", "10", "--enable-rpc=true", f"--rpc-secret={secret}", "--rpc-listen-port", str(vars.aria2_port), "--disable-ipv6", "--file-allocation=trunc", "--allow-overwrite", "--auto-file-renaming=false", "-d", _cache_dir, "-i", f.name, "-U", transformers.file_utils.http_user_agent(user_agent)] + (["-c"] if not force_download else []) + ([f"--header='Authorization: Bearer {token}'"] if use_auth_token else []), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
while p.poll() is None:
r = s.post(f"http://localhost:{vars.aria2_port}/jsonrpc", json={"jsonrpc": "2.0", "id": "kai", "method": "aria2.tellActive", "params": [f"token:{secret}"]}).json()["result"]
if not r:
s.close()
if bar is not None:
bar.n = bar.total
bar.close()
p.terminate()
done = True
break
if bar is None:
bar = tqdm(total=total_length, desc=f"[aria2] Downloading model", unit="B", unit_scale=True, unit_divisor=1000)
visited = set()
for x in r:
filename = x["files"][0]["path"]
lengths[filename] = (int(x["completedLength"]), int(x["totalLength"]))
visited.add(filename)
for k, v in lengths.items():
if k not in visited:
lengths[k] = (v[1], v[1])
bar.n = sum(v[0] for v in lengths.values())
bar.update()
time.sleep(0.1)
path = f.name
except Exception as e:
p.terminate()
raise e
finally:
try:
os.remove(path)
except OSError:
pass
code = p.wait()
if not done and code:
raise OSError(f"aria2 exited with exit code {code}")
for u, t, n in zip(urls, etags, filenames):
os.rename(os.path.join(_cache_dir, "kai-tempfile." + n), os.path.join(_cache_dir, n))
with open(os.path.join(_cache_dir, n + ".json"), "w") as f:
json.dump({"url": u, "etag": t}, f)
#==================================================================#
# Given the path to a pytorch_model.bin.index.json, returns how many
# shards there are in the model
#==================================================================#
def get_num_shards(filename):
with open(filename) as f:
map_data = json.load(f)
return len(set(map_data["weight_map"].values()))
#==================================================================#
# Given the name/path of a sharded model and the path to a
# pytorch_model.bin.index.json, returns a list of weight names in the
# sharded model. Requires lazy loader to be enabled to work properl
#==================================================================#
def get_sharded_checkpoint_num_tensors(pretrained_model_name_or_path, filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, use_auth_token=None, user_agent=None, revision=None, mirror=None, **kwargs):
import transformers.modeling_utils
import torch
shard_paths, _ = transformers.modeling_utils.get_checkpoint_shard_files(pretrained_model_name_or_path, filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, mirror=mirror)
return list(itertools.chain(*(torch.load(p, map_location="cpu").keys() for p in shard_paths)))