Skip to content

Commit

Permalink
fix: update sentence transformer udf (#854)
Browse files Browse the repository at this point in the history
πŸ‘‹ Thanks for submitting a Pull Request to EvaDB!

πŸ™Œ We want to make contributing to EvaDB as easy and transparent as
possible. Here are a few tips to get you started:

- πŸ” Search existing EvaDB
[PRs](https://github.com/georgia-tech-db/eva/pulls) to see if a similar
PR already exists.
- πŸ”— Link this PR to a EvaDB
[issue](https://github.com/georgia-tech-db/eva/issues) to help us
understand what bug fix or feature is being implemented.
- πŸ“ˆ Provide before and after profiling results to help us quantify the
improvement your PR provides (if applicable).

πŸ‘‰ Please see our βœ… [Contributing
Guide](https://evadb.readthedocs.io/en/stable/source/contribute/index.html)
for more details.
  • Loading branch information
jarulraj authored Jun 11, 2023
1 parent 95c501d commit e45d509
Showing 1 changed file with 8 additions and 34 deletions.
42 changes: 8 additions & 34 deletions evadb/udfs/sentence_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
# limitations under the License.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer

from evadb.catalog.catalog_type import NdArrayType
from evadb.udfs.abstract.abstract_udf import AbstractUDF
Expand All @@ -25,30 +23,25 @@
from evadb.udfs.gpu_compatible import GPUCompatible


class SentenceFeatureExtractor(AbstractUDF, GPUCompatible):
class SentenceTransformerFeatureExtractor(AbstractUDF, GPUCompatible):
@setup(cacheable=False, udf_type="FeatureExtraction", batchable=False)
def setup(self):
self.tokenizer = AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
)
self.model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
self.model_device = None
self.model = SentenceTransformer("all-MiniLM-L6-v2")

def to_device(self, device: str) -> GPUCompatible:
self.model_device = device
self.model = self.model.to(device)
return self

@property
def name(self) -> str:
return "SentenceFeatureExtractor"
return "SentenceTransformerFeatureExtractor"

@forward(
input_signatures=[
PandasDataframe(
columns=["data"],
column_types=[NdArrayType.STR],
column_shapes=[(None, 1)],
column_shapes=[(1)],
)
],
output_signatures=[
Expand All @@ -61,28 +54,9 @@ def name(self) -> str:
)
def forward(self, df: pd.DataFrame) -> pd.DataFrame:
def _forward(row: pd.Series) -> np.ndarray:
sentence = row[0]

encoded_input = self.tokenizer(
[sentence], padding=True, truncation=True, return_tensors="pt"
)
if self.model_device is not None:
encoded_input.to(self.model_device)
with torch.no_grad():
model_output = self.model(**encoded_input)

attention_mask = encoded_input["attention_mask"]
token_embedding = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embedding.size()).float()
)
sentence_embedding = torch.sum(
token_embedding * input_mask_expanded, 1
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1)

sentence_embedding_np = sentence_embedding.cpu().numpy()
return sentence_embedding_np
data = row
embedded_list = self.model.encode(data)
return embedded_list

ret = pd.DataFrame()
ret["features"] = df.apply(_forward, axis=1)
Expand Down

0 comments on commit e45d509

Please sign in to comment.