From ab73314b64f1012b36c4e2cc3db6c5d1d0105c40 Mon Sep 17 00:00:00 2001 From: Muinez Date: Sun, 1 Dec 2024 00:44:06 +0300 Subject: [PATCH] remove prompts shuffling --- train_scripts/train_local.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/train_scripts/train_local.py b/train_scripts/train_local.py index c03231a..c46d22f 100644 --- a/train_scripts/train_local.py +++ b/train_scripts/train_local.py @@ -387,16 +387,16 @@ def save_model(save_metric=True): lm_time_start = time.time() prompts = list(batch[2]) - shuffled_prompts = [] - for prompt in prompts: - tags = prompt.split(",") # Split the string into a list of tags - random.shuffle(tags) # Shuffle the tags - shuffled_prompts.append(",".join(tags)) # Join them back into a string + # shuffled_prompts = [] + # for prompt in prompts: + # tags = prompt.split(",") # Split the string into a list of tags + # random.shuffle(tags) # Shuffle the tags + # shuffled_prompts.append(",".join(tags)) # Join them back into a string if "T5" in config.text_encoder.text_encoder_name: with torch.no_grad(): txt_tokens = tokenizer( - shuffled_prompts, + prompts, max_length=max_length, padding="max_length", truncation=True, @@ -408,10 +408,10 @@ def save_model(save_metric=True): with torch.no_grad(): if not config.text_encoder.chi_prompt: max_length_all = config.text_encoder.model_max_length - prompt = shuffled_prompts + prompt = prompts else: chi_prompt = "\n".join(config.text_encoder.chi_prompt) - prompt = [chi_prompt + i for i in shuffled_prompts] + prompt = [chi_prompt + i for i in prompts] num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) max_length_all = ( num_chi_prompt_tokens + config.text_encoder.model_max_length - 2