Skip to content

Commit

Permalink
fix bugs for embedder finetune
Browse files Browse the repository at this point in the history
- m3 embedder: fix_encoder parameter unused
- trust_remote_code: decoder only embedder
  • Loading branch information
hanhainebula committed Jan 13, 2025
1 parent 808b6c8 commit 78fb4df
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 10 deletions.
16 changes: 11 additions & 5 deletions FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
config = AutoConfig.from_pretrained(
model_args.config_name,
token=model_args.token,
cache_dir=model_args.cache_dir
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
token=model_args.token,
cache_dir=model_args.cache_dir
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(
Expand All @@ -74,6 +76,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
trust_remote_code=model_args.trust_remote_code,
)
else:
logger.info("Training new model from scratch")
Expand Down Expand Up @@ -129,13 +132,15 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
config = AutoConfig.from_pretrained(
model_args.config_name,
token=model_args.token,
cache_dir=model_args.cache_dir
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
token=model_args.token,
cache_dir=model_args.cache_dir
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(
Expand All @@ -152,6 +157,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
trust_remote_code=model_args.trust_remote_code,
)
else:
model = model_args.from_config(config)
Expand All @@ -173,5 +179,5 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:

model.save_pretrained(os.path.join(output_dir, 'merged_model'))

tokenizer = AutoTokenizer.from_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(output_dir, trust_remote_code=model_args.trust_remote_code)
tokenizer.save_pretrained(os.path.join(output_dir, 'merged_model'))
3 changes: 2 additions & 1 deletion FlagEmbedding/finetune/embedder/decoder_only/base/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
token=self.model_args.token,
cache_dir=self.model_args.cache_dir,
use_fast=False,
add_eos_token=True
add_eos_token=True,
trust_remote_code=self.model_args.trust_remote_code,
)

if tokenizer.pad_token is None:
Expand Down
10 changes: 7 additions & 3 deletions FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str,
config = AutoConfig.from_pretrained(
model_args.config_name,
token=model_args.token,
cache_dir=model_args.cache_dir
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
token=model_args.token,
cache_dir=model_args.cache_dir
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(
Expand All @@ -74,6 +76,7 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
trust_remote_code=model_args.trust_remote_code,
)
else:
logger.info("Training new model from scratch")
Expand Down Expand Up @@ -152,6 +155,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
trust_remote_code=model_args.trust_remote_code,
)
else:
model = model_args.from_config(config)
Expand All @@ -173,5 +177,5 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d

model.save_pretrained(os.path.join(output_dir, 'merged_model'))

tokenizer = AutoTokenizer.from_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(output_dir, trust_remote_code=model_args.trust_remote_code)
tokenizer.save_pretrained(os.path.join(output_dir, 'merged_model'))
3 changes: 2 additions & 1 deletion FlagEmbedding/finetune/embedder/decoder_only/icl/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
token=self.model_args.token,
cache_dir=self.model_args.cache_dir,
use_fast=False,
add_eos_token=True
add_eos_token=True,
trust_remote_code=self.model_args.trust_remote_code,
)

if tokenizer.pad_token is None:
Expand Down
8 changes: 8 additions & 0 deletions FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
if "position_embeddings" in k:
logging.info(f"Freeze the parameters for {k}")
v.requires_grad = False

if self.training_args.fix_encoder:
for k, v in model.named_parameters():
if "colbert_linear" in k or 'sparse_linear' in k:
logging.info(f"train the parameters for {k}")
else:
v.requires_grad = False

return tokenizer, model

def load_trainer(self) -> EncoderOnlyEmbedderM3Trainer:
Expand Down

0 comments on commit 78fb4df

Please sign in to comment.