You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there - first of all thanks for sharing your work!
I was trying to follow the README.md but am running into the following error:
[LightningDiT-Sampling 2025-01-06 18:31:29]: Using ckpt: LightningDiT/lightningdit-xl-imagenet256-800ep.pt
[LightningDiT-Sampling 2025-01-06 18:31:37]: ckpt_path= LightningDiT/lightningdit-xl-imagenet256-800ep.pt
[LightningDiT-Sampling 2025-01-06 18:31:37]: cfg_scale= 9.0
[LightningDiT-Sampling 2025-01-06 18:31:37]: cfg_interval_start= 0
[LightningDiT-Sampling 2025-01-06 18:31:37]: timestep_shift= 0
[LightningDiT-Sampling 2025-01-06 18:31:37]: Starting rank=0, seed=0, world_size=1.
[LightningDiT-Sampling 2025-01-06 18:31:48]: Loaded VAE model
[LightningDiT-Sampling 2025-01-06 18:31:48]: Using cfg: True
[LightningDiT-Sampling 2025-01-06 18:31:48]: Total number of images that will be sampled: 50000
[LightningDiT-Sampling 2025-01-06 18:31:48]: Using latent normalization
0it [00:00, ?it/s]
Traceback (most recent call last):
File "/LightningDiT/inference.py", line 278, in <module>
sample_folder_dir = do_sample(train_config, accelerator, ckpt_path=ckpt_dir, model=model, demo_sample_mode=args.demo)
File "/LightningDiT/inference.py", line 157, in do_sample
dataset = ImgLatentDataset(
File "/LightningDiT/datasets/img_latent_dataset.py", line 26, in __init__
self._latent_mean, self._latent_std = self.get_latent_stats()
File "/LightningDiT/datasets/img_latent_dataset.py", line 46, in get_latent_stats
latent_stats = self.compute_latent_stats()
File "/LightningDiT/datasets/img_latent_dataset.py", line 63, in compute_latent_stats
latents = torch.cat(latents, dim=0)
RuntimeError: torch.cat(): expected a non-empty list of Tensors
Using the 800ep_cfg config, which looks like this
# we recommend to read config_details.yaml first.
ckpt_path: '/LightningDiT/lightningdit-xl-imagenet256-800ep.pt' # <---- download our pre-trained lightningdit or your own checkpoint
data:
data_path: '/LightningDiT/latents_stats.pt' # <---- path to your data. it is generated by extract_features.py.
# if you just want to inference, download our latent_stats.pt and give its path here is ok.
fid_reference_file: 'path/to/your/VIRTUAL_imagenet256_labeled.npz' # <---- path to your fid_reference_file.npz. download it from ADM
# recommend to use default settings
image_size: 256
num_classes: 1000
num_workers: 8
latent_norm: true
latent_multiplier: 1.0
# recommend to use default settings (we wil release codes with SD-VAE later)
vae:
model_name: 'vavae_f16d32'
downsample_ratio: 16
# recommend to use default settings
model:
model_type: LightningDiT-XL/1
use_qknorm: false
use_swiglu: true
use_rope: true
use_rmsnorm: true
wo_shift: false
in_chans: 32
# recommend to use default settings
train:
max_steps: 80000
global_batch_size: 1024
global_seed: 0
output_dir: 'output'
exp_name: 'lightningdit_xl_vavae_f16d32_V2' # <---- experiment name, set as you like
ckpt: null
log_every: 100
ckpt_every: 20000
optimizer:
lr: 0.0002
beta2: 0.95
# recommend to use default settings
transport:
path_type: Linear
prediction: velocity
loss_weight: null
sample_eps: null
train_eps: null
use_cosine_loss: true
use_lognorm: true
# recommend to use default settings
sample:
mode: ODE
sampling_method: euler
atol: 0.000001
rtol: 0.001
reverse: false
likelihood: false
num_sampling_steps: 250
cfg_scale: 6.7 # <---- cfg scale, for 800 epoch performance with FID=1.35 cfg_scale=6.7
# for 64 epoch performance with FID=2.11 cfg_scale=10.0
# you may find we use a large cfg_scale, this is because of 2 reasons:
# we find a high-dimensional latent space requires a large cfg_scale to get good performance than f8d4 SD-VAE
# we enable cfg interval, which reduces the negative effects of large cfg on high-noise parts. This means larger cfg can be utilized
# recommend to use default settings
per_proc_batch_size: 4
fid_num: 50000
cfg_interval_start: 0.125
timestep_shift: 0.3
I also updated the tokenizer config at tokenizer/configs/vavae_f16d32.yaml - probably worth including that in the readme.
Am I missing anything obvious here?
The text was updated successfully, but these errors were encountered:
Try to change the data_path from /LightningDiT/latents_stats.pt to /LightningDiT/. data_path indicates the training dataset directory which is like the following:
Hi there - first of all thanks for sharing your work!
I was trying to follow the
README.md
but am running into the following error:Using the
800ep_cfg
config, which looks like thisI also updated the tokenizer config at
tokenizer/configs/vavae_f16d32.yaml
- probably worth including that in the readme.Am I missing anything obvious here?
The text was updated successfully, but these errors were encountered: