Skip to content

Commit

Permalink
Test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aelaguiz committed Mar 21, 2024
1 parent 5053fa3 commit c23472f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
2 changes: 0 additions & 2 deletions langdspy/prompt_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def __init__(self, **kwargs):
hints = {} # New dictionary to hold hint fields

for name, attribute in self.__class__.__fields__.items():
print(f"Type of attribute.type_: {type(attribute.type_)}")
print(f"Class of attribute.type_: {attribute.type_.__class__}")
if issubclass(attribute.type_, InputField):
inputs[name] = attribute.default
elif issubclass(attribute.type_, OutputField):
Expand Down
46 changes: 45 additions & 1 deletion tests/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,51 @@
dotenv.load_dotenv()
import pytest
from unittest.mock import MagicMock
from examples.amazon.generate_slugs import ProductSlugGenerator, slug_similarity, get_llm
import langdspy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity


class GenerateSlug(langdspy.PromptSignature):
hint_slug = langdspy.HintField(desc="Generate a URL-friendly slug based on the provided H1, title, and product copy. The slug should be lowercase, use hyphens to separate words, and not exceed 50 characters.")

h1 = langdspy.InputField(name="H1", desc="The H1 heading of the product page")
title = langdspy.InputField(name="Title", desc="The title of the product page")
product_copy = langdspy.InputField(name="Product Copy", desc="The product description or copy")

slug = langdspy.OutputField(name="Slug", desc="The generated URL-friendly slug")

class ProductSlugGenerator(langdspy.Model):
generate_slug = langdspy.PromptRunner(template_class=GenerateSlug, prompt_strategy=langdspy.DefaultPromptStrategy)

def invoke(self, input_dict, config):
h1 = input_dict['h1']
title = input_dict['title']
product_copy = input_dict['product_copy']

slug_res = self.generate_slug.invoke({'h1': h1, 'title': title, 'product_copy': product_copy}, config=config)

return slug_res.slug


def cosine_similarity_tfidf(true_slugs, predicted_slugs):
# Convert slugs to lowercase
true_slugs = [slug.lower() for slug in true_slugs]
predicted_slugs = [slug.lower() for slug in predicted_slugs]

# for i in range(len(true_slugs)):
# print(f"Actual Slug: {true_slugs[i]} Predicted: {predicted_slugs[i]}")

vectorizer = TfidfVectorizer()
true_vectors = vectorizer.fit_transform(true_slugs)
predicted_vectors = vectorizer.transform(predicted_slugs)
similarity_scores = cosine_similarity(true_vectors, predicted_vectors)
return similarity_scores.diagonal()

def slug_similarity(X, true_slugs, predicted_slugs):
similarity_scores = cosine_similarity_tfidf(true_slugs, predicted_slugs)
average_similarity = sum(similarity_scores) / len(similarity_scores)
return average_similarity

@pytest.fixture
def model():
Expand Down

0 comments on commit c23472f

Please sign in to comment.