Skip to content

Commit

Permalink
remove prompts shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
Muinez committed Nov 30, 2024
1 parent 6fa0af0 commit ab73314
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions train_scripts/train_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,16 @@ def save_model(save_metric=True):

lm_time_start = time.time()
prompts = list(batch[2])
shuffled_prompts = []
for prompt in prompts:
tags = prompt.split(",") # Split the string into a list of tags
random.shuffle(tags) # Shuffle the tags
shuffled_prompts.append(",".join(tags)) # Join them back into a string
# shuffled_prompts = []
# for prompt in prompts:
# tags = prompt.split(",") # Split the string into a list of tags
# random.shuffle(tags) # Shuffle the tags
# shuffled_prompts.append(",".join(tags)) # Join them back into a string

if "T5" in config.text_encoder.text_encoder_name:
with torch.no_grad():
txt_tokens = tokenizer(
shuffled_prompts,
prompts,
max_length=max_length,
padding="max_length",
truncation=True,
Expand All @@ -408,10 +408,10 @@ def save_model(save_metric=True):
with torch.no_grad():
if not config.text_encoder.chi_prompt:
max_length_all = config.text_encoder.model_max_length
prompt = shuffled_prompts
prompt = prompts
else:
chi_prompt = "\n".join(config.text_encoder.chi_prompt)
prompt = [chi_prompt + i for i in shuffled_prompts]
prompt = [chi_prompt + i for i in prompts]
num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt))
max_length_all = (
num_chi_prompt_tokens + config.text_encoder.model_max_length - 2
Expand Down

0 comments on commit ab73314

Please sign in to comment.