Skip to content

Commit

Permalink
feat: add json field to store prompt message content
Browse files Browse the repository at this point in the history
  • Loading branch information
alangsto committed Aug 24, 2023
1 parent 97e647b commit 0e393dd
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 24 deletions.
23 changes: 20 additions & 3 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
"""
Library for the learning_assistant app.
"""
import json

from learning_assistant.models import CoursePrompt


def get_prompt_by_course_id(course_id):
def get_deserialized_prompt_content_by_course_id(course_id):
"""
Return a deserialized prompt given a course_id
"""
json_prompt = CoursePrompt.get_json_prompt_content_by_course_id(course_id)
if json_prompt:
prompt_messages = json.loads(json_prompt)
return prompt_messages
return None


def get_setup_messages(course_id):
"""
Return a prompt associated with a given course id.
Return a list of setup messages given a course id
"""
return CoursePrompt.get_prompt_by_course_id(course_id)
message_content = get_deserialized_prompt_content_by_course_id(course_id)
if message_content:
setup_messages = [{'role': 'system', 'content': x} for x in message_content]
return setup_messages
return None
18 changes: 18 additions & 0 deletions learning_assistant/migrations/0002_courseprompt_json_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 3.2.20 on 2023-08-24 09:56

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('learning_assistant', '0001_initial'),
]

operations = [
migrations.AddField(
model_name='courseprompt',
name='json_prompt_content',
field=models.JSONField(null=True),
),
]
14 changes: 7 additions & 7 deletions learning_assistant/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ class CoursePrompt(TimeStampedModel):
# course ID with which the text prompt is associated
course_id = CourseKeyField(max_length=255, db_index=True, unique=True)

# text prompt, that may contain course related information
prompt = models.TextField(blank=True)
# a json representation of the prompt message content
json_prompt_content = models.JSONField(null=True)

@classmethod
def get_prompt_by_course_id(cls, course_id):
def get_json_prompt_content_by_course_id(cls, course_id):
"""
Return a text prompt for a given course id.
Return a json representation of a prompt for a given course id.
"""
try:
prompt_object = cls.objects.get(course_id=course_id)
prompt = prompt_object.prompt
json_prompt_content = prompt_object.json_prompt_content
except cls.DoesNotExist:
prompt = None
return prompt
json_prompt_content = None
return json_prompt_content
11 changes: 4 additions & 7 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# If the waffle flag is false, the endpoint will force an early return.
learning_assistant_is_active = False

from learning_assistant.api import get_prompt_by_course_id
from learning_assistant.api import get_setup_messages
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response

Expand Down Expand Up @@ -66,8 +66,8 @@ def post(self, request, course_id):
data={'detail': 'Must be staff or have valid enrollment.'}
)

prompt_text = get_prompt_by_course_id(course_id)
if not prompt_text:
prompt_messages = get_setup_messages(course_id)
if not prompt_messages:
return Response(
status=http_status.HTTP_404_NOT_FOUND,
data={'detail': 'Learning assistant not enabled for course.'}
Expand All @@ -85,10 +85,7 @@ def post(self, request, course_id):
)

# append system message to beginning of message list
message_setup = [{
'role': 'system',
'content': prompt_text
}]
message_setup = prompt_messages

log.info(
'Attempting to retrieve chat response for user_id=%(user_id)s in course_id=%(course_id)s',
Expand Down
41 changes: 41 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Test cases for the learning-assistant api module.
"""
import json
from django.test import TestCase

from learning_assistant.api import get_deserialized_prompt_content_by_course_id, get_setup_messages
from learning_assistant.models import CoursePrompt


class LearningAssistantAPITests(TestCase):
"""
Test suite for the api module
"""

def setUp(self):
self.course_id = 'course-v1:edx+test+23'
self.prompt = json.dumps('["This is a Prompt", "This is another Prompt"]')
self.course_prompt = CoursePrompt.objects.create(
course_id=self.course_id,
json_prompt_content=self.prompt,
)
return super().setUp()

def test_get_deserialized_prompt_valid_course_id(self):
prompt_content = get_deserialized_prompt_content_by_course_id(self.course_id)
expected_content = json.loads(self.prompt)
self.assertEqual(prompt_content, expected_content)

def test_get_deserialized_prompt_invalid_course_id(self):
prompt_content = get_deserialized_prompt_content_by_course_id('course-v1:edx+fake+19')
self.assertIsNone(prompt_content)

def test_get_setup_messages(self):
setup_messages = get_setup_messages(self.course_id)
expected_messages = [{'role': 'system', 'content': x} for x in json.loads(self.prompt)]
self.assertEqual(setup_messages, expected_messages)

def test_get_setup_messages_invalid_course_id(self):
setup_messages = get_setup_messages('course-v1:edx+fake+19')
self.assertIsNone(setup_messages)
10 changes: 5 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Tests for the `learning-assistant` models module.
"""

import json
from django.test import TestCase

from learning_assistant.models import CoursePrompt
Expand All @@ -15,23 +15,23 @@ class CoursePromptTests(TestCase):

def setUp(self):
self.course_id = 'course-v1:edx+test+23'
self.prompt = 'This is a prompt'
self.prompt = json.dumps('["This is a Prompt", "This is another Prompt"]')
self.course_prompt = CoursePrompt.objects.create(
course_id=self.course_id,
prompt=self.prompt,
json_prompt_content=self.prompt,
)
return super().setUp()

def test_get_prompt_by_course_id(self):
"""
Test that a prompt can be retrieved by course ID
"""
prompt = CoursePrompt.get_prompt_by_course_id(self.course_id)
prompt = CoursePrompt.get_json_prompt_content_by_course_id(self.course_id)
self.assertEqual(prompt, self.prompt)

def test_get_prompt_by_course_id_invalid(self):
"""
Test that None is returned if the given course ID does not exist
"""
prompt = CoursePrompt.get_prompt_by_course_id('course-v1:edx+fake+19')
prompt = CoursePrompt.get_json_prompt_content_by_course_id('course-v1:edx+fake+19')
self.assertIsNone(prompt)
4 changes: 2 additions & 2 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_invalid_messages(self, mock_role, mock_waffle):

CoursePrompt.objects.create(
course_id=self.course_id,
prompt='This is a prompt'
json_prompt_content=json.dumps('["This is a Prompt", "This is another Prompt"]')
)

test_data = [
Expand All @@ -140,7 +140,7 @@ def test_chat_response(self, mock_role, mock_waffle, mock_chat_response):

CoursePrompt.objects.create(
course_id=self.course_id,
prompt='This is a prompt'
json_prompt_content=json.dumps('["This is a Prompt", "This is another Prompt"]')
)

test_data = [
Expand Down

0 comments on commit 0e393dd

Please sign in to comment.