From 0e393dd29c85d5b5687c8015a9a65b153d652fe4 Mon Sep 17 00:00:00 2001 From: Alie Langston Date: Thu, 24 Aug 2023 11:44:40 -0400 Subject: [PATCH] feat: add json field to store prompt message content --- learning_assistant/api.py | 23 +++++++++-- .../0002_courseprompt_json_prompt.py | 18 ++++++++ learning_assistant/models.py | 14 +++---- learning_assistant/views.py | 11 ++--- tests/test_api.py | 41 +++++++++++++++++++ tests/test_models.py | 10 ++--- tests/test_views.py | 4 +- 7 files changed, 97 insertions(+), 24 deletions(-) create mode 100644 learning_assistant/migrations/0002_courseprompt_json_prompt.py create mode 100644 tests/test_api.py diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 1288a8b..5df3c35 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -1,11 +1,28 @@ """ Library for the learning_assistant app. """ +import json + from learning_assistant.models import CoursePrompt -def get_prompt_by_course_id(course_id): +def get_deserialized_prompt_content_by_course_id(course_id): + """ + Return a deserialized prompt given a course_id + """ + json_prompt = CoursePrompt.get_json_prompt_content_by_course_id(course_id) + if json_prompt: + prompt_messages = json.loads(json_prompt) + return prompt_messages + return None + + +def get_setup_messages(course_id): """ - Return a prompt associated with a given course id. + Return a list of setup messages given a course id """ - return CoursePrompt.get_prompt_by_course_id(course_id) + message_content = get_deserialized_prompt_content_by_course_id(course_id) + if message_content: + setup_messages = [{'role': 'system', 'content': x} for x in message_content] + return setup_messages + return None diff --git a/learning_assistant/migrations/0002_courseprompt_json_prompt.py b/learning_assistant/migrations/0002_courseprompt_json_prompt.py new file mode 100644 index 0000000..bd285e1 --- /dev/null +++ b/learning_assistant/migrations/0002_courseprompt_json_prompt.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.20 on 2023-08-24 09:56 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('learning_assistant', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='courseprompt', + name='json_prompt_content', + field=models.JSONField(null=True), + ), + ] diff --git a/learning_assistant/models.py b/learning_assistant/models.py index c6318cc..f4313f0 100644 --- a/learning_assistant/models.py +++ b/learning_assistant/models.py @@ -16,17 +16,17 @@ class CoursePrompt(TimeStampedModel): # course ID with which the text prompt is associated course_id = CourseKeyField(max_length=255, db_index=True, unique=True) - # text prompt, that may contain course related information - prompt = models.TextField(blank=True) + # a json representation of the prompt message content + json_prompt_content = models.JSONField(null=True) @classmethod - def get_prompt_by_course_id(cls, course_id): + def get_json_prompt_content_by_course_id(cls, course_id): """ - Return a text prompt for a given course id. + Return a json representation of a prompt for a given course id. """ try: prompt_object = cls.objects.get(course_id=course_id) - prompt = prompt_object.prompt + json_prompt_content = prompt_object.json_prompt_content except cls.DoesNotExist: - prompt = None - return prompt + json_prompt_content = None + return json_prompt_content diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 85ea4c6..bff3dff 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -20,7 +20,7 @@ # If the waffle flag is false, the endpoint will force an early return. learning_assistant_is_active = False -from learning_assistant.api import get_prompt_by_course_id +from learning_assistant.api import get_setup_messages from learning_assistant.serializers import MessageSerializer from learning_assistant.utils import get_chat_response @@ -66,8 +66,8 @@ def post(self, request, course_id): data={'detail': 'Must be staff or have valid enrollment.'} ) - prompt_text = get_prompt_by_course_id(course_id) - if not prompt_text: + prompt_messages = get_setup_messages(course_id) + if not prompt_messages: return Response( status=http_status.HTTP_404_NOT_FOUND, data={'detail': 'Learning assistant not enabled for course.'} @@ -85,10 +85,7 @@ def post(self, request, course_id): ) # append system message to beginning of message list - message_setup = [{ - 'role': 'system', - 'content': prompt_text - }] + message_setup = prompt_messages log.info( 'Attempting to retrieve chat response for user_id=%(user_id)s in course_id=%(course_id)s', diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..941b420 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,41 @@ +""" +Test cases for the learning-assistant api module. +""" +import json +from django.test import TestCase + +from learning_assistant.api import get_deserialized_prompt_content_by_course_id, get_setup_messages +from learning_assistant.models import CoursePrompt + + +class LearningAssistantAPITests(TestCase): + """ + Test suite for the api module + """ + + def setUp(self): + self.course_id = 'course-v1:edx+test+23' + self.prompt = json.dumps('["This is a Prompt", "This is another Prompt"]') + self.course_prompt = CoursePrompt.objects.create( + course_id=self.course_id, + json_prompt_content=self.prompt, + ) + return super().setUp() + + def test_get_deserialized_prompt_valid_course_id(self): + prompt_content = get_deserialized_prompt_content_by_course_id(self.course_id) + expected_content = json.loads(self.prompt) + self.assertEqual(prompt_content, expected_content) + + def test_get_deserialized_prompt_invalid_course_id(self): + prompt_content = get_deserialized_prompt_content_by_course_id('course-v1:edx+fake+19') + self.assertIsNone(prompt_content) + + def test_get_setup_messages(self): + setup_messages = get_setup_messages(self.course_id) + expected_messages = [{'role': 'system', 'content': x} for x in json.loads(self.prompt)] + self.assertEqual(setup_messages, expected_messages) + + def test_get_setup_messages_invalid_course_id(self): + setup_messages = get_setup_messages('course-v1:edx+fake+19') + self.assertIsNone(setup_messages) diff --git a/tests/test_models.py b/tests/test_models.py index f8b1ed3..35b0958 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,7 +2,7 @@ """ Tests for the `learning-assistant` models module. """ - +import json from django.test import TestCase from learning_assistant.models import CoursePrompt @@ -15,10 +15,10 @@ class CoursePromptTests(TestCase): def setUp(self): self.course_id = 'course-v1:edx+test+23' - self.prompt = 'This is a prompt' + self.prompt = json.dumps('["This is a Prompt", "This is another Prompt"]') self.course_prompt = CoursePrompt.objects.create( course_id=self.course_id, - prompt=self.prompt, + json_prompt_content=self.prompt, ) return super().setUp() @@ -26,12 +26,12 @@ def test_get_prompt_by_course_id(self): """ Test that a prompt can be retrieved by course ID """ - prompt = CoursePrompt.get_prompt_by_course_id(self.course_id) + prompt = CoursePrompt.get_json_prompt_content_by_course_id(self.course_id) self.assertEqual(prompt, self.prompt) def test_get_prompt_by_course_id_invalid(self): """ Test that None is returned if the given course ID does not exist """ - prompt = CoursePrompt.get_prompt_by_course_id('course-v1:edx+fake+19') + prompt = CoursePrompt.get_json_prompt_content_by_course_id('course-v1:edx+fake+19') self.assertIsNone(prompt) diff --git a/tests/test_views.py b/tests/test_views.py index 0562874..106569d 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -115,7 +115,7 @@ def test_invalid_messages(self, mock_role, mock_waffle): CoursePrompt.objects.create( course_id=self.course_id, - prompt='This is a prompt' + json_prompt_content=json.dumps('["This is a Prompt", "This is another Prompt"]') ) test_data = [ @@ -140,7 +140,7 @@ def test_chat_response(self, mock_role, mock_waffle, mock_chat_response): CoursePrompt.objects.create( course_id=self.course_id, - prompt='This is a prompt' + json_prompt_content=json.dumps('["This is a Prompt", "This is another Prompt"]') ) test_data = [