From 3a12949d3e6881d073460f1c6b444451b7e78a8e Mon Sep 17 00:00:00 2001
From: Guillaume Andreu Sabater <guillaume.andreu.sabater@gmail.com>
Date: Tue, 23 Nov 2021 14:33:25 +0100
Subject: [PATCH] Added Django 4.0 compatibility (#317)

* added django 4.0 compat
* dropped django 3.0
---
 .github/workflows/python-package.yml |  2 +-
 djangosaml2/signals.py               |  6 ++----
 djangosaml2/tests/__init__.py        | 25 +++++++++++++++++--------
 djangosaml2/utils.py                 |  7 +++++--
 djangosaml2/views.py                 | 10 ++++------
 setup.py                             |  2 +-
 tox.ini                              |  4 ++--
 7 files changed, 32 insertions(+), 24 deletions(-)

diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index bb607137..8c8c7ab2 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -16,7 +16,7 @@ jobs:
     strategy:
       matrix:
         python-version: [3.8, 3.9]
-        django-version: ["2.2", "3.0", "3.1", "3.2"]
+        django-version: ["2.2", "3.1", "3.2", "4.0"]
 
     steps:
     - uses: actions/checkout@v2
diff --git a/djangosaml2/signals.py b/djangosaml2/signals.py
index 1c006626..ac5d30a3 100644
--- a/djangosaml2/signals.py
+++ b/djangosaml2/signals.py
@@ -14,7 +14,5 @@
 
 import django.dispatch
 
-pre_user_save = django.dispatch.Signal(
-    providing_args=['attributes', 'user_modified'])
-post_authenticated = django.dispatch.Signal(
-    providing_args=['session_info', 'request'])
+pre_user_save = django.dispatch.Signal()
+post_authenticated = django.dispatch.Signal()
diff --git a/djangosaml2/tests/__init__.py b/djangosaml2/tests/__init__.py
index 1a7c3693..867c05c7 100644
--- a/djangosaml2/tests/__init__.py
+++ b/djangosaml2/tests/__init__.py
@@ -21,6 +21,7 @@
 from unittest import mock
 from urllib.parse import parse_qs, urlparse
 
+from django import http
 from django.conf import settings
 from django.contrib.auth import SESSION_KEY, get_user_model
 from django.contrib.auth.models import AnonymousUser
@@ -28,7 +29,6 @@
 from django.test import Client, TestCase, override_settings
 from django.test.client import RequestFactory
 from django.urls import reverse, reverse_lazy
-from django.utils.encoding import force_text
 from djangosaml2 import views
 from djangosaml2.cache import OutstandingQueriesCache
 from djangosaml2.conf import get_config
@@ -57,6 +57,15 @@ def dummy_loader(request):
     return 'dummy_loader'
 
 
+def dummy_get_response(request: http.HttpRequest):
+    """
+    Return a basic HttpResponse.
+
+    Function needed to instantiate SamlSessionMiddleware.
+    """
+    return http.HttpResponse("Session test")
+
+
 non_callable = 'just a string'
 
 
@@ -433,7 +442,7 @@ def test_assertion_consumer_service(self):
 
         # as the RelayState is empty we have redirect to ACS_DEFAULT_REDIRECT_URL
         self.assertRedirects(response, '/dashboard/')
-        self.assertEqual(force_text(new_user.id), client.session[SESSION_KEY])
+        self.assertEqual(str(new_user.id), client.session[SESSION_KEY])
 
     @override_settings(ACS_DEFAULT_REDIRECT_URL='testprofiles:dashboard')
     def test_assertion_consumer_service_default_relay_state(self):
@@ -458,7 +467,7 @@ def test_assertion_consumer_service_default_relay_state(self):
 
         # The RelayState is missing, redirect to ACS_DEFAULT_REDIRECT_URL
         self.assertRedirects(response, '/dashboard/')
-        self.assertEqual(force_text(new_user.id), self.client.session[SESSION_KEY])
+        self.assertEqual(str(new_user.id), self.client.session[SESSION_KEY])
 
     def test_assertion_consumer_service_already_logged_in_allowed(self):
         self.client.force_login(User.objects.create(
@@ -566,7 +575,7 @@ def test_echo_view_no_saml_session(self):
         request.COOKIES = self.client.cookies
         request.user = User.objects.last()
 
-        middleware = SamlSessionMiddleware()
+        middleware = SamlSessionMiddleware(dummy_get_response)
         middleware.process_request(request)
 
         response = EchoAttributesView.as_view()(request)
@@ -585,7 +594,7 @@ def test_echo_view_success(self):
         request = RequestFactory().get('/')
         request.user = User.objects.last()
 
-        middleware = SamlSessionMiddleware()
+        middleware = SamlSessionMiddleware(dummy_get_response)
         middleware.process_request(request)
 
         saml_session_name = getattr(
@@ -808,7 +817,7 @@ def test_custom_conf_loader_from_view(self):
         config_loader_path = 'djangosaml2.tests.test_config_loader_with_real_conf'
         request = RequestFactory().get('/login/')
         request.user = AnonymousUser()
-        middleware = SamlSessionMiddleware()
+        middleware = SamlSessionMiddleware(dummy_get_response)
         middleware.process_request(request)
 
         saml_session_name = getattr(
@@ -855,7 +864,7 @@ def test_middleware_cookie_expireatbrowserclose(self):
             request = RequestFactory().get('/login/')
             request.user = AnonymousUser()
             request.session = session
-            middleware = SamlSessionMiddleware()
+            middleware = SamlSessionMiddleware(dummy_get_response)
             middleware.process_request(request)
 
             saml_session_name = getattr(
@@ -882,7 +891,7 @@ def test_middleware_cookie_with_expiry(self):
             request = RequestFactory().get('/login/')
             request.user = AnonymousUser()
             request.session = session
-            middleware = SamlSessionMiddleware()
+            middleware = SamlSessionMiddleware(dummy_get_response)
             middleware.process_request(request)
 
             saml_session_name = getattr(
diff --git a/djangosaml2/utils.py b/djangosaml2/utils.py
index 5f011d2c..ae392c0a 100644
--- a/djangosaml2/utils.py
+++ b/djangosaml2/utils.py
@@ -22,7 +22,10 @@
 from django.core.exceptions import ImproperlyConfigured
 from django.http import HttpResponse, HttpResponseRedirect
 from django.shortcuts import resolve_url
-from django.utils.http import is_safe_url
+try:
+    from django.utils.http import url_has_allowed_host_and_scheme
+except ImportError:  # django 2.2
+    from django.utils.http import is_safe_url as url_has_allowed_host_and_scheme
 from saml2.config import SPConfig
 from saml2.s_utils import UnknownSystemEntity
 
@@ -96,7 +99,7 @@ def validate_referral_url(request, url):
     saml_allowed_hosts = set(
         getattr(settings, 'SAML_ALLOWED_HOSTS', [request.get_host()]))
 
-    if not is_safe_url(url=url, allowed_hosts=saml_allowed_hosts):
+    if not url_has_allowed_host_and_scheme(url=url, allowed_hosts=saml_allowed_hosts):
         return get_fallback_login_redirect_url()
     return url
 
diff --git a/djangosaml2/views.py b/djangosaml2/views.py
index ce232280..0e382c94 100644
--- a/djangosaml2/views.py
+++ b/djangosaml2/views.py
@@ -16,6 +16,7 @@
 import base64
 import logging
 import saml2
+from urllib.parse import quote
 
 from django.conf import settings
 from django.contrib import auth
@@ -28,7 +29,6 @@
 from django.template import TemplateDoesNotExist
 from django.urls import reverse
 from django.utils.decorators import method_decorator
-from django.utils.http import urlquote
 from django.views.decorators.csrf import csrf_exempt
 from django.views.generic import View
 from django.utils.module_loading import import_string
@@ -208,13 +208,11 @@ def get(self, request, *args, **kwargs):
                 logger.debug(("A discovery process is needed trough a"
                               "Discovery Service: {}").format(discovery_service))
                 login_url = request.build_absolute_uri(reverse('saml2_login'))
-                login_url = '{0}?next={1}'.format(login_url,
-                                                  urlquote(next_path, safe=''))
+                login_url = '{0}?next={1}'.format(login_url, quote(next_path, safe=''))
                 ds_url = '{0}?entityID={1}&return={2}&returnIDParam=idp'
                 ds_url = ds_url.format(discovery_service,
-                                       urlquote(
-                                           getattr(conf, 'entityid'), safe=''),
-                                       urlquote(login_url, safe=''))
+                                       quote(getattr(conf, 'entityid'), safe=''),
+                                       quote(login_url, safe=''))
                 return HttpResponseRedirect(ds_url)
 
             elif len(configured_idps) > 1:
diff --git a/setup.py b/setup.py
index 6d9cdec0..da106760 100644
--- a/setup.py
+++ b/setup.py
@@ -33,9 +33,9 @@ def read(*rnames):
         "Environment :: Web Environment",
         "Framework :: Django",
         "Framework :: Django :: 2.2",
-        "Framework :: Django :: 3.0",
         "Framework :: Django :: 3.1",
         "Framework :: Django :: 3.2",
+        "Framework :: Django :: 4.0",
         "Intended Audience :: Developers",
         "License :: OSI Approved :: Apache Software License",
         "Operating System :: OS Independent",
diff --git a/tox.ini b/tox.ini
index 948c9817..11b964a2 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
 [tox]
 envlist =
-    py{3.6,3.7,3.8,3.9}-django{2.2,3.0,3.1,3.2}
+    py{3.6,3.7,3.8,3.9}-django{2.2,3.1,3.2,4.0}
 
 [testenv]
 commands =
@@ -8,9 +8,9 @@ commands =
 
 deps =
     django2.2: django~=2.2
-    django3.0: django~=3.0
     django3.1: django~=3.1
     django3.2: django~=3.2
+    django4.0: django==4.0rc1
     djangomaster: https://github.com/django/django/archive/master.tar.gz
     .[test]