-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
79 lines (69 loc) · 2.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import hydra
import jax.random as random
import optax
import wandb
from configs import (
AnimeDatasetConfig,
DiffusionConfig,
MainConfig,
ModelConfig,
TrainerConfig,
)
from omegaconf import DictConfig, OmegaConf
from src.dataset import ImageDataset
from src.diffusion import scheduler
from src.model import UViT
from src.trainer import train
@hydra.main(config_path="configs", config_name="default", version_base="1.1")
def main(dict_config: DictConfig):
config = MainConfig(
dataset=AnimeDatasetConfig(**dict_config.dataset),
diffusion=DiffusionConfig(**dict_config.diffusion),
model=ModelConfig(**dict_config.model),
trainer=TrainerConfig(**dict_config.trainer),
mode=dict_config.mode,
)
assert (
config.dataset.image_size % config.model.patch_size == 0
), "The image size should be divisible by the patch size (for patch and unpatch operations)."
dataset = ImageDataset.from_folder(
folder_path=config.dataset.dir_path,
image_size=config.dataset.image_size,
preload=config.trainer.preload_data,
)
train_dataset, test_dataset = dataset.split(
split_ratio=0.80,
key=random.key(config.dataset.seed),
)
schedule = scheduler(config.diffusion.steps)
model = UViT(
num_channels=config.dataset.n_channels,
num_positions=(config.dataset.image_size // config.model.patch_size) ** 2,
num_timesteps=len(schedule),
patch_size=config.model.patch_size,
d_model=config.model.d_model,
num_heads=config.model.num_heads,
num_layers=config.model.num_layers,
key=random.key(config.model.seed),
)
optimizer = optax.adamw(config.trainer.learning_rate)
with wandb.init(
project="anime-diffusion",
config=OmegaConf.to_container(dict_config),
entity="pierrotlc",
mode=config.mode,
) as run:
train(
model=model,
train_dataset=train_dataset,
test_dataset=test_dataset,
schedule=schedule,
optimizer=optimizer,
batch_size=config.trainer.batch_size,
total_iters=config.trainer.total_iters,
evaluate_steps=config.trainer.evaluate_steps,
key=random.key(config.trainer.seed),
logger=run,
)
if __name__ == "__main__":
main()