From dc9179cbb1fdeaffa4f2b1b7e28e4b4abc479051 Mon Sep 17 00:00:00 2001 From: Alison Langston <46360176+alangsto@users.noreply.github.com> Date: Wed, 28 Feb 2024 11:20:13 -0500 Subject: [PATCH] feat: use course cache to retrieve course data (#72) --- CHANGELOG.rst | 4 ++ docs/decisions/0004-course-cache-use.rst | 50 ++++++++++++++++++++++++ learning_assistant/__init__.py | 2 +- learning_assistant/api.py | 23 +++++++++-- learning_assistant/constants.py | 2 +- learning_assistant/platform_imports.py | 16 +++++++- learning_assistant/utils.py | 19 ++------- learning_assistant/views.py | 20 +++++----- test_settings.py | 2 + tests/test_api.py | 19 ++++++--- tests/test_utils.py | 11 +----- tests/test_views.py | 22 +++++++---- 12 files changed, 136 insertions(+), 54 deletions(-) create mode 100644 docs/decisions/0004-course-cache-use.rst diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0c3b823..b4c7a19 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,10 @@ Change Log Unreleased ********** +4.1.0 - 2024-02-26 +****************** +* Use course cache to inject course title and course skill names into prompt template. + 4.0.0 - 2024-02-21 ****************** * Remove use of course waffle flag. Use the django setting LEARNING_ASSISTANT_AVAILABLE diff --git a/docs/decisions/0004-course-cache-use.rst b/docs/decisions/0004-course-cache-use.rst new file mode 100644 index 0000000..13db245 --- /dev/null +++ b/docs/decisions/0004-course-cache-use.rst @@ -0,0 +1,50 @@ +4. Use of the LMS course caches +############################### + +Status +****** + +**Accepted** *2024-02-26* + +Context +******* +Each course run ID in edx-platform is associated with another course ID. While for many courses, the mapping between +course run ID and course ID is straight forward, i.e. edX+testX+2019 -> edX+testX, this is not the case for every +course on edX. The discovery service is the source of truth for mappings between course run ID and course IDs, and +string manipulation cannot be relied on as an accurate way to map between the two forms of ID. + +The learning-assistant `CourseChatView`_ accepts a course run ID as a path parameter, but a number of API functions +in the learning assistant backend also require the course ID associated with a given course. + +In our initial release, we also found that the current courses available in 2U's Xpert Platform team's index, which was +being used to inject course skill names and course titles into the system prompt (see `System Prompt Design Changes`_ for +original details), were too limited. Courses included in that index were conditionally added depending on course +enrollment dates and additional fields from the discovery course API. While the 2U Xpert Platform team may work to address +the gap in product needs for their current course index, an alternate method for retrieving course skills and title should +be considered. + +Decision +******** +In order to determine the mapping between a course run ID and course ID in the learning-assistant app, we will make +use of an `existing course run cache that is defined in edx-platform`_. Similarly, to retrieve the skill names and title of a course, we will also use +an `existing course cache`_. Both caches store data from the discovery API for course runs and courses, respectively. +These are long term caches with a TTL of 24 hours, and on a cache miss the discovery API will be called. + +Consequences +************ +* If the caches were to be removed, code in the learning-assistant repository would no longer function as expected. +* On a cache miss, the learning-assistant backend will incur additional performance cost on calls to the discovery API. + +Rejected Alternatives +********************* +* Calling the discovery API directly from the learning-assistant backend + * This would require building a custom solution in the learning-assistant app to call the discovery service directly. + * Without a cache, this would impact performance on every call to the learning-assistant backend. +* Using string manipulation to map course run ID to course ID. + * If we do not use the discovery service as our source of truth for course run ID to course ID mappings, + we run the risk of being unable to support courses that do not fit the usual pattern mapping. + +.. _existing course run cache that is defined in edx-platform: https://github.com/openedx/edx-platform/blob/c61df904c1d2a5f523f1da44460c21e17ec087ee/openedx/core/djangoapps/catalog/utils.py#L801 +.. _CourseChatView: https://github.com/edx/learning-assistant/blob/fddf0bc27016bd4a1cabf82de7bcb80b51f3763b/learning_assistant/views.py#L29 +.. _System Prompt Design Changes: https://github.com/edx/learning-assistant/blob/main/docs/decisions/0002-system-prompt-design-changes.rst +.. _existing course cache: https://github.com/openedx/edx-platform/blob/3a2b6dd8fcc909fd9128f81750f52650ba8ff906/openedx/core/djangoapps/catalog/utils.py#L767 diff --git a/learning_assistant/__init__.py b/learning_assistant/__init__.py index b83582f..b46d757 100644 --- a/learning_assistant/__init__.py +++ b/learning_assistant/__init__.py @@ -2,6 +2,6 @@ Plugin for a learning assistant backend, intended for use within edx-platform. """ -__version__ = '4.0.0' +__version__ = '4.1.0' default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 8b1a2f5..38ae1c0 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -15,6 +15,8 @@ from learning_assistant.platform_imports import ( block_get_children, block_leaf_filter, + get_cache_course_data, + get_cache_course_run_data, get_single_block, get_text_transcript, traverse_block_pre_order, @@ -101,19 +103,23 @@ def get_block_content(request, user_id, course_id, unit_usage_key): return cache_data['content_length'], cache_data['content_items'] -def render_prompt_template(request, user_id, course_id, unit_usage_key): +def render_prompt_template(request, user_id, course_run_id, unit_usage_key, course_id): """ Return a rendered prompt template, specified by the LEARNING_ASSISTANT_PROMPT_TEMPLATE setting. """ unit_content = '' - course_run_key = CourseKey.from_string(course_id) + course_run_key = CourseKey.from_string(course_run_id) if unit_usage_key and course_content_enabled(course_run_key): - _, unit_content = get_block_content(request, user_id, course_id, unit_usage_key) + _, unit_content = get_block_content(request, user_id, course_run_id, unit_usage_key) + + course_data = get_cache_course_data(course_id, ['skill_names', 'title']) + skill_names = course_data['skill_names'] + title = course_data['title'] template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') template = Environment(loader=BaseLoader).from_string(template_string) - data = template.render(unit_content=unit_content) + data = template.render(unit_content=unit_content, skill_names=skill_names, title=title) return data @@ -168,3 +174,12 @@ def set_learning_assistant_enabled(course_key, enabled): course_key=obj.course_id, enabled=obj.enabled ) + + +def get_course_id(course_run_id): + """ + Given a course run id (str), return the associated course key. + """ + course_data = get_cache_course_run_data(course_run_id, ['course']) + course_key = course_data['course'] + return course_key diff --git a/learning_assistant/constants.py b/learning_assistant/constants.py index 06097ff..7027a28 100644 --- a/learning_assistant/constants.py +++ b/learning_assistant/constants.py @@ -7,7 +7,7 @@ EXTERNAL_COURSE_KEY_PATTERN = r'([A-Za-z0-9-_:]+)' -COURSE_ID_PATTERN = rf'(?P({INTERNAL_COURSE_KEY_PATTERN}|{EXTERNAL_COURSE_KEY_PATTERN}))' +COURSE_ID_PATTERN = rf'(?P({INTERNAL_COURSE_KEY_PATTERN}|{EXTERNAL_COURSE_KEY_PATTERN}))' ACCEPTED_CATEGORY_TYPES = ['html', 'video'] CATEGORY_TYPE_MAP = { diff --git a/learning_assistant/platform_imports.py b/learning_assistant/platform_imports.py index 1d18664..8c2411c 100644 --- a/learning_assistant/platform_imports.py +++ b/learning_assistant/platform_imports.py @@ -51,14 +51,26 @@ def get_cache_course_run_data(course_run_id, fields): """ Return course run related data given a course run id. - This function makes use of the discovery course run cache, which is necessary because - only the discovery service stores the relation between courseruns and courses. + This function makes use of the course run cache in the LMS, which caches data from the discovery service. This is + necessary because only the discovery service stores the relation between courseruns and courses. """ # pylint: disable=import-error, import-outside-toplevel from openedx.core.djangoapps.catalog.utils import get_course_run_data return get_course_run_data(course_run_id, fields) +def get_cache_course_data(course_id, fields): + """ + Return course related data given a course id. + + This function makes use of the course cache in the LMS, which caches data from the discovery service. This is + necessary because only the discovery service stores course skills data. + """ + # pylint: disable=import-error, import-outside-toplevel + from openedx.core.djangoapps.catalog.utils import get_course_data + return get_course_data(course_id, fields) + + def get_user_role(user, course_key): """ Return the role of the user on the edX platform. diff --git a/learning_assistant/utils.py b/learning_assistant/utils.py index 518fb01..7bd3fb8 100644 --- a/learning_assistant/utils.py +++ b/learning_assistant/utils.py @@ -10,8 +10,6 @@ from requests.exceptions import ConnectTimeout from rest_framework import status as http_status -from learning_assistant.platform_imports import get_cache_course_run_data - log = logging.getLogger(__name__) @@ -60,16 +58,7 @@ def get_reduced_message_list(prompt_template, message_list): return new_message_list -def get_course_id(course_run_id): - """ - Given a course run id (str), return the associated course key. - """ - course_data = get_cache_course_run_data(course_run_id, ['course']) - course_key = course_data['course'] - return course_key - - -def create_request_body(prompt_template, message_list, courserun_id): +def create_request_body(prompt_template, message_list, course_id): """ Form request body to be passed to the chat endpoint. """ @@ -77,7 +66,7 @@ def create_request_body(prompt_template, message_list, courserun_id): 'context': { 'content': prompt_template, 'render': { - 'doc_id': get_course_id(courserun_id), + 'doc_id': course_id, 'fields': ['skillNames', 'title'] } }, @@ -87,7 +76,7 @@ def create_request_body(prompt_template, message_list, courserun_id): return response_body -def get_chat_response(prompt_template, message_list, courserun_id): +def get_chat_response(prompt_template, message_list, course_id): """ Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting. """ @@ -98,7 +87,7 @@ def get_chat_response(prompt_template, message_list, courserun_id): connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1) read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15) - body = create_request_body(prompt_template, message_list, courserun_id) + body = create_request_body(prompt_template, message_list, course_id) try: response = requests.post( diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 5db882a..ed76ce6 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -19,7 +19,7 @@ except ImportError: pass -from learning_assistant.api import learning_assistant_enabled, render_prompt_template +from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template from learning_assistant.serializers import MessageSerializer from learning_assistant.utils import get_chat_response, user_role_is_staff @@ -34,9 +34,9 @@ class CourseChatView(APIView): authentication_classes = (SessionAuthentication, JwtAuthentication,) permission_classes = (IsAuthenticated,) - def post(self, request, course_id): + def post(self, request, course_run_id): """ - Given a course ID, retrieve a chat response for that course. + Given a course run ID, retrieve a chat response for that course. Expected POST data: { [ @@ -46,7 +46,7 @@ def post(self, request, course_id): } """ try: - courserun_key = CourseKey.from_string(course_id) + courserun_key = CourseKey.from_string(course_run_id) except InvalidKeyError: return Response( status=http_status.HTTP_400_BAD_REQUEST, @@ -89,11 +89,13 @@ def post(self, request, course_id): 'Attempting to retrieve chat response for user_id=%(user_id)s in course_id=%(course_id)s', { 'user_id': request.user.id, - 'course_id': course_id + 'course_id': course_run_id } ) - prompt_template = render_prompt_template(request, request.user.id, course_id, unit_id) + course_id = get_course_id(course_run_id) + + prompt_template = render_prompt_template(request, request.user.id, course_run_id, unit_id, course_id) status_code, message = get_chat_response(prompt_template, message_list, course_id) @@ -122,16 +124,16 @@ class LearningAssistantEnabledView(APIView): authentication_classes = (SessionAuthentication, JwtAuthentication,) permission_classes = (IsAuthenticated,) - def get(self, request, course_id): + def get(self, request, course_run_id): """ - Given a course ID, retrieve whether the Learning Assistant is enabled for the corresponding course. + Given a course run ID, retrieve whether the Learning Assistant is enabled for the corresponding course. The response will be in the following format. {'enabled': } """ try: - courserun_key = CourseKey.from_string(course_id) + courserun_key = CourseKey.from_string(course_run_id) except InvalidKeyError: return Response( status=http_status.HTTP_400_BAD_REQUEST, diff --git a/test_settings.py b/test_settings.py index 0302e54..78d99fc 100644 --- a/test_settings.py +++ b/test_settings.py @@ -83,6 +83,8 @@ def root(*args): "{{ unit_content }}" "\"" "{% endif %}" + "{{ skill_names }}" + "{{ title }}" ) LEARNING_ASSISTANT_AVAILABLE = True diff --git a/tests/test_api.py b/tests/test_api.py index 7042b0a..1f18d5f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -82,7 +82,7 @@ def setUp(self): ] self.block = FakeBlock(self.children) - self.course_id = 'course-v1:edx+test+23' + self.course_run_id = 'course-v1:edx+test+23' @ddt.data( ('video', True), @@ -156,7 +156,7 @@ def test_get_block_content(self, mock_get_children_contents, mock_get_single_blo # args does not matter for this test right now, as the `get_single_block` function is entirely mocked. request = MagicMock() user_id = 1 - course_id = self.course_id + course_id = self.course_run_id unit_usage_key = 'block-v1:edX+A+B+type@vertical+block@verticalD' length, items = get_block_content(request, user_id, course_id, unit_usage_key) @@ -183,25 +183,34 @@ def test_get_block_content(self, mock_get_children_contents, mock_get_single_blo ('', False), ) @ddt.unpack + @patch('learning_assistant.api.get_cache_course_data') @patch('learning_assistant.toggles._is_learning_assistant_waffle_flag_enabled') @patch('learning_assistant.api.get_block_content') - def test_render_prompt_template(self, unit_content, flag_enabled, mock_get_content, mock_is_flag_enabled): + def test_render_prompt_template( + self, unit_content, flag_enabled, mock_get_content, mock_is_flag_enabled, mock_cache + ): mock_get_content.return_value = (len(unit_content), unit_content) mock_is_flag_enabled.return_value = flag_enabled + skills_content = ['skills'] + title = 'title' + mock_cache.return_value = {'skill_names': skills_content, 'title': title} # mock arguments that are passed through to `get_block_content` function. the value of these # args does not matter for this test right now, as the `get_block_content` function is entirely mocked. request = MagicMock() user_id = 1 - course_id = self.course_id + course_run_id = self.course_run_id unit_usage_key = 'block-v1:edX+A+B+type@vertical+block@verticalD' + course_id = 'edx+test' - prompt_text = render_prompt_template(request, user_id, course_id, unit_usage_key) + prompt_text = render_prompt_template(request, user_id, course_run_id, unit_usage_key, course_id) if unit_content and flag_enabled: self.assertIn(unit_content, prompt_text) else: self.assertNotIn('The following text is useful.', prompt_text) + self.assertIn(str(skills_content), prompt_text) + self.assertIn(title, prompt_text) @ddt.ddt diff --git a/tests/test_utils.py b/tests/test_utils.py index a410c1f..359f7ff 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -24,14 +24,7 @@ def setUp(self): self.prompt_template = 'This is a prompt.' self.message_list = [{'role': 'assistant', 'content': 'Hello'}, {'role': 'user', 'content': 'Goodbye'}] - self.course_id = 'course-v1:edx+test+23' - self.course_key = 'edx-test' - - self.patcher = patch( - 'learning_assistant.utils.get_cache_course_run_data', - return_value={'course': self.course_key} - ) - self.patcher.start() + self.course_id = 'edx+test' def get_response(self): return get_chat_response(self.prompt_template, self.message_list, self.course_id) @@ -99,7 +92,7 @@ def test_post_request_structure(self, mock_requests): 'context': { 'content': self.prompt_template, 'render': { - 'doc_id': self.course_key, + 'doc_id': self.course_id, 'fields': ['skillNames', 'title'] } }, diff --git a/tests/test_views.py b/tests/test_views.py index 24bf4b7..fbbe8b5 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -82,17 +82,23 @@ def setUp(self): super().setUp() self.course_id = 'course-v1:edx+test+23' + self.patcher = patch( + 'learning_assistant.api.get_cache_course_run_data', + return_value={'course': 'edx+test'} + ) + self.patcher.start() + @patch('learning_assistant.views.learning_assistant_enabled') def test_invalid_course_id(self, mock_learning_assistant_enabled): mock_learning_assistant_enabled.return_value = True - response = self.client.get(reverse('enabled', kwargs={'course_id': self.course_id+'+invalid'})) + response = self.client.get(reverse('enabled', kwargs={'course_run_id': self.course_id+'+invalid'})) self.assertEqual(response.status_code, 400) @patch('learning_assistant.views.learning_assistant_enabled') def test_course_waffle_inactive(self, mock_waffle): mock_waffle.return_value = False - response = self.client.post(reverse('chat', kwargs={'course_id': self.course_id})) + response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) @patch('learning_assistant.views.learning_assistant_enabled') @@ -105,7 +111,7 @@ def test_user_no_enrollment_not_staff(self, mock_mode, mock_enrollment, mock_rol mock_mode.VERIFIED_MODES = ['verified'] mock_enrollment.return_value = None - response = self.client.post(reverse('chat', kwargs={'course_id': self.course_id})) + response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) @patch('learning_assistant.views.learning_assistant_enabled') @@ -118,7 +124,7 @@ def test_user_audit_enrollment_not_staff(self, mock_mode, mock_enrollment, mock_ mock_mode.VERIFIED_MODES = ['verified'] mock_enrollment.return_value = MagicMock(mode='audit') - response = self.client.post(reverse('chat', kwargs={'course_id': self.course_id})) + response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) @patch('learning_assistant.views.render_prompt_template') @@ -137,7 +143,7 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): ] response = self.client.post( - reverse('chat', kwargs={'course_id': self.course_id})+f'?unit_id={test_unit_id}', + reverse('chat', kwargs={'course_run_id': self.course_id})+f'?unit_id={test_unit_id}', data=json.dumps(test_data), content_type='application/json' ) @@ -164,7 +170,7 @@ def test_chat_response(self, mock_mode, mock_enrollment, mock_role, mock_waffle, ] response = self.client.post( - reverse('chat', kwargs={'course_id': self.course_id})+f'?unit_id={test_unit_id}', + reverse('chat', kwargs={'course_run_id': self.course_id})+f'?unit_id={test_unit_id}', data=json.dumps(test_data), content_type='application/json' ) @@ -196,7 +202,7 @@ def setUp(self): @patch('learning_assistant.views.learning_assistant_enabled') def test_learning_assistant_enabled(self, mock_value, expected_value, mock_learning_assistant_enabled): mock_learning_assistant_enabled.return_value = mock_value - response = self.client.get(reverse('enabled', kwargs={'course_id': self.course_id})) + response = self.client.get(reverse('enabled', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 200) self.assertEqual(response.data, {'enabled': expected_value}) @@ -204,6 +210,6 @@ def test_learning_assistant_enabled(self, mock_value, expected_value, mock_learn @patch('learning_assistant.views.learning_assistant_enabled') def test_invalid_course_id(self, mock_learning_assistant_enabled): mock_learning_assistant_enabled.return_value = True - response = self.client.get(reverse('enabled', kwargs={'course_id': self.course_id+'+invalid'})) + response = self.client.get(reverse('enabled', kwargs={'course_run_id': self.course_id+'+invalid'})) self.assertEqual(response.status_code, 400)