Skip to content

Commit

Permalink
Merge pull request #103 from invoke-ai/update-ti-sd1-format
Browse files Browse the repository at this point in the history
Update SD1.5 TI embedding output format
  • Loading branch information
RyanJDick authored Apr 11, 2024
2 parents c017568 + 85cf4f6 commit 11160dc
Showing 1 changed file with 1 addition and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _save_ti_embeddings(
text_encoder: CLIPTextModel,
placeholder_token_ids: list[int],
accelerator: Accelerator,
placeholder_token: str,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
):
Expand All @@ -59,7 +58,7 @@ def _save_ti_embeddings(
.get_input_embeddings()
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
)
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
learned_embeds_dict = {"emb_params": learned_embeds.detach().cpu()}

save_state_dict(learned_embeds_dict, save_path)

Expand Down Expand Up @@ -366,7 +365,6 @@ def train(config: SdTextualInversionConfig): # noqa: C901
text_encoder=text_encoder,
placeholder_token_ids=placeholder_token_ids,
accelerator=accelerator,
placeholder_token=config.placeholder_token,
logger=logger,
checkpoint_tracker=step_checkpoint_tracker,
)
Expand Down Expand Up @@ -406,7 +404,6 @@ def train(config: SdTextualInversionConfig): # noqa: C901
text_encoder=text_encoder,
placeholder_token_ids=placeholder_token_ids,
accelerator=accelerator,
placeholder_token=config.placeholder_token,
logger=logger,
checkpoint_tracker=epoch_checkpoint_tracker,
)
Expand Down

0 comments on commit 11160dc

Please sign in to comment.