Skip to content

Commit

Permalink
fix: Added course run key to save_chat_message()
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuma committed Nov 4, 2024
1 parent f0b0cf8 commit 4fa9bf5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
4 changes: 3 additions & 1 deletion learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_course_id(course_run_id):
return course_key


def save_chat_message(user_id, chat_role, message):
def save_chat_message(courserun_key, user_id, chat_role, message):
"""
Save the chat message to the database.
"""
Expand All @@ -203,9 +203,11 @@ def save_chat_message(user_id, chat_role, message):

# Save the user message to the database.
LearningAssistantMessage.objects.create(
course_id=courserun_key,
user=user,
role=chat_role,
content=message,

)


Expand Down
4 changes: 2 additions & 2 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def post(self, request, course_run_id):
user_id = request.user.id

if chat_history_enabled(courserun_key):
save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content'])
save_chat_message(courserun_key, user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content'])

serializer = MessageSerializer(data=message_list, many=True)

Expand Down Expand Up @@ -126,7 +126,7 @@ def post(self, request, course_run_id):
status_code, message = get_chat_response(prompt_template, message_list)

if chat_history_enabled(courserun_key):
save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content'])
save_chat_message(courserun_key, user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content'])

return Response(status=status_code, data=message)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,19 @@ def setUp(self):
super().setUp()

self.test_user = User.objects.create(username='username', password='password')
self.course_run_key = CourseKey.from_string('course-v1:edx+test+23')

@ddt.data(
(LearningAssistantMessage.USER_ROLE, 'What is the meaning of life, the universe and everything?'),
(LearningAssistantMessage.ASSISTANT_ROLE, '42'),
)
@ddt.unpack
def test_save_chat_message(self, chat_role, message):
save_chat_message(self.test_user.id, chat_role, message)
save_chat_message(self.course_run_key, self.test_user.id, chat_role, message)

row = LearningAssistantMessage.objects.all().last()

self.assertEqual(row.course_id, self.course_run_key)
self.assertEqual(row.role, chat_role)
self.assertEqual(row.content, message)

Expand Down
6 changes: 4 additions & 2 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from django.test import TestCase, override_settings
from django.test.client import Client
from django.urls import reverse
from opaque_keys.edx.keys import CourseKey

from learning_assistant.models import LearningAssistantMessage

Expand Down Expand Up @@ -85,6 +86,7 @@ class TestCourseChatView(LoggedInTestCase):
def setUp(self):
super().setUp()
self.course_id = 'course-v1:edx+test+23'
self.course_run_key = CourseKey.from_string(self.course_id)

self.patcher = patch(
'learning_assistant.api.get_cache_course_run_data',
Expand Down Expand Up @@ -209,8 +211,8 @@ def test_chat_response_default(

if enabled_flag:
mock_save_chat_message.assert_has_calls([
call(self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']),
call(self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else')
call(self.course_run_key, self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']),
call(self.course_run_key, self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else')
])
else:
mock_save_chat_message.assert_not_called()
Expand Down

0 comments on commit 4fa9bf5

Please sign in to comment.