Skip to content

Commit

Permalink
[Update] Resume checkpoints during phase 2
Browse files Browse the repository at this point in the history
  • Loading branch information
yeungchenwa committed Mar 14, 2024
1 parent 71345a0 commit 6f670a0
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Initially taken from GitHub's Python gitignore file
phase_1_ckpt/
outputs/
run_sh/
ckpt/
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ sh train_phase_1.sh
- `drop_prob`: The classifier-free guidance training probability.

### Training - Phase 2
After the phase 2 training, you should put the trained checkpoint files (`unet.pth`, `content_encoder.pth`, and `style_encoder.pth`) to the directory `phase_1_ckpt`. During phase 2, these parameters will be resumed.
```bash
sh train_phase_2.sh
```
- `phase_2`: Tag to phase 2 training.
- `phase_1_ckpt_dir`: The model checkpoints saving directory after phase 1 training.
- `scr_ckpt_path`: The ckpt path of pre-trained SCR module. You can download it from above 🔥Model Zoo.
- `sc_coefficient`: The coefficient of style contrastive loss for supervision.
- `num_neg`: The number of negative samples, default to be `16`.
Expand Down
17 changes: 0 additions & 17 deletions configs/collate_fn.py

This file was deleted.

1 change: 1 addition & 0 deletions configs/fontdiffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_parser():

# Training
parser.add_argument("--phase_2", action="store_true", help="Training in phase 2 using SCR module.")
parser.add_argument("--phase_1_ckpt_dir", type=str, default=None, help="The trained ckpt directory during phase 1.")
## SCR
parser.add_argument("--temperature", type=float, default=0.07)
parser.add_argument("--mode", type=str, default="refinement")
Expand Down
1 change: 1 addition & 0 deletions scripts/train_phase_2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ accelerate launch train.py \
--output_dir="outputs/FontDiffuser" \
--report_to="tensorboard" \
--phase_2 \
--phase_1_ckpt_dir="phase_1_ckpt" \
--scr_ckpt_path="ckpt/scr_210000.pth" \
--sc_coefficient=0.01 \
--num_neg=16 \
Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def main():
style_encoder = build_style_encoder(args=args)
content_encoder = build_content_encoder(args=args)
noise_scheduler = build_ddpm_scheduler(args)
if args.phase_2:
unet.load_state_dict(torch.load(f"{args.phase_1_ckpt_dir}/unet.pth"))
style_encoder.load_state_dict(torch.load(f"{args.phase_1_ckpt_dir}/style_encoder.pth"))
content_encoder.load_state_dict(torch.load(f"{args.phase_1_ckpt_dir}/content_encoder.pth"))

model = FontDiffuserModel(
unet=unet,
Expand Down

0 comments on commit 6f670a0

Please sign in to comment.