diff --git a/lms/services/lti_role_service.py b/lms/services/lti_role_service.py index f6008bfba9..c656a6b846 100644 --- a/lms/services/lti_role_service.py +++ b/lms/services/lti_role_service.py @@ -11,15 +11,17 @@ class LTIRoleService: def __init__(self, db_session: Session): self._db = db_session - def get_roles(self, role_description: str) -> list[LTIRole]: + def get_roles(self, role_description: str | list[str]) -> list[LTIRole]: """ Get a list of role objects for the provided strings. - :param role_description: A comma delimited set of role strings + :param role_description: A comma delimited set of role strings for LTI1.1 or a list of string for LTI1.3 """ - role_strings = [role.strip() for role in role_description.split(",")] + if isinstance(role_description, str): + role_strings = [role.strip() for role in role_description.split(",")] + else: + role_strings = role_description - # Pylint is confused about the `in_` for some reason roles = self._db.query(LTIRole).filter(LTIRole.value.in_(role_strings)).all() for role in roles: diff --git a/tests/unit/lms/services/lti_role_service_test.py b/tests/unit/lms/services/lti_role_service_test.py index af4e3eddc2..63a10c9213 100644 --- a/tests/unit/lms/services/lti_role_service_test.py +++ b/tests/unit/lms/services/lti_role_service_test.py @@ -10,7 +10,8 @@ class TestLTIRoleService: - def test_get_roles(self, svc, existing_roles): + @pytest.mark.parametrize("roles_as_string", [True, False]) + def test_get_roles(self, svc, existing_roles, roles_as_string): existing_role_strings = [role.value for role in existing_roles] new_roles = [ "http://purl.imsglobal.org/vocab/lis/v2/system/person#SysSupport", @@ -22,7 +23,9 @@ def test_get_roles(self, svc, existing_roles): role_descriptions.append(existing_roles[0].value) role_descriptions.extend(new_roles) - roles = svc.get_roles(", ".join(role_descriptions)) + if roles_as_string: + role_descriptions = ", ".join(role_descriptions) + roles = svc.get_roles(role_descriptions) expected_new_roles = [ Any.instance_of(LTIRole).with_attrs({"value": value}) for value in new_roles