forked from smallcloudai/refact
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
98 lines (86 loc) · 3.82 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
from copy import copy
from dataclasses import dataclass
from dataclasses import field
from setuptools import setup, find_packages
from typing import List, Set
@dataclass
class PyPackage:
requires: List[str] = field(default_factory=list)
requires_packages: List[str] = field(default_factory=list)
data: List[str] = field(default_factory=list)
all_refact_packages = {
"code_contrast": PyPackage(
requires=["cdifflib", "termcolor", "numpy", "dataclasses"],
requires_packages=["refact_encoding"]),
"known_models_db": PyPackage(
requires=["dataclasses", "dataclasses_json"],
data=["refact_toolbox_db/htmls/*.html"]),
"refact_encoding": PyPackage(
requires=["tiktoken", "tokenizers", "sentencepiece", "termcolor"],
data=["*.json"]),
"refact_models": PyPackage(
requires=["torch", "blobfile", "cloudpickle", "huggingface_hub",
"transformers", "dataclasses", "dataclasses_json"],
requires_packages=["refact_encoding"]),
"refact_scratchpads": PyPackage(
requires=["termcolor", "torch"],
requires_packages=["refact_encoding", "code_contrast", "refact_scratchpads_no_gpu"]),
"refact_scratchpads_no_gpu": PyPackage(
requires=["termcolor", "aiohttp", "tiktoken", "openai", "ujson", "setproctitle"]),
"refact_data_pipeline": PyPackage(
requires=["numpy", "tokenizers", "torch", "requests", "cloudpickle",
"tqdm", "dataclasses_json", "termcolor", 'more_itertools',
"cdifflib", "ujson", "zstandard", "scipy", "deepspeed==0.9.2",
"peft", "einops", "mpi4py", "matplotlib", "transformers",
"triton==2.0.0.post1", "torchinfo", "giturlparse", "jsonlines",
"binpacking", "filelock", "tables"],
requires_packages=["refact_encoding", "code_contrast", "self_hosting_machinery",
"refact_models"],
data=["git_command.exp"],
),
"self_hosting_machinery": PyPackage(
requires=["aiohttp", "cryptography", "fastapi==0.100.0", "giturlparse", "pydantic==1.10.10",
"starlette==0.27.0", "uvicorn", "uvloop", "python-multipart", "auto-gptq==0.4.2",
"torch", "transformers", "termcolor", "accelerate", "bitsandbytes"],
requires_packages=["refact_scratchpads", "refact_scratchpads_no_gpu", "refact_models",
"known_models_db", "refact_data_pipeline"],
data=["webgui/static/*", "webgui/static/js/*", "watchdog/watchdog.d/*"]),
}
def find_required_packages(packages: Set[str]) -> Set[str]:
updated_packages = copy(packages)
for name in packages:
assert name in all_refact_packages, f"Package {name} not found in repo"
updated_packages.update(all_refact_packages[name].requires_packages)
if updated_packages != packages:
return find_required_packages(updated_packages)
return packages
setup_package = os.environ.get("SETUP_PACKAGE", None)
if setup_package is not None:
if setup_package not in all_refact_packages:
raise ValueError(f"Package {setup_package} not found in repo")
setup_packages = {
name: py_package
for name, py_package in all_refact_packages.items()
if name in find_required_packages({setup_package})
}
else:
setup_packages = all_refact_packages
setup(
name="refact-self-hosting",
version="0.9.0",
py_modules=list(setup_packages.keys()),
package_data={
name: py_package.data
for name, py_package in setup_packages.items()
if py_package.data
},
packages=find_packages(include=(
f"{name}*" for name in setup_packages
)),
install_requires=list({
required_package
for py_package in setup_packages.values()
for required_package in py_package.requires
}),
)