Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading logic rework #52

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions pylate/models/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os

import torch
from safetensors import safe_open
from safetensors.torch import load_model as load_safetensors_model
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.util import import_from_string
from torch import nn
from transformers.utils import cached_file

__all__ = ["Dense"]

Expand Down Expand Up @@ -77,6 +79,61 @@ def from_sentence_transformers(dense: DenseSentenceTransformer) -> "Dense":
model.load_state_dict(dense.state_dict())
return model

@staticmethod
def from_stanford_weights(
model_name_or_path: str | os.PathLike,
cache_folder: str | os.PathLike | None = None,
revision: str | None = None,
local_files_only: bool | None = None,
token: str | bool | None = None,
use_auth_token: str | bool | None = None,
) -> "Dense":
"""Load the weight of the Dense layer using weights from a stanford-nlp checkpoint.

Parameters
----------
model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
cache_folder (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
"""
# Check if the model is locally available
if not (os.path.exists(os.path.join(model_name_or_path))):
# Else download the model/use the cached version
model_name_or_path = cached_file(
model_name_or_path,
filename="model.safetensors",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
with safe_open(model_name_or_path, framework="pt", device="cpu") as f:
state_dict = {"linear.weight": f.get_tensor("linear.weight")}

# Determine input and output dimensions
in_features = state_dict["linear.weight"].shape[1]
out_features = state_dict["linear.weight"].shape[0]

# Create Dense layer instance
model = Dense(in_features=in_features, out_features=out_features, bias=False)

model.load_state_dict(state_dict)
return model

@staticmethod
def load(input_path) -> "Dense":
"""Load a Dense layer."""
Expand Down
39 changes: 27 additions & 12 deletions pylate/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,6 @@ def __init__(
config_kwargs: dict | None = None,
model_card_data: Optional[SentenceTransformerModelCardData] = None,
) -> None:
model_kwargs = {} if model_kwargs is None else model_kwargs
model_kwargs["add_pooling_layer"] = False

self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_length = query_length
Expand All @@ -248,20 +245,35 @@ def __init__(
config_kwargs=config_kwargs,
model_card_data=model_card_data,
)

hidden_size = self[0].get_word_embedding_dimension()

# Add a linear projection layer to the model in order to project the embeddings to the desired size.
if len(self) < 2:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
embedding_size = embedding_size or 128
# If the model is a stanford-nlp ColBERT, load the weights of the dense layer
if self[0].auto_model.config.architectures[0] == "HF_ColBERT":
self.append(
Dense.from_stanford_weights(
model_name_or_path,
cache_folder,
revision,
local_files_only,
token,
use_auth_token,
)
)
logger.warning("Loaded the ColBERT model from Stanford NLP.")
else:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
embedding_size = embedding_size or 128

logger.warning(
f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})."
)
self.append(
Dense(in_features=hidden_size, out_features=embedding_size, bias=bias)
)
logger.warning(
f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})."
)
self.append(
Dense(
in_features=hidden_size, out_features=embedding_size, bias=bias
)
)

elif (
embedding_size is not None
Expand All @@ -282,6 +294,9 @@ def __init__(
else:
logger.warning("Pylate model loaded successfully.")

if model_kwargs is not None and "torch_dtype" in model_kwargs:
self[1].to(model_kwargs["torch_dtype"])

self.to(device)
self.is_hpu_graph_enabled = False

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
long_description = fh.read()

base_packages = [
"sentence-transformers >= 3.0.1",
"sentence-transformers == 3.0.1",
"datasets >= 2.20.0",
"accelerate >= 0.31.0",
"voyager >= 2.0.9",
Expand Down
Loading