From a3ab134620eabc8a602c8e3bf4d00dbb1d02da40 Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Sun, 20 Oct 2024 10:04:45 +0200 Subject: [PATCH] apply manual fixes --- ca/django_ca/tests/acme/views/base.py | 3 +- ca/django_ca/tests/acme/views/conftest.py | 22 +- .../tests/acme/views/test_authorization.py | 5 +- .../tests/acme/views/test_challenge.py | 5 +- .../tests/acme/views/test_new_account.py | 9 +- .../tests/acme/views/test_new_order.py | 8 +- ca/django_ca/tests/acme/views/test_order.py | 5 +- .../tests/acme/views/test_order_finalize.py | 9 +- .../tests/acme/views/test_revocation.py | 11 +- .../tests/acme/views/test_update_account.py | 5 +- .../tests/acme/views/test_view_cert.py | 7 +- ca/django_ca/tests/admin/base.py | 8 +- ca/django_ca/tests/admin/conftest.py | 6 +- ca/django_ca/tests/admin/test_actions.py | 38 +- ca/django_ca/tests/admin/test_add_cert.py | 508 ++++++++---------- ca/django_ca/tests/admin/test_admin_ca.py | 14 +- ca/django_ca/tests/admin/test_extra_views.py | 43 +- ca/django_ca/tests/base/assertions.py | 37 +- ca/django_ca/tests/base/conftest_helpers.py | 20 +- ca/django_ca/tests/base/fixtures.py | 80 ++- ca/django_ca/tests/base/mixins.py | 108 +--- ca/django_ca/tests/commands/test_dump_ca.py | 2 +- ca/django_ca/tests/commands/test_dump_crl.py | 2 +- ca/django_ca/tests/commands/test_init_ca.py | 4 +- ca/django_ca/tests/commands/test_list_cas.py | 126 ++--- .../tests/commands/test_list_certs.py | 4 +- ca/django_ca/tests/commands/test_notify.py | 28 +- .../commands/test_regenerate_ocsp_keys.py | 14 +- .../tests/commands/test_resign_cert.py | 282 +++++----- .../tests/commands/test_revoke_cert.py | 36 +- ca/django_ca/tests/conftest.py | 3 +- .../tests/extensions/test_admin_html.py | 2 +- .../extensions/test_unknown_extension.py | 10 +- ca/django_ca/tests/extensions/test_utils.py | 4 +- .../tests/key_backends/hsm/test_backend.py | 4 +- .../tests/key_backends/hsm/test_models.py | 6 +- .../tests/key_backends/hsm/test_session.py | 2 +- .../tests/key_backends/test_storages.py | 2 +- ca/django_ca/tests/models/test_certificate.py | 4 +- .../models/test_certificate_authority.py | 2 +- ca/django_ca/tests/pydantic/base.py | 2 +- .../tests/pydantic/test_extensions.py | 83 +-- .../tests/pydantic/test_general_name.py | 8 +- ca/django_ca/tests/pydantic/test_name.py | 8 +- .../tests/pydantic/test_type_aliases.py | 14 +- .../tests/pydantic/test_validators.py | 32 +- ca/django_ca/tests/test_acme.py | 43 +- ca/django_ca/tests/test_base.py | 22 +- ca/django_ca/tests/test_checks.py | 12 +- ca/django_ca/tests/test_fields.py | 36 +- ca/django_ca/tests/test_management_actions.py | 257 ++++----- ca/django_ca/tests/test_migration_helpers.py | 12 +- ca/django_ca/tests/test_models.py | 247 ++++----- ca/django_ca/tests/test_querysets.py | 35 +- ca/django_ca/tests/test_settings.py | 24 +- ca/django_ca/tests/test_sphinx_extensions.py | 4 +- ca/django_ca/tests/test_tasks.py | 200 ++++--- ca/django_ca/tests/test_typehints.py | 2 +- ca/django_ca/tests/test_utils.py | 316 ++++++----- ca/django_ca/tests/test_views_ocsp.py | 216 ++++---- .../tests/utils/test_get_crl_cache_key.py | 2 +- ca/django_ca/tests/utils/test_othername.py | 9 +- .../tests/utils/test_parse_general_name.py | 4 +- .../tests/utils/test_parse_name_rfc4514.py | 6 +- .../tests/utils/test_parse_name_x509.py | 12 +- ca/django_ca/tests/utils/test_split_str.py | 13 +- .../tests/utils/test_validate_hostname.py | 2 +- .../test_certificate_revocation_list_view.py | 2 +- ca/django_ca/views.py | 2 - 69 files changed, 1425 insertions(+), 1688 deletions(-) diff --git a/ca/django_ca/tests/acme/views/base.py b/ca/django_ca/tests/acme/views/base.py index bdeee2fe1..0574c602e 100644 --- a/ca/django_ca/tests/acme/views/base.py +++ b/ca/django_ca/tests/acme/views/base.py @@ -15,7 +15,6 @@ import abc import typing -from collections.abc import Iterator from http import HTTPStatus from typing import Optional, Union from unittest import mock @@ -255,7 +254,7 @@ class AcmeWithAccountViewTestCaseMixin( """Mixin that also adds accounts to the database.""" @pytest.fixture - def main_account(self, account: AcmeAccount) -> Iterator[AcmeAccount]: + def main_account(self, account: AcmeAccount) -> AcmeAccount: """Return the main account to be used for this test case. This is overwritten by the revocation test case. diff --git a/ca/django_ca/tests/acme/views/conftest.py b/ca/django_ca/tests/acme/views/conftest.py index 6c2bb947b..967fbaa47 100644 --- a/ca/django_ca/tests/acme/views/conftest.py +++ b/ca/django_ca/tests/acme/views/conftest.py @@ -15,8 +15,6 @@ # pylint: disable=redefined-outer-name -from collections.abc import Iterator - from django.test import Client import pytest @@ -36,32 +34,32 @@ @pytest.fixture -def account_slug() -> Iterator[str]: +def account_slug() -> str: """Fixture for an account slug.""" return acme_slug() @pytest.fixture -def order_slug() -> Iterator[str]: +def order_slug() -> str: """Fixture for an order slug.""" return acme_slug() @pytest.fixture -def acme_cert_slug() -> Iterator[str]: +def acme_cert_slug() -> str: """Fixture for an ACME certificate slug.""" return acme_slug() @pytest.fixture -def client(client: Client) -> Iterator[Client]: +def client(client: Client) -> Client: """Override client fixture to set the default server name.""" client.defaults["SERVER_NAME"] = SERVER_NAME return client @pytest.fixture -def account(root: CertificateAuthority, account_slug: str, kid: str) -> Iterator[AcmeAccount]: +def account(root: CertificateAuthority, account_slug: str, kid: str) -> AcmeAccount: """Fixture for an account.""" return AcmeAccount.objects.create( ca=root, @@ -75,25 +73,25 @@ def account(root: CertificateAuthority, account_slug: str, kid: str) -> Iterator @pytest.fixture -def kid(root: CertificateAuthority, account_slug: str) -> Iterator[str]: +def kid(root: CertificateAuthority, account_slug: str) -> str: """Fixture for a full KID.""" return absolute_acme_uri(":acme-account", serial=root.serial, slug=account_slug) @pytest.fixture -def order(account: AcmeAccount, order_slug: str) -> Iterator[AcmeOrder]: +def order(account: AcmeAccount, order_slug: str) -> AcmeOrder: """Fixture for an order.""" return AcmeOrder.objects.create(account=account, slug=order_slug) @pytest.fixture -def authz(order: AcmeOrder) -> Iterator[AcmeAuthorization]: +def authz(order: AcmeOrder) -> AcmeAuthorization: """Fixture for an authorization.""" return AcmeAuthorization.objects.create(order=order, value=HOST_NAME) @pytest.fixture -def challenge(authz: AcmeAuthorization) -> Iterator[AcmeChallenge]: +def challenge(authz: AcmeAuthorization) -> AcmeChallenge: """Fixture for a challenge.""" challenge = authz.get_challenges()[0] challenge.token = "foobar" @@ -102,6 +100,6 @@ def challenge(authz: AcmeAuthorization) -> Iterator[AcmeChallenge]: @pytest.fixture -def acme_cert(root_cert: Certificate, order: AcmeOrder, acme_cert_slug: str) -> Iterator[AcmeCertificate]: +def acme_cert(root_cert: Certificate, order: AcmeOrder, acme_cert_slug: str) -> AcmeCertificate: """Fixture for an ACME certificate.""" return AcmeCertificate.objects.create(order=order, cert=root_cert, slug=acme_cert_slug) diff --git a/ca/django_ca/tests/acme/views/test_authorization.py b/ca/django_ca/tests/acme/views/test_authorization.py index dc657e270..1e1912c51 100644 --- a/ca/django_ca/tests/acme/views/test_authorization.py +++ b/ca/django_ca/tests/acme/views/test_authorization.py @@ -15,7 +15,6 @@ # pylint: disable=redefined-outer-name # because of fixtures -from collections.abc import Iterator from http import HTTPStatus import josepy as jose @@ -41,13 +40,13 @@ @pytest.fixture -def url(authz: AcmeAuthorization) -> Iterator[str]: +def url(authz: AcmeAuthorization) -> str: """URL under test.""" return root_reverse("acme-authz", slug=authz.slug) @pytest.fixture -def message() -> Iterator[bytes]: +def message() -> bytes: """Yield an empty bytestring, since this is a POST-AS-GET request.""" return b"" diff --git a/ca/django_ca/tests/acme/views/test_challenge.py b/ca/django_ca/tests/acme/views/test_challenge.py index 504113207..283fd87c1 100644 --- a/ca/django_ca/tests/acme/views/test_challenge.py +++ b/ca/django_ca/tests/acme/views/test_challenge.py @@ -16,7 +16,6 @@ # pylint: disable=redefined-outer-name # because of fixtures import unittest -from collections.abc import Iterator from http import HTTPStatus from typing import Optional from unittest import mock @@ -42,13 +41,13 @@ @pytest.fixture -def url(challenge: AcmeChallenge) -> Iterator[str]: +def url(challenge: AcmeChallenge) -> str: """URL under test.""" return root_reverse("acme-challenge", slug=challenge.slug) @pytest.fixture -def message() -> Iterator[bytes]: +def message() -> bytes: """Yield an empty bytestring, since this is a POST-AS-GET request.""" return b"" diff --git a/ca/django_ca/tests/acme/views/test_new_account.py b/ca/django_ca/tests/acme/views/test_new_account.py index d29b114bc..c06f3420e 100644 --- a/ca/django_ca/tests/acme/views/test_new_account.py +++ b/ca/django_ca/tests/acme/views/test_new_account.py @@ -15,7 +15,6 @@ # pylint: disable=redefined-outer-name # because of fixtures -from collections.abc import Iterator from http import HTTPStatus from unittest import mock @@ -53,19 +52,19 @@ @pytest.fixture -def url() -> Iterator[str]: +def url() -> str: """URL under test.""" return root_reverse("acme-new-account") @pytest.fixture -def message() -> Iterator[Registration]: +def message() -> Registration: """Default message sent to the server.""" return Registration(contact=(CONTACT,), terms_of_service_agreed=True) @pytest.fixture -def kid() -> Iterator[None]: +def kid() -> None: """Request requires no kid, yield None.""" return @@ -228,7 +227,7 @@ def test_unsupported_contact(client: Client, url: str, root: CertificateAuthorit @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ('mailto:"with spaces"@example.com', "Quoted local part in email is not allowed."), ("mailto:user@example.com,user@example.net", "More than one addr-spec is not allowed."), diff --git a/ca/django_ca/tests/acme/views/test_new_order.py b/ca/django_ca/tests/acme/views/test_new_order.py index 984d30bcf..986ab4c49 100644 --- a/ca/django_ca/tests/acme/views/test_new_order.py +++ b/ca/django_ca/tests/acme/views/test_new_order.py @@ -15,7 +15,6 @@ # pylint: disable=redefined-outer-name # because of fixtures -from collections.abc import Iterator from datetime import timedelta, timezone as tz from http import HTTPStatus from typing import Any @@ -46,13 +45,13 @@ @pytest.fixture -def url() -> Iterator[str]: +def url() -> str: """URL under test.""" return root_reverse("acme-new-order") @pytest.fixture -def message() -> Iterator[NewOrder]: +def message() -> NewOrder: """Default message sent to the server.""" return NewOrder(identifiers=[{"type": "dns", "value": SERVER_NAME}]) @@ -187,7 +186,7 @@ def test_no_identifiers(client: Client, url: str, root: CertificateAuthority, ki @pytest.mark.usefixtures("account") @pytest.mark.parametrize( - "values,expected", + ("values", "expected"), ( ({"not_before": now - timedelta(days=1)}, "Certificate cannot be valid before now."), ({"not_after": now + timedelta(days=3650)}, "Certificate cannot be valid that long."), @@ -201,7 +200,6 @@ def test_invalid_not_before_after( client: Client, url: str, root: CertificateAuthority, kid: str, values: dict[str, Any], expected: str ) -> None: """Test invalid not_before/not_after dates.""" - print(values) message = NewOrder(identifiers=[{"type": "dns", "value": SERVER_NAME}], **values) resp = acme_request(client, url, root, message, kid=kid) assert_malformed(resp, root, expected) diff --git a/ca/django_ca/tests/acme/views/test_order.py b/ca/django_ca/tests/acme/views/test_order.py index c113c582b..3eafd57fd 100644 --- a/ca/django_ca/tests/acme/views/test_order.py +++ b/ca/django_ca/tests/acme/views/test_order.py @@ -15,7 +15,6 @@ # pylint: disable=redefined-outer-name # because of fixtures -from collections.abc import Iterator from http import HTTPStatus from typing import Optional from unittest import mock @@ -51,13 +50,13 @@ @pytest.fixture -def url(order: AcmeOrder) -> Iterator[str]: +def url(order: AcmeOrder) -> str: """URL under test.""" return root_reverse("acme-order", slug=order.slug) @pytest.fixture -def message() -> Iterator[bytes]: +def message() -> bytes: """Yield an empty bytestring, since this is a POST-AS-GET request.""" return b"" diff --git a/ca/django_ca/tests/acme/views/test_order_finalize.py b/ca/django_ca/tests/acme/views/test_order_finalize.py index 0fc30e9f0..78d929e00 100644 --- a/ca/django_ca/tests/acme/views/test_order_finalize.py +++ b/ca/django_ca/tests/acme/views/test_order_finalize.py @@ -15,7 +15,6 @@ # pylint: disable=redefined-outer-name -from collections.abc import Iterator from http import HTTPStatus from typing import Optional from unittest import mock @@ -64,7 +63,7 @@ @pytest.fixture -def order(order: AcmeOrder) -> Iterator[AcmeOrder]: +def order(order: AcmeOrder) -> AcmeOrder: """Override the module-level fixture to set the status to ready.""" order.status = AcmeOrder.STATUS_READY order.save() @@ -72,7 +71,7 @@ def order(order: AcmeOrder) -> Iterator[AcmeOrder]: @pytest.fixture -def authz(authz: AcmeAuthorization) -> Iterator[AcmeAuthorization]: +def authz(authz: AcmeAuthorization) -> AcmeAuthorization: """Override the module-level fixture to set the status to valid.""" authz.status = AcmeAuthorization.STATUS_VALID authz.save() @@ -80,13 +79,13 @@ def authz(authz: AcmeAuthorization) -> Iterator[AcmeAuthorization]: @pytest.fixture -def url(order: AcmeOrder) -> Iterator[str]: +def url(order: AcmeOrder) -> str: """URL under test.""" return root_reverse("acme-order-finalize", slug=order.slug) @pytest.fixture -def message() -> Iterator[CertificateRequest]: +def message() -> CertificateRequest: """Default message sent to the server.""" req = X509Req.from_cryptography(CSR) return CertificateRequest(csr=jose.util.ComparableX509(req)) diff --git a/ca/django_ca/tests/acme/views/test_revocation.py b/ca/django_ca/tests/acme/views/test_revocation.py index 2945db3a6..172757d6a 100644 --- a/ca/django_ca/tests/acme/views/test_revocation.py +++ b/ca/django_ca/tests/acme/views/test_revocation.py @@ -16,7 +16,6 @@ # pylint: disable=redefined-outer-name # because of fixtures import unittest -from collections.abc import Iterator from datetime import datetime from http import HTTPStatus from typing import Any, Optional, Union @@ -58,13 +57,13 @@ @pytest.fixture -def url() -> Iterator[str]: +def url() -> str: """URL under test.""" return root_reverse("acme-revoke") @pytest.fixture -def message() -> Iterator[Revocation]: +def message() -> Revocation: """Default message sent to the server.""" default_certificate = CERT_DATA["root-cert"]["pub"]["parsed"] return Revocation(certificate=jose.util.ComparableX509(X509.from_cryptography(default_certificate))) @@ -107,7 +106,7 @@ def acme( return acme_request(client, url, ca, message, kid=kid) @pytest.mark.parametrize( - "use_tz, timestamp", + ("use_tz", "timestamp"), ((True, TIMESTAMPS["everything_valid"]), (False, TIMESTAMPS["everything_valid_naive"])), ) def test_basic( @@ -218,7 +217,7 @@ class TestAcmeCertificateRevocationWithAuthorizationsView(TestAcmeCertificateRev CHILD_SLUG = acme_slug() @pytest.fixture - def child_kid_fixture(self, root: CertificateAuthority) -> Iterator[str]: + def child_kid_fixture(self, root: CertificateAuthority) -> str: """Fixture to set compute the child KID.""" return self.absolute_uri(":acme-account", serial=root.serial, slug=self.CHILD_SLUG) @@ -287,7 +286,7 @@ def test_wrong_url(self) -> None: # type: ignore[override] pass @pytest.fixture - def kid(self, child_kid_fixture: str) -> Iterator[Optional[str]]: + def kid(self, child_kid_fixture: str) -> Optional[str]: """Override kid to return the child kid.""" return child_kid_fixture diff --git a/ca/django_ca/tests/acme/views/test_update_account.py b/ca/django_ca/tests/acme/views/test_update_account.py index f7d649145..1a854299f 100644 --- a/ca/django_ca/tests/acme/views/test_update_account.py +++ b/ca/django_ca/tests/acme/views/test_update_account.py @@ -16,7 +16,6 @@ # pylint: disable=redefined-outer-name # because of fixtures import unittest -from collections.abc import Iterator from http import HTTPStatus from acme.messages import IDENTIFIER_FQDN, Identifier, Registration @@ -39,13 +38,13 @@ @pytest.fixture -def url(account_slug: str) -> Iterator[str]: +def url(account_slug: str) -> str: """URL under test.""" return root_reverse("acme-account", slug=account_slug) @pytest.fixture -def message() -> Iterator[Registration]: +def message() -> Registration: """Default message sent to the server.""" return Registration() diff --git a/ca/django_ca/tests/acme/views/test_view_cert.py b/ca/django_ca/tests/acme/views/test_view_cert.py index 6138418a5..70cea026b 100644 --- a/ca/django_ca/tests/acme/views/test_view_cert.py +++ b/ca/django_ca/tests/acme/views/test_view_cert.py @@ -15,7 +15,6 @@ # pylint: disable=redefined-outer-name # for to fixtures -from collections.abc import Iterator from http import HTTPStatus from typing import Optional @@ -37,7 +36,7 @@ @pytest.fixture -def order(order: AcmeOrder) -> Iterator[AcmeOrder]: +def order(order: AcmeOrder) -> AcmeOrder: """Override to set status to valid.""" order.status = AcmeOrder.STATUS_VALID order.save() @@ -45,13 +44,13 @@ def order(order: AcmeOrder) -> Iterator[AcmeOrder]: @pytest.fixture -def url(acme_cert_slug: str) -> Iterator[str]: +def url(acme_cert_slug: str) -> str: """URL under test.""" return root_reverse("acme-cert", slug=acme_cert_slug) @pytest.fixture -def message() -> Iterator[bytes]: +def message() -> bytes: """Yield an empty bytestring, since this is a POST-AS-GET request.""" return b"" diff --git a/ca/django_ca/tests/admin/base.py b/ca/django_ca/tests/admin/base.py index afe5b7b06..893157f7d 100644 --- a/ca/django_ca/tests/admin/base.py +++ b/ca/django_ca/tests/admin/base.py @@ -60,16 +60,16 @@ def setUp(self) -> None: def assertModified(self) -> None: # pylint: disable=invalid-name """Assert that the field was modified.""" - self.assertEqual(self.key_value_field.get_attribute("data-modified"), "true") + assert self.key_value_field.get_attribute("data-modified") == "true" def assertNotModified(self) -> None: # pylint: disable=invalid-name """Assert that the field was not modified.""" - self.assertNotEqual(self.key_value_field.get_attribute("data-modified"), "true") + assert self.key_value_field.get_attribute("data-modified") != "true" def assertChapterHasValue(self, chapter: WebElement, value: Any) -> None: # pylint: disable=invalid-name """Assert that the given chapter has the given value.""" loaded_value = json.loads(chapter.get_attribute("data-value")) # type: ignore[arg-type] - self.assertEqual(loaded_value, value) + assert loaded_value == value def initialize(self) -> None: """Load the page and find core elements. @@ -94,7 +94,7 @@ def displayed_value(self) -> list[dict[str, str]]: """Load the currently displayed value from the key/value list.""" selects = self.key_value_list.find_elements(By.CSS_SELECTOR, "select") inputs = self.key_value_list.find_elements(By.CSS_SELECTOR, "input") - self.assertEqual(len(selects), len(inputs)) + assert len(selects) == len(inputs) return [ { diff --git a/ca/django_ca/tests/admin/conftest.py b/ca/django_ca/tests/admin/conftest.py index d5ab1d0e9..8b91e1974 100644 --- a/ca/django_ca/tests/admin/conftest.py +++ b/ca/django_ca/tests/admin/conftest.py @@ -13,8 +13,6 @@ """Extra fixtures for tests for the admin interface.""" -from collections.abc import Iterator - from django.test import Client from django.urls import reverse @@ -25,13 +23,13 @@ @pytest.fixture(params=["name_to_rfc4514"]) -def extra_view_url(request: "SubRequest") -> Iterator[str]: +def extra_view_url(request: "SubRequest") -> str: """Parametrized fixture providing reversed extra view URLs.""" return reverse(f"admin:django_ca_certificate_{request.param}") @pytest.fixture -def staff_client(user: "User", user_client: Client) -> Iterator[Client]: +def staff_client(user: "User", user_client: Client) -> Client: """Client with a staff user with no extra permissions.""" user.is_staff = True user.save() diff --git a/ca/django_ca/tests/admin/test_actions.py b/ca/django_ca/tests/admin/test_actions.py index 00dcdec8a..c5582c50e 100644 --- a/ca/django_ca/tests/admin/test_actions.py +++ b/ca/django_ca/tests/admin/test_actions.py @@ -39,7 +39,7 @@ from django_ca.models import Certificate, X509CertMixin from django_ca.pydantic.general_name import GeneralNameModelList from django_ca.signals import post_issue_cert, post_revoke_cert, pre_revoke_cert, pre_sign_cert -from django_ca.tests.base.assertions import assert_revoked +from django_ca.tests.base.assertions import assert_extension_equal, assert_revoked from django_ca.tests.base.constants import TIMESTAMPS from django_ca.tests.base.mixins import AdminTestCaseMixin from django_ca.tests.base.mocks import mock_signal @@ -80,7 +80,7 @@ def test_user_is_staff_only(self) -> None: for obj in self.get_objects(): response = self.client.post(self.changelist_url, self.data) - self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) + assert response.status_code == HTTPStatus.FORBIDDEN self.assertFailedRequest(response, obj) def test_insufficient_permissions(self) -> None: @@ -117,7 +117,7 @@ def test_insufficient_permissions(self) -> None: for obj in self.get_objects(): response = self.client.post(self.changelist_url, self.data) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertFailedRequest(response, obj) def test_required_permissions(self) -> None: @@ -166,7 +166,7 @@ def assertForbidden( # pylint: disable=invalid-name self, response: "HttpResponse", obj: Optional[DjangoCAModelTypeVar] = None ) -> None: """Assert that the action returned HTTP 403 (Forbidden).""" - self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) + assert response.status_code == HTTPStatus.FORBIDDEN self.assertFailedRequest(response, obj=obj) @contextmanager @@ -206,7 +206,7 @@ def test_get(self) -> None: for obj in self.get_objects(): with self.assertNoSignals(): response = self.client.get(self.get_url(obj=obj)) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK def test_anonymous(self) -> None: """Test performing action as anonymous user.""" @@ -303,9 +303,9 @@ def assertFormValidationError( # pylint: disable=invalid-name ) -> None: """Assert that the form validation failed with the given errors.""" self.assertNotRevoked(cert) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertTemplateUsed("admin/django_ca/certificate/revoke_form.html") - self.assertEqual(response.context["form"].errors, errors) + assert response.context["form"].errors == errors def assertSuccessfulRequest( self, @@ -399,7 +399,7 @@ class ResignChangeActionTestCase(AdminChangeActionTestCaseMixin[Certificate], We def assertFailedRequest(self, response: "HttpResponse", obj: Optional[Certificate] = None) -> None: obj = obj or self.cert - self.assertEqual(self.model.objects.filter(cn=obj.cn).count(), 1) + assert self.model.objects.filter(cn=obj.cn).count() == 1 def assertSuccessfulRequest( self, @@ -410,13 +410,13 @@ def assertSuccessfulRequest( obj.refresh_from_db() resigned = Certificate.objects.filter(cn=obj.cn).exclude(pk=obj.pk).get() - self.assertFalse(resigned.revoked) - self.assertFalse(obj.revoked) - self.assertEqual(obj.cn, resigned.cn) - self.assertEqual(obj.csr, resigned.csr) - self.assertEqual(obj.profile, resigned.profile) - self.assertEqual(obj.cn, resigned.cn) - self.assertEqual(obj.algorithm, resigned.algorithm) + assert not resigned.revoked + assert not obj.revoked + assert obj.cn == resigned.cn + assert obj.csr == resigned.csr + assert obj.profile == resigned.profile + assert obj.cn == resigned.cn + assert obj.algorithm == resigned.algorithm for oid in [ ExtensionOID.EXTENDED_KEY_USAGE, @@ -424,11 +424,11 @@ def assertSuccessfulRequest( ExtensionOID.KEY_USAGE, ExtensionOID.SUBJECT_ALTERNATIVE_NAME, ]: - self.assertEqual(obj.extensions.get(oid), resigned.extensions.get(oid)) + assert_extension_equal(obj.extensions.get(oid), resigned.extensions.get(oid)) # Some properties are obviously *not* equal - self.assertNotEqual(obj.pub, resigned.pub) - self.assertNotEqual(obj.serial, resigned.serial) + assert obj.pub != resigned.pub + assert obj.serial != resigned.serial @property def data(self) -> dict[str, Any]: # type: ignore[override] @@ -488,7 +488,7 @@ def test_no_profile(self) -> None: form.submit().follow() resigned = Certificate.objects.filter(cn=self.cert.cn).exclude(pk=self.cert.pk).get() - self.assertEqual(resigned.profile, model_settings.CA_DEFAULT_PROFILE) + assert resigned.profile == model_settings.CA_DEFAULT_PROFILE @override_tmpcadir() def test_webtest_basic(self) -> None: diff --git a/ca/django_ca/tests/admin/test_add_cert.py b/ca/django_ca/tests/admin/test_add_cert.py index eb0d0d53e..3c1584469 100644 --- a/ca/django_ca/tests/admin/test_add_cert.py +++ b/ca/django_ca/tests/admin/test_add_cert.py @@ -55,6 +55,7 @@ from django_ca.tests.admin.base import AddCertificateSeleniumTestCase, CertificateModelAdminTestCaseMixin from django_ca.tests.base.assertions import ( assert_authority_key_identifier, + assert_count_equal, assert_create_cert_signals, assert_extensions, assert_post_issue_cert, @@ -174,16 +175,13 @@ def add_cert(self, cname: str, ca: CertificateAuthority, algorithm: str = "SHA-2 cert = Certificate.objects.get(cn=cname) assert_post_issue_cert(post, cert) - self.assertEqual( - cert.pub.loaded.subject, - x509.Name( - [ - x509.NameAttribute(oid=NameOID.COUNTRY_NAME, value="US"), - x509.NameAttribute(oid=NameOID.COMMON_NAME, value=cname), - ] - ), + assert cert.pub.loaded.subject == x509.Name( + [ + x509.NameAttribute(oid=NameOID.COUNTRY_NAME, value="US"), + x509.NameAttribute(oid=NameOID.COMMON_NAME, value=cname), + ] ) - self.assertIssuer(ca, cert) + assert cert.issuer == ca.subject assert_extensions( cert, [ @@ -203,24 +201,24 @@ def add_cert(self, cname: str, ca: CertificateAuthority, algorithm: str = "SHA-2 ), ], ) - self.assertEqual(cert.ca, ca) - self.assertEqual(cert.csr.pem, CSR) - self.assertEqual(cert.profile, "webserver") + assert cert.ca == ca + assert cert.csr.pem == CSR + assert cert.profile == "webserver" # Some extensions are NOT set - self.assertNotIn(ExtensionOID.ISSUER_ALTERNATIVE_NAME, cert.extensions) + assert ExtensionOID.ISSUER_ALTERNATIVE_NAME not in cert.extensions # Test that we can view the certificate response = self.client.get(cert.admin_change_url) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK def _test_get(self) -> "HttpResponse": """Do a basic get request (to test CSS etc).""" response = self.client.get(self.add_url) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK templates = [t.name for t in response.templates] - self.assertIn("admin/django_ca/certificate/change_form.html", templates) - self.assertIn("admin/change_form.html", templates) + assert "admin/django_ca/certificate/change_form.html" in templates + assert "admin/change_form.html" in templates assert_css(response, "django_ca/admin/css/base.css") assert_css(response, "django_ca/admin/css/certificateadmin.css") return response @@ -247,14 +245,14 @@ def test_default_ca_key_does_not_exist(self) -> None: """View add form when the ca key does not exist.""" storages["django-ca"].delete(self.ca.key_backend_options["path"]) response = self.client.get(self.add_url) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK form = response.context_data["adminform"].form # type: ignore[attr-defined] # false positive field = form.fields["ca"] bound_field = field.get_bound_field(form, "ca") - self.assertNotEqual(bound_field.initial, self.ca) - self.assertIsInstance(bound_field.initial, CertificateAuthority) + assert bound_field.initial != self.ca + assert isinstance(bound_field.initial, CertificateAuthority) @override_tmpcadir(CA_DEFAULT_CA=CERT_DATA["child"]["serial"]) def test_cas_expired(self) -> None: @@ -284,7 +282,7 @@ def test_get_profiles(self) -> None: field = form.fields["ocsp_no_check"] bound_field = field.get_bound_field(form, "ocsp_no_check") - self.assertEqual(bound_field.initial, ocsp_no_check(critical=True)) + assert bound_field.initial == ocsp_no_check(critical=True) @override_tmpcadir(CA_DEFAULT_SUBJECT=tuple()) def test_add(self) -> None: @@ -306,10 +304,9 @@ def test_empty_subject(self) -> None: cert: Certificate = Certificate.objects.get(cn="") assert_post_issue_cert(post, cert) - self.assertEqual(cert.subject, x509.Name([])) - self.assertEqual( - cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME], - subject_alternative_name(dns(self.hostname)), + assert cert.subject == x509.Name([]) + assert cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME] == subject_alternative_name( + dns(self.hostname) ) @override_tmpcadir(CA_DEFAULT_SUBJECT=tuple()) @@ -335,16 +332,13 @@ def test_subject_with_multiple_org_units(self) -> None: cert: Certificate = Certificate.objects.get(cn=self.hostname) assert_post_issue_cert(post, cert) - self.assertEqual( - cert.subject, - x509.Name( - [ - x509.NameAttribute(oid=NameOID.COUNTRY_NAME, value="US"), - x509.NameAttribute(oid=NameOID.ORGANIZATIONAL_UNIT_NAME, value="OU-1"), - x509.NameAttribute(oid=NameOID.ORGANIZATIONAL_UNIT_NAME, value="OU-2"), - x509.NameAttribute(oid=NameOID.COMMON_NAME, value=self.hostname), - ] - ), + assert cert.subject == x509.Name( + [ + x509.NameAttribute(oid=NameOID.COUNTRY_NAME, value="US"), + x509.NameAttribute(oid=NameOID.ORGANIZATIONAL_UNIT_NAME, value="OU-1"), + x509.NameAttribute(oid=NameOID.ORGANIZATIONAL_UNIT_NAME, value="OU-2"), + x509.NameAttribute(oid=NameOID.COMMON_NAME, value=self.hostname), + ] ) @override_tmpcadir(CA_DEFAULT_SUBJECT=tuple()) @@ -363,17 +357,14 @@ def test_add_no_common_name_and_no_subject_alternative_name(self) -> None: "subject_alternative_name_1": True, }, ) - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertFalse(response.context["adminform"].form.is_valid()) - self.assertEqual( - response.context["adminform"].form.errors, - { - "subject_alternative_name": [ - "Subject Alternative Name is required if the subject does not contain a Common Name." - ] - }, - ) - self.assertEqual(cert_count, Certificate.objects.all().count()) + assert response.status_code == HTTPStatus.OK + assert not response.context["adminform"].form.is_valid() + assert response.context["adminform"].form.errors == { + "subject_alternative_name": [ + "Subject Alternative Name is required if the subject does not contain a Common Name." + ] + } + assert cert_count == Certificate.objects.all().count() @override_tmpcadir(CA_DEFAULT_SUBJECT=tuple()) def test_subject_with_multiple_country_codes(self) -> None: @@ -394,10 +385,10 @@ def test_subject_with_multiple_country_codes(self) -> None: ), }, ) - self.assertFalse(response.context["adminform"].form.is_valid()) + assert not response.context["adminform"].form.is_valid() msg = "Value error, attribute of type countryName must not occur more then once in a name." - self.assertEqual(response.context["adminform"].form.errors, {"subject": [msg]}) + assert response.context["adminform"].form.errors == {"subject": [msg]} @override_tmpcadir(CA_DEFAULT_SUBJECT=tuple()) def test_subject_with_invalid_country_code(self) -> None: @@ -417,12 +408,11 @@ def test_subject_with_invalid_country_code(self) -> None: ), }, ) - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertFalse(response.context["adminform"].form.is_valid()) - self.assertEqual( - response.context["adminform"].form.errors, - {"subject": ["Value error, FOO: Must have exactly two characters"]}, - ) + assert response.status_code == HTTPStatus.OK + assert not response.context["adminform"].form.is_valid() + assert response.context["adminform"].form.errors == { + "subject": ["Value error, FOO: Must have exactly two characters"] + } @override_tmpcadir(CA_DEFAULT_SUBJECT=tuple()) def test_add_no_key_usage(self) -> None: @@ -437,11 +427,11 @@ def test_add_no_key_usage(self) -> None: self.assertRedirects(response, self.changelist_url) cert = Certificate.objects.get(cn=self.hostname) - self.assertNotIn(ExtensionOID.KEY_USAGE, cert.extensions) # KeyUsage is not set! + assert ExtensionOID.KEY_USAGE not in cert.extensions # KeyUsage is not set! # Test that we can view the certificate response = self.client.get(cert.admin_change_url) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK @override_tmpcadir(CA_DEFAULT_SUBJECT=tuple(), CA_PASSWORDS={}) def test_add_with_password(self) -> None: @@ -531,14 +521,13 @@ def test_invalid_csr(self) -> None: with assert_create_cert_signals(False, False): response = self.client.post(self.add_url, data=self.form_data("whatever", ca)) - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertFalse(response.context["adminform"].form.is_valid()) - self.assertEqual( - response.context["adminform"].form.errors, - {"csr": [CertificateSigningRequestField.simple_validation_error]}, - ) + assert response.status_code == HTTPStatus.OK + assert not response.context["adminform"].form.is_valid() + assert response.context["adminform"].form.errors == { + "csr": [CertificateSigningRequestField.simple_validation_error] + } - with self.assertRaises(Certificate.DoesNotExist): + with pytest.raises(Certificate.DoesNotExist): Certificate.objects.get(cn=self.hostname) @override_tmpcadir() @@ -557,16 +546,16 @@ def test_unparsable_csr(self) -> None: "-----BEGIN CERTIFICATE REQUEST-----\nwrong-----END CERTIFICATE REQUEST-----", ca ), ) - self.assertEqual(response.status_code, HTTPStatus.OK, response.content) - self.assertFalse(response.context["adminform"].form.is_valid()) + assert response.status_code == HTTPStatus.OK, response.content + assert not response.context["adminform"].form.is_valid() # Not testing exact error message here, as it the one from cryptography. Instead, just check that # there is exactly one message for the "csr" field. form = response.context["adminform"].form - self.assertEqual(len(form.errors), 1, form.errors) - self.assertEqual(len(form.errors["csr"]), 1, form.errors["csr"]) + assert len(form.errors) == 1, form.errors + assert len(form.errors["csr"]) == 1, form.errors["csr"] - with self.assertRaises(Certificate.DoesNotExist): + with pytest.raises(Certificate.DoesNotExist): Certificate.objects.get(cn=self.hostname) @override_tmpcadir() @@ -579,15 +568,14 @@ def test_not_after_in_the_past(self) -> None: response = self.client.post( self.add_url, data={**self.form_data(CSR, ca), "not_after": expires.strftime("%Y-%m-%d")} ) - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertIn("Certificate cannot expire in the past.", response.content.decode("utf-8")) - self.assertFalse(response.context["adminform"].form.is_valid()) - self.assertEqual( - response.context["adminform"].form.errors, - {"not_after": ["Certificate cannot expire in the past."]}, - ) + assert response.status_code == HTTPStatus.OK + assert "Certificate cannot expire in the past." in response.content.decode("utf-8") + assert not response.context["adminform"].form.is_valid() + assert response.context["adminform"].form.errors == { + "not_after": ["Certificate cannot expire in the past."] + } - with self.assertRaises(Certificate.DoesNotExist): + with pytest.raises(Certificate.DoesNotExist): Certificate.objects.get(cn=self.hostname) @override_tmpcadir() @@ -602,12 +590,12 @@ def test_expires_too_late(self) -> None: response = self.client.post( self.add_url, data={**self.form_data(CSR, ca), "not_after": expires.strftime("%Y-%m-%d")} ) - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertIn(error, response.content.decode("utf-8")) - self.assertFalse(response.context["adminform"].form.is_valid()) - self.assertEqual(response.context["adminform"].form.errors, {"not_after": [error]}) + assert response.status_code == HTTPStatus.OK + assert error in response.content.decode("utf-8") + assert not response.context["adminform"].form.is_valid() + assert response.context["adminform"].form.errors == {"not_after": [error]} - with self.assertRaises(Certificate.DoesNotExist): + with pytest.raises(Certificate.DoesNotExist): Certificate.objects.get(cn=self.hostname) @override_tmpcadir() @@ -629,11 +617,10 @@ def test_invalid_signature_hash_algorithm(self) -> None: "not_after": self.default_expires, }, ) - self.assertFalse(response.context["adminform"].form.is_valid(), response) - self.assertEqual( - response.context["adminform"].form.errors, - {"algorithm": ["Ed448-based certificate authorities do not use a signature hash algorithm."]}, - ) + assert not response.context["adminform"].form.is_valid(), response + assert response.context["adminform"].form.errors == { + "algorithm": ["Ed448-based certificate authorities do not use a signature hash algorithm."] + } # Test with Ed25519 CA csr = CERT_DATA["ed25519-cert"]["csr"]["parsed"].public_bytes(Encoding.PEM).decode("utf-8") @@ -651,11 +638,10 @@ def test_invalid_signature_hash_algorithm(self) -> None: "not_after": self.default_expires, }, ) - self.assertFalse(response.context["adminform"].form.is_valid(), response) - self.assertEqual( - response.context["adminform"].form.errors, - {"algorithm": ["Ed25519-based certificate authorities do not use a signature hash algorithm."]}, - ) + assert not response.context["adminform"].form.is_valid(), response + assert response.context["adminform"].form.errors == { + "algorithm": ["Ed25519-based certificate authorities do not use a signature hash algorithm."] + } # Test with DSA CA csr = CERT_DATA["dsa-cert"]["csr"]["parsed"].public_bytes(Encoding.PEM).decode("utf-8") @@ -673,11 +659,10 @@ def test_invalid_signature_hash_algorithm(self) -> None: "not_after": self.default_expires, }, ) - self.assertFalse(response.context["adminform"].form.is_valid(), response) - self.assertEqual( - response.context["adminform"].form.errors, - {"algorithm": ["DSA-based certificate authorities require a SHA-256 signature hash algorithm."]}, - ) + assert not response.context["adminform"].form.is_valid(), response + assert response.context["adminform"].form.errors == { + "algorithm": ["DSA-based certificate authorities require a SHA-256 signature hash algorithm."] + } # Test with RSA CA with assert_create_cert_signals(False, False): @@ -694,11 +679,10 @@ def test_invalid_signature_hash_algorithm(self) -> None: "not_after": self.default_expires, }, ) - self.assertFalse(response.context["adminform"].form.is_valid(), response) - self.assertEqual( - response.context["adminform"].form.errors, - {"algorithm": ["RSA-based certificate authorities require a signature hash algorithm."]}, - ) + assert not response.context["adminform"].form.is_valid(), response + assert response.context["adminform"].form.errors == { + "algorithm": ["RSA-based certificate authorities require a signature hash algorithm."] + } @override_tmpcadir(CA_DEFAULT_SUBJECT=tuple()) def test_certificate_policies_with_invalid_oid(self) -> None: @@ -725,13 +709,12 @@ def test_certificate_policies_with_invalid_oid(self) -> None: "certificate_policies_0": "abc", }, ) - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertFalse(response.context["adminform"].form.is_valid()) - self.assertEqual( - response.context["adminform"].form.errors, - {"certificate_policies": ["abc: The given OID is invalid."]}, - ) - self.assertEqual(cert_count, Certificate.objects.all().count()) + assert response.status_code == HTTPStatus.OK + assert not response.context["adminform"].form.is_valid() + assert response.context["adminform"].form.errors == { + "certificate_policies": ["abc: The given OID is invalid."] + } + assert cert_count == Certificate.objects.all().count() def test_add_no_cas(self) -> None: """Test adding when all CAs are disabled.""" @@ -843,18 +826,18 @@ def assertProfile( # pylint: disable=invalid-name ku_expected = self.get_expected(profile, ExtensionOID.KEY_USAGE, []) ku_selected = [o.get_attribute("value") for o in ku_select.all_selected_options] - self.assertCountEqual(ku_expected["value"], ku_selected) - self.assertEqual(ku_expected["critical"], ku_critical.is_selected()) + assert_count_equal(ku_expected["value"], ku_selected) + assert ku_expected["critical"] == ku_critical.is_selected() eku_expected = self.get_expected(profile, ExtensionOID.EXTENDED_KEY_USAGE, []) eku_selected = [o.get_attribute("value") for o in eku_select.all_selected_options] - self.assertCountEqual(eku_expected["value"], eku_selected) - self.assertEqual(eku_expected["critical"], eku_critical.is_selected()) + assert_count_equal(eku_expected["value"], eku_selected) + assert eku_expected["critical"] == eku_critical.is_selected() tf_selected = [o.get_attribute("value") for o in tf_select.all_selected_options] tf_expected = self.get_expected(profile, ExtensionOID.TLS_FEATURE, []) - self.assertCountEqual(tf_expected.get("value", []), tf_selected) - self.assertEqual(tf_expected.get("critical", False), tf_critical.is_selected()) + assert_count_equal(tf_expected.get("value", []), tf_selected) + assert tf_expected.get("critical", False) == tf_critical.is_selected() def clear_form( self, @@ -901,10 +884,9 @@ def test_select_profile(self) -> None: tf_critical = self.find("input#id_tls_feature_1") # test that the default profile is preselected - self.assertEqual( - [model_settings.CA_DEFAULT_PROFILE], - [o.get_attribute("value") for o in select.all_selected_options], - ) + assert [model_settings.CA_DEFAULT_PROFILE] == [ + o.get_attribute("value") for o in select.all_selected_options + ] # assert that the values from the default profile are preloaded self.assertProfile( @@ -993,22 +975,22 @@ def test_subject_field(self) -> None: ] # Test the initial state - self.assertEqual(self.value, expected_initial_subject) - self.assertEqual(self.displayed_value, expected_initial_subject) + assert self.value == expected_initial_subject + assert self.displayed_value == expected_initial_subject # Add a row and confirm that it's initially empty and the field is thus not yet modified self.key_value_field.find_element(By.CLASS_NAME, "add-row-btn").click() self.assertNotModified() new_select = Select(self.key_value_list.find_elements(By.CSS_SELECTOR, "select")[-1]) new_input = self.key_value_list.find_elements(By.CSS_SELECTOR, "input")[-1] - self.assertEqual(new_select.all_selected_options, []) - self.assertEqual(new_input.get_attribute("value"), "") + assert new_select.all_selected_options == [] + assert new_input.get_attribute("value") == "" # Enter a value. This marks the field as modified, but the hidden input is *not* updated, as there is # no key/OID selected yet new_input.send_keys(self.hostname) self.assertModified() - self.assertEqual(self.value, expected_initial_subject) + assert self.value == expected_initial_subject # Now select common name, and the subject is also updated new_select.select_by_value(NameOID.COMMON_NAME.dotted_string) @@ -1016,15 +998,15 @@ def test_subject_field(self) -> None: *expected_initial_subject, {"oid": NameOID.COMMON_NAME.dotted_string, "value": self.hostname}, ] - self.assertEqual(self.value, new_subject) - self.assertEqual(self.displayed_value, new_subject) # just to be sure + assert self.value == new_subject + assert self.displayed_value == new_subject # just to be sure # Remove the second row, check the update self.key_value_list.find_elements(By.CSS_SELECTOR, ".remove-row-btn")[1].click() new_subject.pop(1) - self.assertEqual(len(new_subject), 2) - self.assertEqual(self.value, new_subject) - self.assertEqual(self.displayed_value, new_subject) + assert len(new_subject) == 2 + assert self.value == new_subject + assert self.displayed_value == new_subject @override_tmpcadir() def test_csr_integration(self) -> None: @@ -1123,16 +1105,16 @@ def test_paste_csr_no_subject(self) -> None: ) # Check that the right parts of the CSR chapter is displayed - self.assertIs(no_csr.is_displayed(), False) - self.assertIs(has_content.is_displayed(), False) - self.assertIs(no_content.is_displayed(), True) + assert no_csr.is_displayed() is False + assert has_content.is_displayed() is False + assert no_content.is_displayed() is True self.assertNotModified() # Click the clear button and validate that the subject is cleared csr_chapter.find_element(By.CSS_SELECTOR, ".clear-button").click() self.assertModified() - self.assertEqual(self.value, []) - self.assertEqual(self.displayed_value, []) + assert self.value == [] + assert self.displayed_value == [] @override_tmpcadir() def test_paste_csr_missing_delimiters(self) -> None: @@ -1152,9 +1134,9 @@ def test_paste_csr_missing_delimiters(self) -> None: csr_field.send_keys(csr.public_bytes(Encoding.PEM).decode("ascii")[1:]) # Check that the right parts of the CSR chapter is displayed - self.assertIs(no_csr.is_displayed(), True) # this is displayed as we haven't pasted a CSR - self.assertIs(has_content.is_displayed(), False) - self.assertIs(no_content.is_displayed(), False) + assert no_csr.is_displayed() is True # this is displayed as we haven't pasted a CSR + assert has_content.is_displayed() is False + assert no_content.is_displayed() is False self.assertNotModified() @override_tmpcadir() @@ -1174,9 +1156,9 @@ def test_paste_invalid_csr(self) -> None: csr.send_keys("-----BEGIN CERTIFICATE REQUEST-----\nXXX\n-----END CERTIFICATE REQUEST-----") # Check that the right parts of the CSR chapter is displayed - self.assertIs(no_csr.is_displayed(), True) # this is displayed as we haven't pasted a CSR - self.assertIs(has_content.is_displayed(), False) - self.assertIs(no_content.is_displayed(), False) + assert no_csr.is_displayed() is True # this is displayed as we haven't pasted a CSR + assert has_content.is_displayed() is False + assert no_content.is_displayed() is False self.assertNotModified() @override_tmpcadir( @@ -1224,10 +1206,10 @@ def test_profile_integration(self) -> None: # Test the initial state (webserver subject, since it's the default profile self.assertNotModified() self.assertChapterHasValue(chapter, webserver_subject) - self.assertEqual(self.value, webserver_subject) - self.assertEqual(self.displayed_value, webserver_subject) - self.assertIs(has_content.is_displayed(), True) - self.assertIs(no_content.is_displayed(), False) + assert self.value == webserver_subject + assert self.displayed_value == webserver_subject + assert has_content.is_displayed() is True + assert no_content.is_displayed() is False profile_select = Select(self.selenium.find_element(By.ID, "id_profile")) @@ -1235,10 +1217,10 @@ def test_profile_integration(self) -> None: profile_select.select_by_value("client") self.assertNotModified() self.assertChapterHasValue(chapter, client_subject) - self.assertEqual(self.value, client_subject) - self.assertEqual(self.displayed_value, client_subject) - self.assertIs(has_content.is_displayed(), True) - self.assertIs(no_content.is_displayed(), False) + assert self.value == client_subject + assert self.displayed_value == client_subject + assert has_content.is_displayed() is True + assert no_content.is_displayed() is False # Change one field and check modification st_input = self.key_value_list.find_elements(By.CSS_SELECTOR, "input")[1] @@ -1247,24 +1229,24 @@ def test_profile_integration(self) -> None: new_subject = deepcopy(client_subject) new_subject[1]["value"] = "Styria" self.assertModified() - self.assertEqual(self.value, new_subject) - self.assertEqual(self.displayed_value, new_subject) + assert self.value == new_subject + assert self.displayed_value == new_subject # Switch back to the old profile. Since you made changes, it's not automatically updated profile_select.select_by_value("webserver") self.assertModified() self.assertChapterHasValue(chapter, webserver_subject) - self.assertEqual(self.value, new_subject) - self.assertEqual(self.displayed_value, new_subject) + assert self.value == new_subject + assert self.displayed_value == new_subject # Copy the profile subject and check the state chapter.find_element(By.CLASS_NAME, "copy-button").click() self.assertNotModified() self.assertChapterHasValue(chapter, webserver_subject) - self.assertEqual(self.value, webserver_subject) - self.assertEqual(self.displayed_value, webserver_subject) - self.assertIs(has_content.is_displayed(), True) - self.assertIs(no_content.is_displayed(), False) + assert self.value == webserver_subject + assert self.displayed_value == webserver_subject + assert has_content.is_displayed() is True + assert no_content.is_displayed() is False # Modify subject again (so that we can check the modified flag of the clear button) st_input = self.key_value_list.find_elements(By.CSS_SELECTOR, "input")[1] @@ -1272,24 +1254,24 @@ def test_profile_integration(self) -> None: st_input.send_keys("Styria") new_subject[2]["value"] = "webserver" self.assertModified() - self.assertEqual(self.value, new_subject) - self.assertEqual(self.displayed_value, new_subject) + assert self.value == new_subject + assert self.displayed_value == new_subject # Switch to the profile with no subject and check the state profile_select.select_by_value("no-subject") self.assertModified() self.assertChapterHasValue(chapter, []) - self.assertEqual(self.value, new_subject) - self.assertEqual(self.displayed_value, new_subject) - self.assertIs(has_content.is_displayed(), False) - self.assertIs(no_content.is_displayed(), True) + assert self.value == new_subject + assert self.displayed_value == new_subject + assert has_content.is_displayed() is False + assert no_content.is_displayed() is True # Click the clear button chapter.find_element(By.CLASS_NAME, "clear-button").click() self.assertNotModified() self.assertChapterHasValue(chapter, []) - self.assertEqual(self.value, []) - self.assertEqual(self.displayed_value, []) + assert self.value == [] + assert self.displayed_value == [] @freeze_time(TIMESTAMPS["everything_valid"]) @@ -1318,7 +1300,7 @@ def test_empty_form_and_empty_cert(self) -> None: form["authority_information_access_0"] = "[]" form["authority_information_access_1"] = "[]" response = form.submit() - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 # Fill in the bare minimum fields form = response.forms["certificate_form"] @@ -1331,18 +1313,15 @@ def test_empty_form_and_empty_cert(self) -> None: # Submit the form response = form.submit().follow() - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 cert = Certificate.objects.get(cn="test-empty-form.example.com") # Cert has minimal extensions, since we cleared the form earlier - self.assertEqual( - cert.sorted_extensions, - [ - cert.ca.get_authority_key_identifier_extension(), - basic_constraints(), - subject_key_identifier(cert), - ], - ) + assert cert.sorted_extensions == [ + cert.ca.get_authority_key_identifier_extension(), + basic_constraints(), + subject_key_identifier(cert), + ] @override_tmpcadir( CA_PROFILES={ @@ -1361,20 +1340,17 @@ def test_none_extension_and_subject_alternative_name_extension(self) -> None: form["csr"] = CSR form["subject"] = json.dumps([{"oid": NameOID.COMMON_NAME.dotted_string, "value": self.hostname}]) response = form.submit().follow() - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 cert: Certificate = Certificate.objects.get(cn=self.hostname) - self.assertEqual( - cert.sorted_extensions, - [ - cert.ca.sign_authority_information_access, - cert.ca.get_authority_key_identifier_extension(), - basic_constraints(), - cert.ca.sign_crl_distribution_points, - subject_alternative_name(dns("example.com")), - subject_key_identifier(cert), - ], - ) + assert cert.sorted_extensions == [ + cert.ca.sign_authority_information_access, + cert.ca.get_authority_key_identifier_extension(), + basic_constraints(), + cert.ca.sign_crl_distribution_points, + subject_alternative_name(dns("example.com")), + subject_key_identifier(cert), + ] @override_tmpcadir(CA_PROFILES={"nothing": {}}, CA_DEFAULT_PROFILE="nothing") def test_only_ca_prefill(self) -> None: @@ -1516,8 +1492,8 @@ def test_full_profile_prefill(self) -> None: fields would not show up in the signed certificate. """ # Make sure that the CA has sign_* field values set. - self.assertIsNotNone(self.ca.sign_authority_information_access) - self.assertIsNotNone(self.ca.sign_crl_distribution_points) + assert self.ca.sign_authority_information_access is not None + assert self.ca.sign_crl_distribution_points is not None self.ca.sign_certificate_policies = certificate_policies( x509.PolicyInformation( policy_identifier=CertificatePoliciesOID.CPS_QUALIFIER, policy_qualifiers=None @@ -1536,56 +1512,50 @@ def test_full_profile_prefill(self) -> None: form["csr"] = CSR form["subject"] = json.dumps([{"oid": NameOID.COMMON_NAME.dotted_string, "value": self.hostname}]) response = form.submit().follow() - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 # Check that we get all the extensions from the CA cert = Certificate.objects.get(cn=self.hostname) - self.assertEqual(cert.profile, "everything") - self.assertEqual( - cert.sorted_extensions, - [ - authority_information_access( - ca_issuers=[uri("http://profile.issuers.example.com")], - ocsp=[ - uri("http://profile.ocsp.example.com"), - uri("http://profile.ocsp-backup.example.com"), - ], - critical=False, - ), - cert.ca.get_authority_key_identifier_extension(), - basic_constraints(), - crl_distribution_points( - distribution_point( - full_name=[uri("http://crl.profile.example.com")], - crl_issuer=[uri("http://crl-issuer.profile.example.com")], - ), - critical=True, - ), - certificate_policies( - x509.PolicyInformation( - policy_identifier=CertificatePoliciesOID.CPS_USER_NOTICE, policy_qualifiers=["text1"] - ), - critical=True, - ), - extended_key_usage( - ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, critical=True + assert cert.profile == "everything" + assert cert.sorted_extensions == [ + authority_information_access( + ca_issuers=[uri("http://profile.issuers.example.com")], + ocsp=[uri("http://profile.ocsp.example.com"), uri("http://profile.ocsp-backup.example.com")], + critical=False, + ), + cert.ca.get_authority_key_identifier_extension(), + basic_constraints(), + crl_distribution_points( + distribution_point( + full_name=[uri("http://crl.profile.example.com")], + crl_issuer=[uri("http://crl-issuer.profile.example.com")], ), - freshest_crl( - distribution_point( - full_name=[uri("http://freshest-crl.profile.example.com")], - crl_issuer=[uri("http://freshest-crl-issuer.profile.example.com")], - ), - critical=False, + critical=True, + ), + certificate_policies( + x509.PolicyInformation( + policy_identifier=CertificatePoliciesOID.CPS_USER_NOTICE, policy_qualifiers=["text1"] ), - issuer_alternative_name( - uri("http://ian1.example.com"), uri("http://ian2.example.com"), critical=True + critical=True, + ), + extended_key_usage( + ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, critical=True + ), + freshest_crl( + distribution_point( + full_name=[uri("http://freshest-crl.profile.example.com")], + crl_issuer=[uri("http://freshest-crl-issuer.profile.example.com")], ), - key_usage(key_agreement=True, key_cert_sign=True), - ocsp_no_check(critical=True), - subject_key_identifier(cert), - tls_feature(x509.TLSFeatureType.status_request, critical=True), - ], - ) + critical=False, + ), + issuer_alternative_name( + uri("http://ian1.example.com"), uri("http://ian2.example.com"), critical=True + ), + key_usage(key_agreement=True, key_cert_sign=True), + ocsp_no_check(critical=True), + subject_key_identifier(cert), + tls_feature(x509.TLSFeatureType.status_request, critical=True), + ] @override_tmpcadir( CA_PROFILES={ @@ -1640,13 +1610,10 @@ def test_multiple_distribution_points(self) -> None: with self.assertLogs("django_ca") as logcm: response = self.app.get(self.add_url, user=self.user.username) - self.assertEqual( - logcm.output, - [ - "WARNING:django_ca.widgets:Received multiple DistributionPoints, only the first can be " - "changed in the web interface." - ], - ) + assert logcm.output == [ + "WARNING:django_ca.widgets:Received multiple DistributionPoints, only the first can be " + "changed in the web interface." + ] form = response.forms["certificate_form"] # default value for form field is on import time, so override settings does not change # profile field @@ -1655,41 +1622,38 @@ def test_multiple_distribution_points(self) -> None: form["subject"] = json.dumps([{"oid": NameOID.COMMON_NAME.dotted_string, "value": cn}]) response = form.submit() response = response.follow() - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 # Check that we get all the extensions from the CA cert: Certificate = Certificate.objects.get(cn="test-only-ca.example.com") - self.assertEqual(cert.profile, "everything") - self.assertEqual( - cert.sorted_extensions, - [ - cert.ca.get_authority_key_identifier_extension(), - basic_constraints(), - x509.Extension( - oid=ExtensionOID.CRL_DISTRIBUTION_POINTS, - critical=True, - value=x509.CRLDistributionPoints( - [ - x509.DistributionPoint( - full_name=[uri("http://crl.profile.example.com")], - relative_name=None, - reasons=None, - crl_issuer=[uri("http://crl-issuer.profile.example.com")], - ), - x509.DistributionPoint( - full_name=[uri("http://crl2.profile.example.com")], - relative_name=None, - reasons=None, - crl_issuer=[uri("http://crl-issuer2.profile.example.com")], - ), - ] - ), - ), - self.freshest_crl( - [uri("http://freshest-crl.profile.example.com")], - crl_issuer=[uri("http://freshest-crl-issuer.profile.example.com")], - critical=False, + assert cert.profile == "everything" + assert cert.sorted_extensions == [ + cert.ca.get_authority_key_identifier_extension(), + basic_constraints(), + x509.Extension( + oid=ExtensionOID.CRL_DISTRIBUTION_POINTS, + critical=True, + value=x509.CRLDistributionPoints( + [ + x509.DistributionPoint( + full_name=[uri("http://crl.profile.example.com")], + relative_name=None, + reasons=None, + crl_issuer=[uri("http://crl-issuer.profile.example.com")], + ), + x509.DistributionPoint( + full_name=[uri("http://crl2.profile.example.com")], + relative_name=None, + reasons=None, + crl_issuer=[uri("http://crl-issuer2.profile.example.com")], + ), + ] ), - subject_key_identifier(cert), - ], - ) + ), + self.freshest_crl( + [uri("http://freshest-crl.profile.example.com")], + crl_issuer=[uri("http://freshest-crl-issuer.profile.example.com")], + critical=False, + ), + subject_key_identifier(cert), + ] diff --git a/ca/django_ca/tests/admin/test_admin_ca.py b/ca/django_ca/tests/admin/test_admin_ca.py index 438fc04f4..54871e28e 100644 --- a/ca/django_ca/tests/admin/test_admin_ca.py +++ b/ca/django_ca/tests/admin/test_admin_ca.py @@ -54,7 +54,7 @@ def test_complex_sign_certificate_policies(self) -> None: # This test is only meaningful if the CA does **not** have the Certificate Policies extension in its # own extensions. We (can) only test for the used template after viewing, and the template would be # used for that extension. - self.assertNotIn(ExtensionOID.CERTIFICATE_POLICIES, ca.extensions) + assert ExtensionOID.CERTIFICATE_POLICIES not in ca.extensions ca.sign_certificate_policies = certificate_policies( x509.PolicyInformation( @@ -80,7 +80,7 @@ def test_complex_sign_certificate_policies(self) -> None: response = self.get_change_view(ca) assert_change_response(response) templates = [t.name for t in response.templates] - self.assertIn("django_ca/admin/extensions/2.5.29.32.html", templates) + assert "django_ca/admin/extensions/2.5.29.32.html" in templates class CADownloadBundleTestCase(AdminTestCaseMixin[CertificateAuthority], TestCase): @@ -111,13 +111,13 @@ def test_child(self) -> None: def test_invalid_format(self) -> None: """Test downloading the bundle in an invalid format.""" response = self.client.get(f"{self.url}?format=INVALID") - self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) - self.assertEqual(response.content, b"") + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content == b"" # DER is not supported for bundles response = self.client.get(f"{self.url}?format=DER") - self.assertEqual(response.status_code, 400) - self.assertEqual(response.content, b"DER/ASN.1 certificates cannot be downloaded as a bundle.") + assert response.status_code == 400 + assert response.content == b"DER/ASN.1 certificates cannot be downloaded as a bundle." def test_permission_denied(self) -> None: """Test downloading without permissions fails.""" @@ -125,7 +125,7 @@ def test_permission_denied(self) -> None: self.user.save() response = self.client.get(f"{self.url}?format=PEM") - self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) + assert response.status_code == HTTPStatus.FORBIDDEN def test_unauthorized(self) -> None: """Test viewing as unauthorized viewer.""" diff --git a/ca/django_ca/tests/admin/test_extra_views.py b/ca/django_ca/tests/admin/test_extra_views.py index 78d736aca..c339b009e 100644 --- a/ca/django_ca/tests/admin/test_extra_views.py +++ b/ca/django_ca/tests/admin/test_extra_views.py @@ -38,7 +38,7 @@ @pytest.mark.parametrize( - "data,expected", + ("data", "expected"), ( ([], ""), ([{"oid": NameOID.COMMON_NAME.dotted_string, "value": "example.com"}], "CN=example.com"), @@ -107,12 +107,11 @@ def test_basic(self) -> None: response = self.client.post( self.url, data=json.dumps({"csr": csr}), content_type="application/json" ) - self.assertEqual(response.status_code, 200, response.json()) + assert response.status_code == 200, response.json() csr_subject = cert_data["csr"]["parsed"].subject - self.assertEqual( - response.json(), - {"subject": [{"oid": s.oid.dotted_string, "value": s.value} for s in csr_subject]}, - ) + assert response.json() == { + "subject": [{"oid": s.oid.dotted_string, "value": s.value} for s in csr_subject] + } def test_fields(self) -> None: """Test fetching a CSR with all subject fields.""" @@ -129,7 +128,7 @@ def test_fields(self) -> None: response = self.client.post( self.url, data=json.dumps({"csr": csr_pem}), content_type="application/json" ) - self.assertEqual(response.status_code, 200, response.json()) + assert response.status_code == 200, response.json() expected = [ {"oid": NameOID.USER_ID.dotted_string, "value": "test-uid"}, {"oid": NameOID.DOMAIN_COMPONENT.dotted_string, "value": "test-domainComponent"}, @@ -169,12 +168,12 @@ def test_fields(self) -> None: {"oid": NameOID.ORGANIZATION_IDENTIFIER.dotted_string, "value": "test-organizationIdentifier"}, ] - self.assertEqual(json.loads(response.content.decode("utf-8")), {"subject": expected}) + assert json.loads(response.content.decode("utf-8")) == {"subject": expected} def test_bad_request(self) -> None: """Test posting bogus data.""" response = self.client.post(self.url, data={"csr": "foobar"}) - self.assertEqual(response.status_code, 400) + assert response.status_code == 400 def test_anonymous(self) -> None: """Try downloading as anonymous user.""" @@ -192,7 +191,7 @@ def test_no_perms(self) -> None: self.user.is_superuser = False self.user.save() response = self.client.post(self.url, data={"csr": self.csr_pem}) - self.assertEqual(response.status_code, 403) + assert response.status_code == 403 def test_no_staff(self) -> None: """Try downloading as user that has permissions but is not staff.""" @@ -217,22 +216,22 @@ def test_der(self) -> None: """Download a certificate in DER format.""" filename = f"{self.cert.serial}.der" response = self.client.get(self.get_url(self.cert), {"format": "DER"}) - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertEqual(response["Content-Type"], "application/pkix-cert") - self.assertEqual(response["Content-Disposition"], f"attachment; filename={filename}") - self.assertEqual(response.content, self.cert.pub.der) + assert response.status_code == HTTPStatus.OK + assert response["Content-Type"] == "application/pkix-cert" + assert response["Content-Disposition"] == f"attachment; filename={filename}" + assert response.content == self.cert.pub.der def test_not_found(self) -> None: """Try downloading a certificate that does not exist.""" url = reverse("admin:django_ca_certificate_download", kwargs={"pk": "123"}) response = self.client.get(f"{url}?format=DER") - self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND) + assert response.status_code == HTTPStatus.NOT_FOUND def test_bad_format(self) -> None: """Try downloading an unknown format.""" response = self.client.get(self.get_url(self.cert), {"format": "bad"}) - self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) - self.assertEqual(response.content, b"") + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content == b"" def test_anonymous(self) -> None: """Try an anonymous download.""" @@ -249,7 +248,7 @@ def test_no_perms(self) -> None: self.user.is_superuser = False self.user.save() response = self.client.get(self.get_url(self.cert)) - self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) + assert response.status_code == HTTPStatus.FORBIDDEN def test_no_staff(self) -> None: """Try downloading with right permissions but not as staff user.""" @@ -283,10 +282,10 @@ def test_invalid_format(self) -> None: """Try downloading an invalid format.""" url = self.get_url(self.cert) response = self.client.get(f"{url}?format=INVALID") - self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) - self.assertEqual(response.content, b"") + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content == b"" # DER is not supported for bundles response = self.client.get(f"{url}?format=DER") - self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) - self.assertEqual(response.content, b"DER/ASN.1 certificates cannot be downloaded as a bundle.") + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content == b"DER/ASN.1 certificates cannot be downloaded as a bundle." diff --git a/ca/django_ca/tests/base/assertions.py b/ca/django_ca/tests/base/assertions.py index 37463dda7..4d61363ed 100644 --- a/ca/django_ca/tests/base/assertions.py +++ b/ca/django_ca/tests/base/assertions.py @@ -13,13 +13,14 @@ """:py:mod:`django_ca.tests.base.assertions` collects assertions used throughout the entire test suite.""" +import collections import io import re import typing from collections.abc import Iterable, Iterator from contextlib import contextmanager from datetime import datetime, timedelta, timezone as tz -from typing import AnyStr, Optional, Union +from typing import Any, AnyStr, Optional, Union from unittest.mock import Mock from cryptography import x509 @@ -157,6 +158,12 @@ def assert_command_error(msg: str, returncode: int = 1) -> Iterator[None]: assert exc_info.value.returncode == returncode +def assert_count_equal(first: Iterable[Any], second: Iterable[Any]) -> None: + """Roughly equivalent version of unittests assertCountEqual().""" + first, second = list(first), list(second) + assert collections.Counter(first) == collections.Counter(second) + + @contextmanager def assert_create_ca_signals(pre: bool = True, post: bool = True) -> Iterator[tuple[Mock, Mock]]: """Context manager asserting that the `pre_create_ca`/`post_create_ca` signals are (not) called.""" @@ -305,6 +312,34 @@ def assert_e2e_error( raise NotImplementedError +def assert_extension_equal( + first: Optional[x509.Extension[x509.ExtensionType]], second: Optional[x509.Extension[x509.ExtensionType]] +) -> None: + """Compare two extensions for equality (or if both are None). + + This assertion overrides comparison for iterable extension and should be used only when order of these + extension values cannot be guaranteed. For example, two ExtendedKeyUsage extension will pass as equal + regardless of order of the extended key usages in the extensions. + """ + # If both are None that's still okay. + if first is None and second is None: + return + if first is None or second is None: # pragma: no cover + raise AssertionError("One of the values is None.") + + if second.oid in ( + ExtensionOID.EXTENDED_KEY_USAGE, + ExtensionOID.TLS_FEATURE, + ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + ExtensionOID.ISSUER_ALTERNATIVE_NAME, + ): + assert first.oid == second.oid + assert first.critical == second.critical + assert_count_equal(first.value, second.value) # type: ignore[arg-type] + else: + assert first == second + + def assert_extensions( cert: Union[X509CertMixin, x509.Certificate], extensions: Iterable[x509.Extension[x509.ExtensionType]], diff --git a/ca/django_ca/tests/base/conftest_helpers.py b/ca/django_ca/tests/base/conftest_helpers.py index cde721765..340cf6889 100644 --- a/ca/django_ca/tests/base/conftest_helpers.py +++ b/ca/django_ca/tests/base/conftest_helpers.py @@ -146,24 +146,24 @@ def setup_pragmas(cov: coverage.Coverage) -> None: exclude_versions(cov, "cryptography", cg_version, ver, version_str) -def generate_pub_fixture(name: str) -> typing.Callable[[], Iterator[x509.Certificate]]: +def generate_pub_fixture(name: str) -> typing.Callable[[], x509.Certificate]: """Generate fixture for a loaded public key (root_pub, root_cert_pub, ...).""" @pytest.fixture(scope="session") - def fixture() -> Iterator[x509.Certificate]: + def fixture() -> x509.Certificate: return load_pub(name) return fixture -def generate_ca_fixture(name: str) -> typing.Callable[["SubRequest", Any], Iterator[CertificateAuthority]]: +def generate_ca_fixture(name: str) -> typing.Callable[["SubRequest", Any], CertificateAuthority]: """Function to generate CA fixtures (root, child, ...).""" @pytest.fixture def fixture( request: "SubRequest", db: Any, # pylint: disable=unused-argument # usefixtures does not work for fixtures - ) -> Iterator[CertificateAuthority]: + ) -> CertificateAuthority: data = CERT_DATA[name] ca_fixture_name = f"{name}_pub" if data["cat"] == "sphinx-contrib": @@ -187,27 +187,25 @@ def fixture( return fixture -def generate_usable_ca_fixture( - name: str, -) -> typing.Callable[["SubRequest", Path], Iterator[CertificateAuthority]]: +def generate_usable_ca_fixture(name: str) -> typing.Callable[["SubRequest", Path], CertificateAuthority]: """Function to generate CA fixtures (root, child, ...).""" @pytest.fixture - def fixture(request: "SubRequest", tmpcadir: Path) -> Iterator[CertificateAuthority]: + def fixture(request: "SubRequest", tmpcadir: Path) -> CertificateAuthority: ca = request.getfixturevalue(name) # load the CA into the database data = CERT_DATA[name] shutil.copy(os.path.join(FIXTURES_DIR, data["key_filename"]), tmpcadir) - return ca + return ca # type: ignore[no-any-return] return fixture -def generate_cert_fixture(name: str) -> typing.Callable[["SubRequest"], Iterator[Certificate]]: +def generate_cert_fixture(name: str) -> typing.Callable[["SubRequest"], Certificate]: """Function to generate cert fixtures (root_cert, all_extensions, no_extensions, ...).""" @pytest.fixture - def fixture(request: "SubRequest") -> Iterator[Certificate]: + def fixture(request: "SubRequest") -> Certificate: sanitized_name = name.replace("-", "_") data = CERT_DATA[name] diff --git a/ca/django_ca/tests/base/fixtures.py b/ca/django_ca/tests/base/fixtures.py index fe12d68bb..49e3cb21c 100644 --- a/ca/django_ca/tests/base/fixtures.py +++ b/ca/django_ca/tests/base/fixtures.py @@ -58,15 +58,15 @@ @pytest.fixture(params=all_cert_names) -def any_cert(request: "SubRequest") -> Iterator[Certificate]: +def any_cert(request: "SubRequest") -> Certificate: """Parametrized fixture for absolutely *any* certificate name.""" - return request.param + return request.param # type: ignore[no-any-return] @pytest.fixture -def ca_name(request: "SubRequest") -> Iterator[str]: +def ca_name(request: "SubRequest") -> str: """Fixture for a name suitable for a CA.""" - return request.node.name + return request.node.name # type: ignore[no-any-return] @pytest.fixture( @@ -163,7 +163,7 @@ def ca_name(request: "SubRequest") -> Iterator[str]: ], ) ) -def certificate_policies_value(request: "SubRequest") -> Iterator[x509.CertificatePolicies]: +def certificate_policies_value(request: "SubRequest") -> x509.CertificatePolicies: """Parametrized fixture with different :py:class:`~cg:cryptography.x509.CertificatePolicies` objects.""" return x509.CertificatePolicies(policies=request.param) @@ -171,7 +171,7 @@ def certificate_policies_value(request: "SubRequest") -> Iterator[x509.Certifica @pytest.fixture(params=(True, False)) def certificate_policies( request: "SubRequest", certificate_policies_value: x509.CertificatePolicies -) -> Iterator[x509.Extension[x509.CertificatePolicies]]: +) -> x509.Extension[x509.CertificatePolicies]: """Parametrized fixture yielding different ``x509.Extension[x509.CertificatePolicies]`` objects.""" return x509.Extension( critical=request.param, oid=ExtensionOID.CERTIFICATE_POLICIES, value=certificate_policies_value @@ -186,13 +186,13 @@ def clear_cache() -> Iterator[None]: @pytest.fixture(params=("ed448", "ed25519")) -def ed_ca(request: "SubRequest") -> Iterator[CertificateAuthority]: +def ed_ca(request: "SubRequest") -> CertificateAuthority: """Parametrized fixture for CAs with an Edwards-curve algorithm (ed448, ed25519).""" - return request.getfixturevalue(f"{request.param}") + return request.getfixturevalue(f"{request.param}") # type: ignore[no-any-return] @pytest.fixture -def hostname(ca_name: str) -> Iterator[str]: +def hostname(ca_name: str) -> str: """Fixture for a hostname. The value is unique for each test, and it includes the CA name, which includes the test name. @@ -201,30 +201,30 @@ def hostname(ca_name: str) -> Iterator[str]: @pytest.fixture(params=interesting_certificate_names) -def interesting_cert(request: "SubRequest") -> Iterator[Certificate]: +def interesting_cert(request: "SubRequest") -> Certificate: """Parametrized fixture for "interesting" certificates. A function using this fixture will be called once for each certificate with unusual extensions. """ - return request.getfixturevalue(request.param.replace("-", "_")) + return request.getfixturevalue(request.param.replace("-", "_")) # type: ignore[no-any-return] @pytest.fixture -def key_backend(request: "SubRequest") -> Iterator[StoragesBackend]: +def key_backend(request: "SubRequest") -> StoragesBackend: """Return a :py:class:`~django_ca.key_backends.storages.StoragesBackend` for creating a new CA.""" request.getfixturevalue("tmpcadir") - return key_backends[model_settings.CA_DEFAULT_KEY_BACKEND] # type: ignore[misc] + return key_backends[model_settings.CA_DEFAULT_KEY_BACKEND] # type: ignore[return-value] @pytest.fixture(params=precertificate_signed_certificate_timestamps_cert_names) -def precertificate_signed_certificate_timestamps_pub(request: "SubRequest") -> Iterator[x509.Certificate]: +def precertificate_signed_certificate_timestamps_pub(request: "SubRequest") -> x509.Certificate: """Parametrized fixture for certificates that have a PrecertSignedCertificateTimestamps extension.""" name = request.param.replace("-", "_") - return request.getfixturevalue(f"contrib_{name}_pub") + return request.getfixturevalue(f"contrib_{name}_pub") # type: ignore[no-any-return] @pytest.fixture -def rfc4514_subject(subject: x509.Name) -> Iterator[str]: +def rfc4514_subject(subject: x509.Name) -> str: """Fixture for an RFC 4514 formatted name to use for a subject. The common name is based on :py:func:`~django_ca.tests.base.fixtures.hostname` and identical to @@ -234,7 +234,7 @@ def rfc4514_subject(subject: x509.Name) -> Iterator[str]: @pytest.fixture -def root_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]: +def root_crl(root: CertificateAuthority) -> CertificateRevocationList: """Fixture for the global CRL object for the Root CA.""" with open(constants.FIXTURES_DIR / "root.crl", "rb") as stream: crl_data = stream.read() @@ -248,7 +248,7 @@ def root_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]: @pytest.fixture -def root_ca_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]: +def root_ca_crl(root: CertificateAuthority) -> CertificateRevocationList: """Fixture for the user CRL object for the Root CA.""" with open(constants.FIXTURES_DIR / "root.ca.crl", "rb") as stream: crl_data = stream.read() @@ -267,7 +267,7 @@ def root_ca_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationLis @pytest.fixture -def root_user_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]: +def root_user_crl(root: CertificateAuthority) -> CertificateRevocationList: """Fixture for the user CRL object for the Root CA.""" with open(constants.FIXTURES_DIR / "root.user.crl", "rb") as stream: crl_data = stream.read() @@ -286,7 +286,7 @@ def root_user_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationL @pytest.fixture -def root_attribute_crl(root: CertificateAuthority) -> Iterator[CertificateRevocationList]: +def root_attribute_crl(root: CertificateAuthority) -> CertificateRevocationList: """Fixture for the attribute CRL object for the Root CA.""" with open(constants.FIXTURES_DIR / "root.attribute.crl", "rb") as stream: crl_data = stream.read() @@ -305,30 +305,28 @@ def root_attribute_crl(root: CertificateAuthority) -> Iterator[CertificateRevoca @pytest.fixture -def secondary_backend(request: "SubRequest") -> Iterator[StoragesBackend]: +def secondary_backend(request: "SubRequest") -> StoragesBackend: """Return a :py:class:`~django_ca.key_backends.storages.StoragesBackend` for the secondary key backend.""" request.getfixturevalue("tmpcadir") - return key_backends["secondary"] # type: ignore[misc] + return key_backends["secondary"] # type: ignore[return-value] @pytest.fixture(params=signed_certificate_timestamp_cert_names) -def signed_certificate_timestamp_pub(request: "SubRequest") -> Iterator[x509.Certificate]: +def signed_certificate_timestamp_pub(request: "SubRequest") -> x509.Certificate: """Parametrized fixture for certificates that have any SCT extension.""" name = request.param.replace("-", "_") - return request.getfixturevalue(f"contrib_{name}_pub") + return request.getfixturevalue(f"contrib_{name}_pub") # type: ignore[no-any-return] @pytest.fixture(params=signed_certificate_timestamps_cert_names) -def signed_certificate_timestamps_pub( - request: "SubRequest", -) -> Iterator[x509.Certificate]: # pragma: no cover +def signed_certificate_timestamps_pub(request: "SubRequest") -> x509.Certificate: # pragma: no cover """Parametrized fixture for certificates that have a SignedCertificateTimestamps extension. .. NOTE:: There are no certificates with this extension right now, so this fixture is in fact never run. """ name = request.param.replace("-", "_") - return request.getfixturevalue(f"{name}_pub") + return request.getfixturevalue(f"{name}_pub") # type: ignore[no-any-return] @pytest.fixture @@ -370,7 +368,7 @@ def softhsm_setup(tmp_path: Path) -> Iterator[Path]: # pragma: hsm def softhsm_token( # pragma: hsm request: "SubRequest", settings: SettingsWrapper, -) -> Iterator[str]: +) -> str: """Get a unique token for the current test.""" request.getfixturevalue("softhsm_setup") token = settings.PKCS11_TOKEN_LABEL @@ -390,7 +388,7 @@ def softhsm_token( # pragma: hsm if lib := SessionPool._lib_pool.get(settings.PKCS11_PATH): # pylint: disable=protected-access lib.reinitialize() - return token + return token # type: ignore[no-any-return] @pytest.fixture @@ -404,7 +402,7 @@ def hsm_backend(request: "SubRequest") -> Iterator[HSMBackend]: # pragma: hsm @pytest.fixture(params=HSMBackend.supported_key_types) def usable_hsm_ca( # pragma: hsm request: "SubRequest", ca_name: str, subject: x509.Name, hsm_backend: HSMBackend -) -> Iterator[CertificateAuthority]: +) -> CertificateAuthority: """Parametrized fixture yielding a certificate authority for every key type.""" request.getfixturevalue("db") key_type = request.param @@ -428,7 +426,7 @@ def usable_hsm_ca( # pragma: hsm @pytest.fixture -def subject(hostname: str) -> Iterator[x509.Name]: +def subject(hostname: str) -> x509.Name: """Fixture for a :py:class:`~cg:cryptography.x509.Name` to use for a subject. The common name is based on :py:func:`~django_ca.tests.base.fixtures.hostname` and identical to @@ -474,28 +472,28 @@ def tmpcadir(tmp_path: Path, settings: SettingsWrapper) -> Iterator[Path]: @pytest.fixture(params=all_ca_names) -def ca(request: "SubRequest") -> Iterator[CertificateAuthority]: +def ca(request: "SubRequest") -> CertificateAuthority: """Parametrized fixture for all certificate authorities known to the test suite.""" fixture_name = request.param if CERT_DATA[fixture_name]["cat"] in ("contrib", "sphinx-contrib"): fixture_name = f"contrib_{fixture_name}" - return request.getfixturevalue(fixture_name) + return request.getfixturevalue(fixture_name) # type: ignore[no-any-return] @pytest.fixture(params=usable_ca_names) -def usable_ca_name(request: "SubRequest") -> Iterator[CertificateAuthority]: +def usable_ca_name(request: "SubRequest") -> CertificateAuthority: """Parametrized fixture for the name of every usable CA.""" - return request.param + return request.param # type: ignore[no-any-return] @pytest.fixture(params=usable_ca_names) -def usable_ca(request: "SubRequest") -> Iterator[CertificateAuthority]: +def usable_ca(request: "SubRequest") -> CertificateAuthority: """Parametrized fixture for every usable CA (with usable private key).""" - return request.getfixturevalue(f"usable_{request.param}") + return request.getfixturevalue(f"usable_{request.param}") # type: ignore[no-any-return] @pytest.fixture -def usable_cas(request: "SubRequest") -> Iterator[list[CertificateAuthority]]: +def usable_cas(request: "SubRequest") -> list[CertificateAuthority]: """Fixture for all usable CAs as a list.""" cas = [] for name in usable_ca_names: @@ -504,7 +502,7 @@ def usable_cas(request: "SubRequest") -> Iterator[list[CertificateAuthority]]: @pytest.fixture(params=usable_cert_names) -def usable_cert(request: "SubRequest") -> Iterator[Certificate]: +def usable_cert(request: "SubRequest") -> Certificate: """Parametrized fixture for every ``{ca}-cert`` certificate. The name of the certificate can be retrieved from the non-standard `test_name` property of the @@ -514,4 +512,4 @@ def usable_cert(request: "SubRequest") -> Iterator[Certificate]: cert = request.getfixturevalue(name.replace("-", "_")) cert.test_name = name request.getfixturevalue(f"usable_{cert.ca.name}") - return cert + return cert # type: ignore[no-any-return] diff --git a/ca/django_ca/tests/base/mixins.py b/ca/django_ca/tests/base/mixins.py index 9e11ac154..cae1f42fb 100644 --- a/ca/django_ca/tests/base/mixins.py +++ b/ca/django_ca/tests/base/mixins.py @@ -68,13 +68,6 @@ class TestCaseMixin(TestCaseProtocol): re_false_password = r"^Could not decrypt private key - bad password\?$" def setUp(self) -> None: - # Add custom equality functions - self.addTypeEqualityFunc(x509.AuthorityInformationAccess, self.assertAuthorityInformationAccessEqual) - self.addTypeEqualityFunc(x509.ExtendedKeyUsage, self.assertExtendedKeyUsageEqual) - self.addTypeEqualityFunc(x509.Extension, self.assertCryptographyExtensionEqual) - self.addTypeEqualityFunc(x509.KeyUsage, self.assertKeyUsageEqual) - self.addTypeEqualityFunc(x509.TLSFeature, self.assertTLSFeatureEqual) - super().setUp() cache.clear() @@ -152,95 +145,18 @@ def absolute_uri(self, name: str, hostname: Optional[str] = None, **kwargs: Any) name = f"django_ca{name}" return f"http://{hostname}{reverse(name, kwargs=kwargs)}" - def assertAuthorityInformationAccessEqual( # pylint: disable=invalid-name - self, - first: x509.AuthorityInformationAccess, - second: x509.AuthorityInformationAccess, - msg: Optional[str] = None, - ) -> None: - """Type equality function for x509.AuthorityInformationAccess.""" - - def sorter(ad: x509.AccessDescription) -> tuple[str, str]: - return ad.access_method.dotted_string, ad.access_location.value - - self.assertEqual(sorted(first, key=sorter), sorted(second, key=sorter), msg=msg) - - def assertCryptographyExtensionEqual( # pylint: disable=invalid-name - self, - first: x509.Extension[x509.ExtensionType], - second: x509.Extension[x509.ExtensionType], - msg: Optional[str] = None, - ) -> None: - """Type equality function for x509.Extension.""" - # NOTE: Cryptography in name comes from overriding class in AbstractExtensionTestMixin - # remove once old wrapper classes are removed - self.assertEqual(first.oid, second.oid, msg=msg) - self.assertEqual(first.critical, second.critical, msg="critical is unequal.") - self.assertEqual(first.value, second.value, msg=msg) - - def assertExtendedKeyUsageEqual( # pylint: disable=invalid-name - self, first: x509.ExtendedKeyUsage, second: x509.ExtendedKeyUsage, msg: Optional[str] = None - ) -> None: - """Type equality function for x509.ExtendedKeyUsage.""" - self.assertEqual(set(first), set(second), msg=msg) - - def assertKeyUsageEqual( # pylint: disable=invalid-name - self, first: x509.KeyUsage, second: x509.KeyUsage, msg: Optional[str] = None - ) -> None: - """Type equality function for x509.KeyUsage.""" - diffs = [] - for usage in [ - "content_commitment", - "crl_sign", - "data_encipherment", - "decipher_only", - "digital_signature", - "encipher_only", - "key_agreement", - "key_cert_sign", - "key_encipherment", - ]: - try: - first_val = getattr(first, usage) - except ValueError: - first_val = False - try: - second_val = getattr(second, usage) - except ValueError: - second_val = False - - if first_val != second_val: # pragma: no cover # would only be run in case of error - diffs.append(f" * {usage}: {first_val} -> {second_val}") - - if msg is None: - msg = "KeyUsage extensions differ:" - if diffs: # pragma: no cover # would only be run in case of error - raise self.failureException(msg + "\n" + "\n".join(diffs)) - - def assertTLSFeatureEqual( # pylint: disable=invalid-name - self, first: x509.TLSFeature, second: x509.TLSFeature, msg: Optional[str] = None - ) -> None: - """Type equality function for x509.TLSFeature.""" - self.assertEqual(set(first), set(second), msg=msg) - - def assertIssuer( # pylint: disable=invalid-name - self, issuer: CertificateAuthority, cert: X509CertMixin - ) -> None: - """Assert that the issuer for `cert` matches the subject of `issuer`.""" - self.assertEqual(cert.issuer, issuer.subject) - def assertMessages( # pylint: disable=invalid-name self, response: "HttpResponse", expected: list[str] ) -> None: """Assert given Django messages for `response`.""" messages = [str(m) for m in list(get_messages(response.wsgi_request))] - self.assertEqual(messages, expected) + assert messages == expected def assertNotRevoked(self, cert: X509CertMixin) -> None: # pylint: disable=invalid-name """Assert that the certificate is not revoked.""" cert.refresh_from_db() - self.assertFalse(cert.revoked) - self.assertEqual(cert.revoked_reason, "") + assert not cert.revoked + assert cert.revoked_reason == "" def assertPostRevoke(self, post: mock.Mock, cert: Certificate) -> None: # pylint: disable=invalid-name """Assert that the post_revoke_cert signal was called.""" @@ -381,13 +297,13 @@ def mute_celery(self, *calls: Any) -> Iterator[mock.MagicMock]: # Make sure that all invocations are JSON serializable for invocation in mocked.call_args_list: # invocation apply_async() has task args as arg[0] and arg[1] - self.assertIsInstance(json.dumps(invocation.args[0]), str) - self.assertIsInstance(json.dumps(invocation.args[1]), str) + assert isinstance(json.dumps(invocation.args[0]), str) + assert isinstance(json.dumps(invocation.args[1]), str) # Make sure that task was called the right number of times - self.assertEqual(len(calls), len(mocked.call_args_list)) + assert len(calls) == len(mocked.call_args_list) for expected, actual in zip(calls, mocked.call_args_list): - self.assertEqual(expected, actual, actual) + assert expected == actual, actual @contextmanager def patch(self, *args: Any, **kwargs: Any) -> Iterator[mock.MagicMock]: @@ -445,10 +361,10 @@ def assertBundle( # pylint: disable=invalid-name expected_content = "\n".join([e.pub.pem.strip() for e in expected]) + "\n" response = self.client.get(url, {"format": "PEM"}) - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertEqual(response["Content-Type"], "application/pkix-cert") - self.assertEqual(response["Content-Disposition"], f"attachment; filename={filename}") - self.assertEqual(response.content.decode("utf-8"), expected_content) + assert response.status_code == HTTPStatus.OK + assert response["Content-Type"] == "application/pkix-cert" + assert response["Content-Disposition"] == f"attachment; filename={filename}" + assert response.content.decode("utf-8") == expected_content def assertRequiresLogin( # pylint: disable=invalid-name self, response: "HttpResponse", **kwargs: Any @@ -511,7 +427,7 @@ def get_changelists( def test_model_count(self) -> None: """Test that the implementing TestCase actually creates some instances.""" - self.assertGreater(self.model._default_manager.all().count(), 0) + assert self.model._default_manager.all().count() > 0 def test_changelist_view(self) -> None: """Test that the changelist view works.""" diff --git a/ca/django_ca/tests/commands/test_dump_ca.py b/ca/django_ca/tests/commands/test_dump_ca.py index da410fbfc..4baff39e6 100644 --- a/ca/django_ca/tests/commands/test_dump_ca.py +++ b/ca/django_ca/tests/commands/test_dump_ca.py @@ -41,7 +41,7 @@ def test_basic(root: CertificateAuthority) -> None: assert stdout.decode() == root.pub.pem -@pytest.mark.parametrize("encoding", [Encoding.PEM, Encoding.DER]) +@pytest.mark.parametrize("encoding", (Encoding.PEM, Encoding.DER)) def test_format(root: CertificateAuthority, encoding: Encoding) -> None: """Test encoding parameter.""" stdout = dump_ca(root.serial, format=encoding) diff --git a/ca/django_ca/tests/commands/test_dump_crl.py b/ca/django_ca/tests/commands/test_dump_crl.py index 26e5852a2..63dae914b 100644 --- a/ca/django_ca/tests/commands/test_dump_crl.py +++ b/ca/django_ca/tests/commands/test_dump_crl.py @@ -144,7 +144,7 @@ def test_disabled(usable_root: CertificateAuthority) -> None: assert_crl(stdout, signer=usable_root, algorithm=usable_root.algorithm) -@pytest.mark.parametrize("reason", [x509.ReasonFlags.unspecified, x509.ReasonFlags.key_compromise]) +@pytest.mark.parametrize("reason", (x509.ReasonFlags.unspecified, x509.ReasonFlags.key_compromise)) def test_revoked_with_reason( usable_root: CertificateAuthority, root_cert: Certificate, reason: x509.ReasonFlags ) -> None: diff --git a/ca/django_ca/tests/commands/test_init_ca.py b/ca/django_ca/tests/commands/test_init_ca.py index 262ed2ab9..20ae06e9f 100644 --- a/ca/django_ca/tests/commands/test_init_ca.py +++ b/ca/django_ca/tests/commands/test_init_ca.py @@ -867,7 +867,7 @@ def test_password(ca_name: str, key_backend: StoragesBackend) -> None: key_backend.get_key(parent, use_options) # Wrong password doesn't work either - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 # cryptography controls the error message # NOTE: cryptography is notoriously unstable when it comes to the error message here, so we only # check the exception class. key_backend.get_key(parent, StoragesUsePrivateKeyOptions(password=b"wrong")) @@ -1483,7 +1483,7 @@ def test_key_size_with_unsupported_key_type(ca_name: str, key_type: str) -> None @pytest.mark.skipif(CRYPTOGRAPHY_VERSION < (43,), reason="cryptography check was added in version 43") @pytest.mark.parametrize( - "value,msg", + ("value", "msg"), ( ("", r"Attribute's length must be >= 1 and <= 64, but it was 0"), ("X" * 65, r"Attribute's length must be >= 1 and <= 64, but it was 65"), diff --git a/ca/django_ca/tests/commands/test_list_cas.py b/ca/django_ca/tests/commands/test_list_cas.py index e8bc999a1..4d1da9f36 100644 --- a/ca/django_ca/tests/commands/test_list_cas.py +++ b/ca/django_ca/tests/commands/test_list_cas.py @@ -48,7 +48,7 @@ def assertOutput( # pylint: disable=invalid-name context.update(CERT_DATA) for ca_name in self.cas: context.setdefault(f"{ca_name}_state", "") - self.assertEqual(output, expected.format(**context)) + assert output == expected.format(**context) def test_all_cas(self) -> None: """Test list with all CAs.""" @@ -56,54 +56,58 @@ def test_all_cas(self) -> None: self.load_ca(name) stdout, stderr = cmd("list_cas") - self.assertEqual( - stdout, - f"""{CERT_DATA['letsencrypt_x1']['serial_colons']} - {CERT_DATA['letsencrypt_x1']['name']} -{CERT_DATA['letsencrypt_x3']['serial_colons']} - {CERT_DATA['letsencrypt_x3']['name']} -{CERT_DATA['dst_root_x3']['serial_colons']} - {CERT_DATA['dst_root_x3']['name']} -{CERT_DATA['google_g3']['serial_colons']} - {CERT_DATA['google_g3']['name']} -{CERT_DATA['globalsign_r2_root']['serial_colons']} - {CERT_DATA['globalsign_r2_root']['name']} -{CERT_DATA['trustid_server_a52']['serial_colons']} - {CERT_DATA['trustid_server_a52']['name']} -{CERT_DATA['rapidssl_g3']['serial_colons']} - {CERT_DATA['rapidssl_g3']['name']} -{CERT_DATA['geotrust']['serial_colons']} - {CERT_DATA['geotrust']['name']} -{CERT_DATA['startssl_class2']['serial_colons']} - {CERT_DATA['startssl_class2']['name']} -{CERT_DATA['digicert_sha2']['serial_colons']} - {CERT_DATA['digicert_sha2']['name']} -{CERT_DATA['globalsign_dv']['serial_colons']} - {CERT_DATA['globalsign_dv']['name']} -{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']} -{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']} -{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']} -{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']} -{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']} -{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']} -{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']} -{CERT_DATA['comodo_ev']['serial_colons']} - {CERT_DATA['comodo_ev']['name']} -{CERT_DATA['globalsign']['serial_colons']} - {CERT_DATA['globalsign']['name']} -{CERT_DATA['digicert_ha_intermediate']['serial_colons']} - {CERT_DATA['digicert_ha_intermediate']['name']} -{CERT_DATA['comodo_dv']['serial_colons']} - {CERT_DATA['comodo_dv']['name']} -{CERT_DATA['startssl_class3']['serial_colons']} - {CERT_DATA['startssl_class3']['name']} -{CERT_DATA['godaddy_g2_intermediate']['serial_colons']} - {CERT_DATA['godaddy_g2_intermediate']['name']} -{CERT_DATA['digicert_ev_root']['serial_colons']} - {CERT_DATA['digicert_ev_root']['name']} -{CERT_DATA['digicert_global_root']['serial_colons']} - {CERT_DATA['digicert_global_root']['name']} -{CERT_DATA['identrust_root_1']['serial_colons']} - {CERT_DATA['identrust_root_1']['name']} -{CERT_DATA['startssl_root']['serial_colons']} - {CERT_DATA['startssl_root']['name']} -{CERT_DATA['godaddy_g2_root']['serial_colons']} - {CERT_DATA['godaddy_g2_root']['name']} -{CERT_DATA['comodo']['serial_colons']} - {CERT_DATA['comodo']['name']} -""", + assert ( + stdout + == f"{CERT_DATA['letsencrypt_x1']['serial_colons']} - {CERT_DATA['letsencrypt_x1']['name']}\n" + f"{CERT_DATA['letsencrypt_x3']['serial_colons']} - {CERT_DATA['letsencrypt_x3']['name']}\n" + f"{CERT_DATA['dst_root_x3']['serial_colons']} - {CERT_DATA['dst_root_x3']['name']}\n" + f"{CERT_DATA['google_g3']['serial_colons']} - {CERT_DATA['google_g3']['name']}\n" + f"{CERT_DATA['globalsign_r2_root']['serial_colons']}" + f" - {CERT_DATA['globalsign_r2_root']['name']}\n" + f"{CERT_DATA['trustid_server_a52']['serial_colons']}" + f" - {CERT_DATA['trustid_server_a52']['name']}\n" + f"{CERT_DATA['rapidssl_g3']['serial_colons']} - {CERT_DATA['rapidssl_g3']['name']}\n" + f"{CERT_DATA['geotrust']['serial_colons']} - {CERT_DATA['geotrust']['name']}\n" + f"{CERT_DATA['startssl_class2']['serial_colons']} - {CERT_DATA['startssl_class2']['name']}\n" + f"{CERT_DATA['digicert_sha2']['serial_colons']} - {CERT_DATA['digicert_sha2']['name']}\n" + f"{CERT_DATA['globalsign_dv']['serial_colons']} - {CERT_DATA['globalsign_dv']['name']}\n" + f"{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}\n" + f"{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}\n" + f"{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}\n" + f"{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}\n" + f"{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}\n" + f"{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}\n" + f"{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}\n" + f"{CERT_DATA['comodo_ev']['serial_colons']} - {CERT_DATA['comodo_ev']['name']}\n" + f"{CERT_DATA['globalsign']['serial_colons']} - {CERT_DATA['globalsign']['name']}\n" + f"{CERT_DATA['digicert_ha_intermediate']['serial_colons']}" + f" - {CERT_DATA['digicert_ha_intermediate']['name']}\n" + f"{CERT_DATA['comodo_dv']['serial_colons']} - {CERT_DATA['comodo_dv']['name']}\n" + f"{CERT_DATA['startssl_class3']['serial_colons']} - {CERT_DATA['startssl_class3']['name']}\n" + f"{CERT_DATA['godaddy_g2_intermediate']['serial_colons']}" + f" - {CERT_DATA['godaddy_g2_intermediate']['name']}\n" + f"{CERT_DATA['digicert_ev_root']['serial_colons']} - {CERT_DATA['digicert_ev_root']['name']}\n" + f"{CERT_DATA['digicert_global_root']['serial_colons']}" + f" - {CERT_DATA['digicert_global_root']['name']}\n" + f"{CERT_DATA['identrust_root_1']['serial_colons']} - {CERT_DATA['identrust_root_1']['name']}\n" + f"{CERT_DATA['startssl_root']['serial_colons']} - {CERT_DATA['startssl_root']['name']}\n" + f"{CERT_DATA['godaddy_g2_root']['serial_colons']} - {CERT_DATA['godaddy_g2_root']['name']}\n" + f"{CERT_DATA['comodo']['serial_colons']} - {CERT_DATA['comodo']['name']}\n" ) - self.assertEqual(stderr, "") + assert stderr == "" def test_no_cas(self) -> None: """Test the command if no CAs are defined.""" CertificateAuthority.objects.all().delete() stdout, stderr = cmd("list_cas") - self.assertEqual(stdout, "") - self.assertEqual(stderr, "") + assert stdout == "" + assert stderr == "" def test_basic(self) -> None: """Basic test of the command.""" stdout, stderr = cmd("list_cas") self.assertOutput(stdout, EXPECTED) - self.assertEqual(stderr, "") + assert stderr == "" def test_disabled(self) -> None: """Test the command if some CA is disabled.""" @@ -112,7 +116,7 @@ def test_disabled(self) -> None: stdout, stderr = cmd("list_cas") self.assertOutput(stdout, EXPECTED, child_state=" (disabled)") - self.assertEqual(stderr, "") + assert stderr == "" @freeze_time(TIMESTAMPS["everything_valid"]) def test_tree(self) -> None: @@ -121,18 +125,16 @@ def test_tree(self) -> None: NOTE: freeze_time b/c we create some fake CA objects and order in the tree depends on validity. """ stdout, stderr = cmd("list_cas", tree=True) - self.assertEqual( - stdout, - f"""{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']} -{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']} -{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']} -{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']} -{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']} -{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']} -└───{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']} -""", + assert ( + stdout == f"{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}\n" + f"{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}\n" + f"{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}\n{ + CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}\n" + f"{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}\n" + f"{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}\n" + f"└───{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}\n" ) - self.assertEqual(stderr, "") + assert stderr == "" # manually create Certificate objects not_after = timezone.now() + timedelta(days=3) @@ -155,17 +157,15 @@ def test_tree(self) -> None: ) stdout, stderr = cmd("list_cas", tree=True) - self.assertEqual( - stdout, - f"""{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']} -{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']} -{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']} -{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']} -{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']} -{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']} -│───ch:il:d3 - child3 -│ └───ch:il:d3:.1 - child3.1 -│───ch:il:d4 - child4 -└───{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']} -""", + assert ( + stdout == f"{CERT_DATA['dsa']['serial_colons']} - {CERT_DATA['dsa']['name']}\n" + f"{CERT_DATA['ec']['serial_colons']} - {CERT_DATA['ec']['name']}\n" + f"{CERT_DATA['ed25519']['serial_colons']} - {CERT_DATA['ed25519']['name']}\n" + f"{CERT_DATA['ed448']['serial_colons']} - {CERT_DATA['ed448']['name']}\n" + f"{CERT_DATA['pwd']['serial_colons']} - {CERT_DATA['pwd']['name']}\n" + f"{CERT_DATA['root']['serial_colons']} - {CERT_DATA['root']['name']}\n" + f"│───ch:il:d3 - child3\n" + f"│ └───ch:il:d3:.1 - child3.1\n" + f"│───ch:il:d4 - child4\n" + f"└───{CERT_DATA['child']['serial_colons']} - {CERT_DATA['child']['name']}\n" ) diff --git a/ca/django_ca/tests/commands/test_list_certs.py b/ca/django_ca/tests/commands/test_list_certs.py index 2e9613721..758af9ca1 100644 --- a/ca/django_ca/tests/commands/test_list_certs.py +++ b/ca/django_ca/tests/commands/test_list_certs.py @@ -49,8 +49,8 @@ def assertCerts(self, *certs: Certificate, **kwargs: Any) -> None: # pylint: di """Assert that command outputs the given certs.""" stdout, stderr = cmd("list_certs", **kwargs) sorted_certs = sorted(certs, key=lambda c: (c.not_after, c.cn, c.serial)) - self.assertEqual(stdout, "".join([f"{self._line(c)}\n" for c in sorted_certs])) - self.assertEqual(stderr, "") + assert stdout == "".join([f"{self._line(c)}\n" for c in sorted_certs]) + assert stderr == "" @freeze_time(TIMESTAMPS["everything_valid"]) def test_basic(self) -> None: diff --git a/ca/django_ca/tests/commands/test_notify.py b/ca/django_ca/tests/commands/test_notify.py index aca3e7512..c8786c890 100644 --- a/ca/django_ca/tests/commands/test_notify.py +++ b/ca/django_ca/tests/commands/test_notify.py @@ -37,18 +37,18 @@ class NotifyExpiringCertsTestCase(TestCaseMixin, TestCase): def test_no_certs(self) -> None: """Try notify command when all certs are still valid.""" stdout, stderr = cmd("notify_expiring_certs") - self.assertEqual(stdout, "") - self.assertEqual(stderr, "") - self.assertEqual(len(mail.outbox), 0) + assert stdout == "" + assert stderr == "" + assert len(mail.outbox) == 0 @freeze_time(TIMESTAMPS["ca_certs_expiring"]) def test_no_watchers(self) -> None: """Try expiring certs, but with no watchers.""" # certs have no watchers by default, so we get no mails stdout, stderr = cmd("notify_expiring_certs") - self.assertEqual(stdout, "") - self.assertEqual(stderr, "") - self.assertEqual(len(mail.outbox), 0) + assert stdout == "" + assert stderr == "" + assert len(mail.outbox) == 0 @freeze_time(TIMESTAMPS["ca_certs_expiring"]) def test_one_watcher(self) -> None: @@ -59,11 +59,11 @@ def test_one_watcher(self) -> None: timestamp = self.cert.not_after.strftime("%Y-%m-%d") stdout, stderr = cmd("notify_expiring_certs") - self.assertEqual(stdout, "") - self.assertEqual(stderr, "") - self.assertEqual(len(mail.outbox), 1) - self.assertEqual(mail.outbox[0].subject, f"Certificate expiration for {self.cert.cn} on {timestamp}") - self.assertEqual(mail.outbox[0].to, [email]) + assert stdout == "" + assert stderr == "" + assert len(mail.outbox) == 1 + assert mail.outbox[0].subject == f"Certificate expiration for {self.cert.cn} on {timestamp}" + assert mail.outbox[0].to == [email] def test_notification_days(self) -> None: """Test that user gets multiple notifications of expiring certs.""" @@ -74,8 +74,8 @@ def test_notification_days(self) -> None: with freeze_time(self.cert.not_after - timedelta(days=20)) as frozen_time: for _i in reversed(range(0, 20)): stdout, stderr = cmd("notify_expiring_certs", days=14) - self.assertEqual(stdout, "") - self.assertEqual(stderr, "") + assert stdout == "" + assert stderr == "" frozen_time.tick(timedelta(days=1)) - self.assertEqual(len(mail.outbox), 4) + assert len(mail.outbox) == 4 diff --git a/ca/django_ca/tests/commands/test_regenerate_ocsp_keys.py b/ca/django_ca/tests/commands/test_regenerate_ocsp_keys.py index e99675b84..a7a7c0093 100644 --- a/ca/django_ca/tests/commands/test_regenerate_ocsp_keys.py +++ b/ca/django_ca/tests/commands/test_regenerate_ocsp_keys.py @@ -75,14 +75,14 @@ def assertKey( # pylint: disable=invalid-name priv = typing.cast( CertificateIssuerPrivateKeyTypes, load_der_private_key(read_file(priv_path), password) ) - self.assertIsInstance(priv, key_type) + assert isinstance(priv, key_type) if isinstance(priv, (dsa.DSAPrivateKey, rsa.RSAPrivateKey)): - self.assertEqual(priv.key_size, key_size) + assert priv.key_size == key_size if isinstance(priv, ec.EllipticCurvePrivateKey): - self.assertIsInstance(priv.curve, elliptic_curve) + assert isinstance(priv.curve, elliptic_curve) cert = x509.load_pem_x509_certificate(read_file(cert_path)) - self.assertIsInstance(cert, x509.Certificate) + assert isinstance(cert, x509.Certificate) cert_qs = Certificate.objects.filter(ca=ca).exclude(pk__in=self.existing_certs) @@ -105,7 +105,7 @@ def assertKey( # pylint: disable=invalid-name if ad.access_method == AuthorityInformationAccessOID.CA_ISSUERS ), ) - self.assertEqual(aia, expected_aia) + assert aia == expected_aia return priv, cert @@ -232,8 +232,8 @@ def test_overwrite(self) -> None: new_priv, new_cert = self.assertKey(self.cas["root"], excludes=excludes) # Key/Cert should now be different - self.assertNotEqual(priv, new_priv) - self.assertNotEqual(cert, new_cert) + assert priv != new_priv + assert cert != new_cert @override_tmpcadir() def test_wrong_serial(self) -> None: diff --git a/ca/django_ca/tests/commands/test_resign_cert.py b/ca/django_ca/tests/commands/test_resign_cert.py index 228261550..ca22fe220 100644 --- a/ca/django_ca/tests/commands/test_resign_cert.py +++ b/ca/django_ca/tests/commands/test_resign_cert.py @@ -124,7 +124,7 @@ def test_basic(self) -> None: new = Certificate.objects.get(pub=stdout) assert_resigned(self.cert, new) assert_equal_ext(self.cert, new) - self.assertIsInstance(new.algorithm, type(self.cert.algorithm)) + assert isinstance(new.algorithm, type(self.cert.algorithm)) @override_tmpcadir() def test_dsa_ca_resign(self) -> None: @@ -136,7 +136,7 @@ def test_dsa_ca_resign(self) -> None: new = Certificate.objects.get(pub=stdout) assert_resigned(self.certs["dsa-cert"], new) assert_equal_ext(self.certs["dsa-cert"], new) - self.assertIsInstance(new.algorithm, hashes.SHA256) + assert isinstance(new.algorithm, hashes.SHA256) @override_tmpcadir() def test_all_extensions_certificate(self) -> None: @@ -148,20 +148,19 @@ def test_all_extensions_certificate(self) -> None: new = Certificate.objects.get(pub=stdout) assert_resigned(orig, new) - self.assertIsInstance(new.algorithm, hashes.SHA256) + assert isinstance(new.algorithm, hashes.SHA256) expected = orig.extensions actual = new.extensions - self.assertEqual( - sorted(expected.values(), key=lambda e: e.oid.dotted_string), - sorted(actual.values(), key=lambda e: e.oid.dotted_string), + assert sorted(expected.values(), key=lambda e: e.oid.dotted_string) == sorted( + actual.values(), key=lambda e: e.oid.dotted_string ) @override_tmpcadir() def test_test_all_extensions_cert_with_overrides(self) -> None: """Test resigning a certificate with adding new extensions.""" - self.assertIsNotNone(self.ca.sign_authority_information_access) - self.assertIsNotNone(self.ca.sign_crl_distribution_points) + assert self.ca.sign_authority_information_access is not None + assert self.ca.sign_crl_distribution_points is not None self.ca.sign_certificate_policies = certificate_policies( x509.PolicyInformation( policy_identifier=CertificatePoliciesOID.CPS_QUALIFIER, policy_qualifiers=None @@ -216,104 +215,91 @@ def test_test_all_extensions_cert_with_overrides(self) -> None: new = Certificate.objects.get(pub=stdout) assert_resigned(orig, new) - self.assertIsInstance(new.algorithm, hashes.SHA256) + assert isinstance(new.algorithm, hashes.SHA256) extensions = new.extensions # Test Authority Information Access extension - self.assertEqual( - extensions[ExtensionOID.AUTHORITY_INFORMATION_ACCESS], - x509.Extension( - oid=ExtensionOID.AUTHORITY_INFORMATION_ACCESS, - critical=False, - value=x509.AuthorityInformationAccess( - [ - x509.AccessDescription( - access_method=AuthorityInformationAccessOID.OCSP, - access_location=uri("http://ocsp.example.com/1"), - ), - x509.AccessDescription( - access_method=AuthorityInformationAccessOID.OCSP, - access_location=uri("http://ocsp.example.com/2"), - ), - x509.AccessDescription( - access_method=AuthorityInformationAccessOID.CA_ISSUERS, - access_location=uri("http://issuer.example.com/1"), - ), - x509.AccessDescription( - access_method=AuthorityInformationAccessOID.CA_ISSUERS, - access_location=uri("http://issuer.example.com/2"), - ), - ] - ), + assert extensions[ExtensionOID.AUTHORITY_INFORMATION_ACCESS] == x509.Extension( + oid=ExtensionOID.AUTHORITY_INFORMATION_ACCESS, + critical=False, + value=x509.AuthorityInformationAccess( + [ + x509.AccessDescription( + access_method=AuthorityInformationAccessOID.OCSP, + access_location=uri("http://ocsp.example.com/1"), + ), + x509.AccessDescription( + access_method=AuthorityInformationAccessOID.OCSP, + access_location=uri("http://ocsp.example.com/2"), + ), + x509.AccessDescription( + access_method=AuthorityInformationAccessOID.CA_ISSUERS, + access_location=uri("http://issuer.example.com/1"), + ), + x509.AccessDescription( + access_method=AuthorityInformationAccessOID.CA_ISSUERS, + access_location=uri("http://issuer.example.com/2"), + ), + ] ), ) # Test Certificate Policies extension - self.assertEqual( - extensions[ExtensionOID.CERTIFICATE_POLICIES], - x509.Extension( - oid=ExtensionOID.CERTIFICATE_POLICIES, - critical=False, - value=x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier("1.2.3"), - policy_qualifiers=[ - "https://example.com/overwritten/", - x509.UserNotice( - notice_reference=None, explicit_text="overwritten user notice text" - ), - ], - ) - ] - ), + assert extensions[ExtensionOID.CERTIFICATE_POLICIES] == x509.Extension( + oid=ExtensionOID.CERTIFICATE_POLICIES, + critical=False, + value=x509.CertificatePolicies( + policies=[ + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier("1.2.3"), + policy_qualifiers=[ + "https://example.com/overwritten/", + x509.UserNotice( + notice_reference=None, explicit_text="overwritten user notice text" + ), + ], + ) + ] ), ) # Test CRL Distribution Points extension - self.assertEqual( - extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS], - self.crl_distribution_points([uri("http://crl.example.com"), uri("http://crl.example.net")]), + assert extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS] == self.crl_distribution_points( + [uri("http://crl.example.com"), uri("http://crl.example.net")] ) # Test Extended Key Usage extension - self.assertEqual( - extensions[ExtensionOID.EXTENDED_KEY_USAGE], - extended_key_usage(ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH), + assert extensions[ExtensionOID.EXTENDED_KEY_USAGE] == extended_key_usage( + ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH ) # Test Issuer Alternative Name extension - self.assertEqual( - extensions[ExtensionOID.ISSUER_ALTERNATIVE_NAME], - issuer_alternative_name(dns("ian-override.example.com"), uri("http://ian-override.example.com")), + assert extensions[ExtensionOID.ISSUER_ALTERNATIVE_NAME] == issuer_alternative_name( + dns("ian-override.example.com"), uri("http://ian-override.example.com") ) # Test KeyUsage extension - self.assertEqual( - extensions[ExtensionOID.KEY_USAGE], - key_usage(key_agreement=True, key_encipherment=True, critical=False), + assert extensions[ExtensionOID.KEY_USAGE] == key_usage( + key_agreement=True, key_encipherment=True, critical=False ) # Test OCSP No Check extension - self.assertEqual(extensions[ExtensionOID.OCSP_NO_CHECK], ocsp_no_check(critical=True)) + assert extensions[ExtensionOID.OCSP_NO_CHECK] == ocsp_no_check(critical=True) # Test Subject Alternative Name extension - self.assertEqual( - extensions[x509.SubjectAlternativeName.oid], - subject_alternative_name(dns("override.example.net")), + assert extensions[x509.SubjectAlternativeName.oid] == subject_alternative_name( + dns("override.example.net") ) # Test TLSFeature extension - self.assertEqual( - extensions[ExtensionOID.TLS_FEATURE], tls_feature(x509.TLSFeatureType.status_request) - ) + assert extensions[ExtensionOID.TLS_FEATURE] == tls_feature(x509.TLSFeatureType.status_request) @override_tmpcadir() def test_test_no_extensions_cert_with_overrides(self) -> None: """Test resigning a certificate with adding new extensions.""" - self.assertIsNotNone(self.ca.sign_authority_information_access) - self.assertIsNotNone(self.ca.sign_crl_distribution_points) + assert self.ca.sign_authority_information_access is not None + assert self.ca.sign_crl_distribution_points is not None self.ca.sign_certificate_policies = certificate_policies( x509.PolicyInformation( policy_identifier=CertificatePoliciesOID.CPS_QUALIFIER, policy_qualifiers=None @@ -361,68 +347,55 @@ def test_test_no_extensions_cert_with_overrides(self) -> None: new = Certificate.objects.get(pub=stdout) assert_resigned(orig, new) - self.assertIsInstance(new.algorithm, hashes.SHA256) + assert isinstance(new.algorithm, hashes.SHA256) extensions = new.extensions # Test Certificate Policies extension - self.assertEqual( - extensions[ExtensionOID.CERTIFICATE_POLICIES], - certificate_policies( - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier("1.2.3"), - policy_qualifiers=[ - "https://example.com/overwritten/", - x509.UserNotice(notice_reference=None, explicit_text="overwritten user notice text"), - ], - ) - ), + assert extensions[ExtensionOID.CERTIFICATE_POLICIES] == certificate_policies( + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier("1.2.3"), + policy_qualifiers=[ + "https://example.com/overwritten/", + x509.UserNotice(notice_reference=None, explicit_text="overwritten user notice text"), + ], + ) ) # Test CRL Distribution Points extension - self.assertEqual( - extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS], - crl_distribution_points( - distribution_point([uri("http://crl.example.com"), uri("http://crl.example.net")]) - ), + assert extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS] == crl_distribution_points( + distribution_point([uri("http://crl.example.com"), uri("http://crl.example.net")]) ) # Test Extended Key Usage extension - self.assertEqual( - extensions[ExtensionOID.EXTENDED_KEY_USAGE], - extended_key_usage(ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH), + assert extensions[ExtensionOID.EXTENDED_KEY_USAGE] == extended_key_usage( + ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH ) # Test Issuer Alternative Name extension - self.assertEqual( - extensions[ExtensionOID.ISSUER_ALTERNATIVE_NAME], - issuer_alternative_name(dns("ian-override.example.com"), uri("http://ian-override.example.com")), + assert extensions[ExtensionOID.ISSUER_ALTERNATIVE_NAME] == issuer_alternative_name( + dns("ian-override.example.com"), uri("http://ian-override.example.com") ) # Test Key Usage extension - self.assertEqual( - extensions[ExtensionOID.KEY_USAGE], key_usage(key_agreement=True, key_encipherment=True) - ) + assert extensions[ExtensionOID.KEY_USAGE] == key_usage(key_agreement=True, key_encipherment=True) # Test OCSP No Check extension - self.assertEqual(extensions[ExtensionOID.OCSP_NO_CHECK], ocsp_no_check()) + assert extensions[ExtensionOID.OCSP_NO_CHECK] == ocsp_no_check() # Test Subject Alternative Name extension - self.assertEqual( - extensions[x509.SubjectAlternativeName.oid], - subject_alternative_name(dns("override.example.net")), + assert extensions[x509.SubjectAlternativeName.oid] == subject_alternative_name( + dns("override.example.net") ) # Test TLSFeature extension - self.assertEqual( - extensions[ExtensionOID.TLS_FEATURE], tls_feature(x509.TLSFeatureType.status_request) - ) + assert extensions[ExtensionOID.TLS_FEATURE] == tls_feature(x509.TLSFeatureType.status_request) @override_tmpcadir() def test_test_no_extensions_cert_with_overrides_with_non_default_critical(self) -> None: """Test resigning a certificate with adding new extensions with non-default critical values.""" - self.assertIsNotNone(self.ca.sign_authority_information_access) - self.assertIsNotNone(self.ca.sign_crl_distribution_points) + assert self.ca.sign_authority_information_access is not None + assert self.ca.sign_crl_distribution_points is not None self.ca.sign_certificate_policies = certificate_policies( x509.PolicyInformation( policy_identifier=CertificatePoliciesOID.CPS_QUALIFIER, policy_qualifiers=None @@ -469,67 +442,55 @@ def test_test_no_extensions_cert_with_overrides_with_non_default_critical(self) new = Certificate.objects.get(pub=stdout) assert_resigned(orig, new) - self.assertIsInstance(new.algorithm, hashes.SHA256) + assert isinstance(new.algorithm, hashes.SHA256) extensions = new.extensions # Test Certificate Policies extension - self.assertEqual( - extensions[ExtensionOID.CERTIFICATE_POLICIES], - x509.Extension( - oid=ExtensionOID.CERTIFICATE_POLICIES, - critical=True, - value=x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier("1.2.3"), - policy_qualifiers=[ - "https://example.com/overwritten/", - x509.UserNotice( - notice_reference=None, explicit_text="overwritten user notice text" - ), - ], - ) - ] - ), + assert extensions[ExtensionOID.CERTIFICATE_POLICIES] == x509.Extension( + oid=ExtensionOID.CERTIFICATE_POLICIES, + critical=True, + value=x509.CertificatePolicies( + policies=[ + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier("1.2.3"), + policy_qualifiers=[ + "https://example.com/overwritten/", + x509.UserNotice( + notice_reference=None, explicit_text="overwritten user notice text" + ), + ], + ) + ] ), ) # Test CRL Distribution Points extension - self.assertEqual( - extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS], - self.crl_distribution_points( - [uri("http://crl.example.com"), uri("http://crl.example.net")], critical=True - ), + assert extensions[ExtensionOID.CRL_DISTRIBUTION_POINTS] == self.crl_distribution_points( + [uri("http://crl.example.com"), uri("http://crl.example.net")], critical=True ) # Test Extended Key Usage extension - self.assertEqual( - extensions[ExtensionOID.EXTENDED_KEY_USAGE], - extended_key_usage( - ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, critical=True - ), + assert extensions[ExtensionOID.EXTENDED_KEY_USAGE] == extended_key_usage( + ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, critical=True ) # Test Key Usage extension - self.assertEqual( - extensions[ExtensionOID.KEY_USAGE], - key_usage(key_agreement=True, key_encipherment=True, critical=False), + assert extensions[ExtensionOID.KEY_USAGE] == key_usage( + key_agreement=True, key_encipherment=True, critical=False ) # Test OCSP No Check extension - self.assertEqual(extensions[ExtensionOID.OCSP_NO_CHECK], ocsp_no_check(True)) + assert extensions[ExtensionOID.OCSP_NO_CHECK] == ocsp_no_check(True) # Test Subject Alternative Name extension - self.assertEqual( - extensions[x509.SubjectAlternativeName.oid], - subject_alternative_name(dns("override.example.net"), critical=True), + assert extensions[x509.SubjectAlternativeName.oid] == subject_alternative_name( + dns("override.example.net"), critical=True ) # Test TLSFeature extension - self.assertEqual( - extensions[ExtensionOID.TLS_FEATURE], - tls_feature(x509.TLSFeatureType.status_request, critical=True), + assert extensions[ExtensionOID.TLS_FEATURE] == tls_feature( + x509.TLSFeatureType.status_request, critical=True ) @override_tmpcadir() @@ -542,7 +503,7 @@ def test_custom_algorithm(self) -> None: new = Certificate.objects.get(pub=stdout) assert_resigned(self.cert, new) assert_equal_ext(self.cert, new) - self.assertIsInstance(new.algorithm, hashes.SHA512) + assert isinstance(new.algorithm, hashes.SHA512) @override_tmpcadir() def test_different_ca(self) -> None: @@ -589,31 +550,28 @@ def test_overwrite(self) -> None: new = Certificate.objects.get(pub=stdout) assert_resigned(self.cert, new) - self.assertEqual(new.subject, x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cname)])) - self.assertEqual(list(new.watchers.all()), [Watcher.objects.get(mail=watcher)]) + assert new.subject == x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, cname)]) + assert list(new.watchers.all()) == [Watcher.objects.get(mail=watcher)] # assert overwritten extensions extensions = new.extensions # Test Extended Key Usage extension - self.assertEqual( - extensions[ExtensionOID.EXTENDED_KEY_USAGE], - extended_key_usage(ExtendedKeyUsageOID.EMAIL_PROTECTION, critical=True), + assert extensions[ExtensionOID.EXTENDED_KEY_USAGE] == extended_key_usage( + ExtendedKeyUsageOID.EMAIL_PROTECTION, critical=True ) # Test Key Usage extension - self.assertEqual(extensions[ExtensionOID.KEY_USAGE], key_usage(crl_sign=True, critical=False)) + assert extensions[ExtensionOID.KEY_USAGE] == key_usage(crl_sign=True, critical=False) # Test Subject Alternative Name extension - self.assertEqual( - extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME], - subject_alternative_name(dns("subject-alternative-name.example.com")), + assert extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME] == subject_alternative_name( + dns("subject-alternative-name.example.com") ) # Test TLSFeature extension - self.assertEqual( - extensions[ExtensionOID.TLS_FEATURE], - tls_feature(x509.TLSFeatureType.status_request_v2, critical=True), + assert extensions[ExtensionOID.TLS_FEATURE] == tls_feature( + x509.TLSFeatureType.status_request_v2, critical=True ) @override_tmpcadir( @@ -627,7 +585,7 @@ def test_set_profile(self) -> None: assert stderr == "" new = Certificate.objects.get(pub=stdout) - self.assertEqual(new.not_after.date(), timezone.now().date() + timedelta(days=200)) + assert new.not_after.date() == timezone.now().date() + timedelta(days=200) assert_resigned(self.cert, new) assert_equal_ext(self.cert, new) @@ -645,7 +603,7 @@ def test_cert_profile(self) -> None: assert stderr == "" new = Certificate.objects.get(pub=stdout) - self.assertEqual(new.not_after.date(), timezone.now().date() + timedelta(days=200)) + assert new.not_after.date() == timezone.now().date() + timedelta(days=200) assert_resigned(self.cert, new) assert_equal_ext(self.cert, new) diff --git a/ca/django_ca/tests/commands/test_revoke_cert.py b/ca/django_ca/tests/commands/test_revoke_cert.py index 3b9c4c3a4..a4239a36e 100644 --- a/ca/django_ca/tests/commands/test_revoke_cert.py +++ b/ca/django_ca/tests/commands/test_revoke_cert.py @@ -47,24 +47,24 @@ def revoke( with mock_signal(pre_revoke_cert) as pre, mock_signal(post_revoke_cert) as post: stdout, stderr = cmd_e2e(["revoke_cert", cert.serial, *arguments]) - self.assertEqual(stdout, "") - self.assertEqual(stderr, "") + assert stdout == "" + assert stderr == "" cert.refresh_from_db() - self.assertEqual(pre.call_count, 1) + assert pre.call_count == 1 self.assertPostRevoke(post, cert) - self.assertTrue(cert.revoked) - self.assertTrue(cert.revoked_date is not None) - self.assertEqual(cert.revoked_reason, reason) + assert cert.revoked + assert cert.revoked_date is not None + assert cert.revoked_reason == reason def test_no_arguments(self) -> None: """Test revoking without a reason.""" - self.assertFalse(self.cert.revoked) + assert not self.cert.revoked self.revoke(self.cert) def test_with_reason(self) -> None: """Test revoking with a reason.""" - self.assertFalse(self.cert.revoked) + assert not self.cert.revoked for reason in ReasonFlags: self.revoke(self.cert, ["--reason", reason.name], reason=reason.name) @@ -79,26 +79,26 @@ def test_with_compromised(self) -> None: """Test revoking the certificate with a compromised date.""" now = datetime.now(tz=tz.utc) self.revoke(self.cert, arguments=["--compromised", now.isoformat()]) - self.assertEqual(self.cert.compromised, now) + assert self.cert.compromised == now def test_with_compromised_with_use_tz_is_false(self) -> None: """Test revoking the certificate with a compromised date with USE_TZ=False.""" with self.settings(USE_TZ=False): now = datetime.now(tz=tz.utc) self.revoke(self.cert, arguments=["--compromised", now.isoformat()]) - self.assertEqual(self.cert.compromised, timezone.make_naive(now)) + assert self.cert.compromised == timezone.make_naive(now) def test_revoked(self) -> None: """Test revoking a cert that is already revoked.""" - self.assertFalse(self.cert.revoked) + assert not self.cert.revoked with mock_signal(pre_revoke_cert) as pre, mock_signal(post_revoke_cert) as post: cmd("revoke_cert", self.cert.serial) cert = Certificate.objects.get(serial=self.cert.serial) - self.assertEqual(pre.call_count, 1) + assert pre.call_count == 1 self.assertPostRevoke(post, cert) - self.assertEqual(cert.revoked_reason, ReasonFlags.unspecified.name) + assert cert.revoked_reason == ReasonFlags.unspecified.name with ( assert_command_error(rf"^{self.cert.serial}: Certificate is already revoked\.$"), @@ -106,13 +106,13 @@ def test_revoked(self) -> None: mock_signal(post_revoke_cert) as post, ): cmd("revoke_cert", self.cert.serial, reason=ReasonFlags.key_compromise) - self.assertFalse(pre.called) - self.assertFalse(post.called) + assert not pre.called + assert not post.called cert = Certificate.objects.get(serial=self.cert.serial) - self.assertTrue(cert.revoked) - self.assertTrue(cert.revoked_date is not None) - self.assertEqual(cert.revoked_reason, ReasonFlags.unspecified.name) + assert cert.revoked + assert cert.revoked_date is not None + assert cert.revoked_reason == ReasonFlags.unspecified.name def test_compromised_with_naive_datetime(self) -> None: """Test passing a naive datetime (which is an error).""" diff --git a/ca/django_ca/tests/conftest.py b/ca/django_ca/tests/conftest.py index e01d77f24..6345ed3db 100644 --- a/ca/django_ca/tests/conftest.py +++ b/ca/django_ca/tests/conftest.py @@ -18,7 +18,6 @@ import importlib.metadata import os import sys -from collections.abc import Iterator from typing import Any import coverage @@ -143,7 +142,7 @@ def user( @pytest.fixture -def user_client(user: "User", client: Client) -> Iterator[Client]: +def user_client(user: "User", client: Client) -> Client: """A Django test client logged in as a normal user.""" client.force_login(user) # type: ignore[arg-type] # django-stubs 5.1.0 thinks user is AbstractUser return client diff --git a/ca/django_ca/tests/extensions/test_admin_html.py b/ca/django_ca/tests/extensions/test_admin_html.py index e5c769e9e..558710c2d 100644 --- a/ca/django_ca/tests/extensions/test_admin_html.py +++ b/ca/django_ca/tests/extensions/test_admin_html.py @@ -690,7 +690,7 @@ def _set_distribution_point_extension( def assertAdminHTML(self, name: str, cert: X509CertMixin) -> None: # pylint: disable=invalid-name """Assert that the actual extension HTML is equivalent to the expected HTML.""" for oid, ext in cert.extensions.items(): - self.assertIn(oid, self.admin_html[name], (name, oid)) + assert oid in self.admin_html[name], (name, oid) admin_html = self.admin_html[name][oid] admin_html = f'\n
{admin_html}
' actual = extension_as_admin_html(ext) diff --git a/ca/django_ca/tests/extensions/test_unknown_extension.py b/ca/django_ca/tests/extensions/test_unknown_extension.py index a1ebb5ce5..ce4210746 100644 --- a/ca/django_ca/tests/extensions/test_unknown_extension.py +++ b/ca/django_ca/tests/extensions/test_unknown_extension.py @@ -17,6 +17,8 @@ from django.test import TestCase +import pytest + from django_ca.extensions import extension_as_text, parse_extension @@ -39,17 +41,17 @@ def public_bytes(self) -> bytes: def test_parse_unknown_key(self) -> None: """Test exception for parsing an extension with an unsupported key.""" - with self.assertRaisesRegex(ValueError, r"^wrong_key: Unknown extension key\.$"): + with pytest.raises(ValueError, match=r"^wrong_key: Unknown extension key\.$"): parse_extension("wrong_key", {}) def test_no_extension_as_text(self) -> None: """Test textualizing an extension that is not an extension type.""" - with self.assertRaisesRegex(TypeError, r"^bytes: Not a cryptography\.x509\.ExtensionType\.$"): + with pytest.raises(TypeError, match=r"^bytes: Not a cryptography\.x509\.ExtensionType\.$"): extension_as_text(b"foo") # type: ignore[arg-type] def test_unknown_extension_type_as_text(self) -> None: """Test textualizing an extension of unknown type.""" - with self.assertRaisesRegex( - TypeError, r"^UnknownExtensionType \(oid: 1\.2\.3\): Unknown extension type\.$" + with pytest.raises( + TypeError, match=r"^UnknownExtensionType \(oid: 1\.2\.3\): Unknown extension type\.$" ): extension_as_text(self.ext_type) diff --git a/ca/django_ca/tests/extensions/test_utils.py b/ca/django_ca/tests/extensions/test_utils.py index 979fb6e86..06804b452 100644 --- a/ca/django_ca/tests/extensions/test_utils.py +++ b/ca/django_ca/tests/extensions/test_utils.py @@ -27,11 +27,11 @@ class CertificatePoliciesIsSimpleTestCase(TestCase): def assertIsSimple(self, *policies: x509.PolicyInformation) -> None: # pylint: disable=invalid-name """Assert that a Certificate Policies extension with the given policies is simple.""" - self.assertTrue(certificate_policies_is_simple(self.certificate_policy(*policies))) + assert certificate_policies_is_simple(self.certificate_policy(*policies)) def assertIsNotSimple(self, *policies: x509.PolicyInformation) -> None: # pylint: disable=invalid-name """Assert that a Certificate Policies extension with the given policies is *not* simple.""" - self.assertFalse(certificate_policies_is_simple(self.certificate_policy(*policies))) + assert not certificate_policies_is_simple(self.certificate_policy(*policies)) def certificate_policy(self, *policies: x509.PolicyInformation) -> x509.CertificatePolicies: """Create a Certificate Policy object from the given policies.""" diff --git a/ca/django_ca/tests/key_backends/hsm/test_backend.py b/ca/django_ca/tests/key_backends/hsm/test_backend.py index 496ab6ec0..889cca1f9 100644 --- a/ca/django_ca/tests/key_backends/hsm/test_backend.py +++ b/ca/django_ca/tests/key_backends/hsm/test_backend.py @@ -35,7 +35,7 @@ def test_session_with_session_read_only_exception(hsm_backend: HSMBackend) -> None: """Test exception message when SessionReadOnly() is raised.""" - with pytest.raises(pkcs11.PKCS11Error, match=r"^Attempting to write to a read-only session\.$"): + with pytest.raises(pkcs11.PKCS11Error, match=r"^Attempting to write to a read-only session\.$"): # noqa: PT012 with hsm_backend.session(so_pin=None, user_pin=settings.PKCS11_USER_PIN) as session: with patch.object(session, "get_key", side_effect=pkcs11.SessionReadOnly()): session.get_key() @@ -43,7 +43,7 @@ def test_session_with_session_read_only_exception(hsm_backend: HSMBackend) -> No def test_session_with_unknown_pkcs11_exception(hsm_backend: HSMBackend) -> None: """Test exception message when a generic PKCS11 error is raised.""" - with pytest.raises(pkcs11.PKCS11Error, match=r"^Unknown pkcs11 error \(SessionCount\)\.$"): + with pytest.raises(pkcs11.PKCS11Error, match=r"^Unknown pkcs11 error \(SessionCount\)\.$"): # noqa: PT012 with hsm_backend.session(so_pin=None, user_pin=settings.PKCS11_USER_PIN) as session: with patch.object(session, "get_key", side_effect=pkcs11.SessionCount()): session.get_key() diff --git a/ca/django_ca/tests/key_backends/hsm/test_models.py b/ca/django_ca/tests/key_backends/hsm/test_models.py index bd746cbba..e5e947234 100644 --- a/ca/django_ca/tests/key_backends/hsm/test_models.py +++ b/ca/django_ca/tests/key_backends/hsm/test_models.py @@ -25,7 +25,7 @@ ) -@pytest.mark.parametrize("so_pin,user_pin", (("so-pin-value", None), (None, "user-pin-value"))) +@pytest.mark.parametrize(("so_pin", "user_pin"), (("so-pin-value", None), (None, "user-pin-value"))) def test_pins(so_pin: Optional[str], user_pin: Optional[str]) -> None: """Test valid pin configurations.""" model = HSMUsePrivateKeyOptions(so_pin=so_pin, user_pin=user_pin) @@ -34,7 +34,7 @@ def test_pins(so_pin: Optional[str], user_pin: Optional[str]) -> None: @pytest.mark.parametrize( - "so_pin,user_pin,error", + ("so_pin", "user_pin", "error"), ( (None, None, r"Provide one of so_pin or user_pin\."), ("so-pin-value", "user-pin-value", r"Provide either so_pin or user_pin\."), @@ -79,6 +79,6 @@ def test_with_no_context(caplog: LogCaptureFixture) -> None: def test_with_no_backend_in_context(caplog: LogCaptureFixture) -> None: """Test creating a Model with loading the pins from the context.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message HSMUsePrivateKeyOptions.model_validate({}, context={"foo": "bar"}) assert "Did not receive backend in context." in caplog.text diff --git a/ca/django_ca/tests/key_backends/hsm/test_session.py b/ca/django_ca/tests/key_backends/hsm/test_session.py index c28bb0295..b48a16069 100644 --- a/ca/django_ca/tests/key_backends/hsm/test_session.py +++ b/ca/django_ca/tests/key_backends/hsm/test_session.py @@ -34,7 +34,7 @@ @pytest.fixture -def pool_key(softhsm_token: str) -> Iterator[PoolKeyType]: +def pool_key(softhsm_token: str) -> PoolKeyType: """Minor fixture to return the pool key for the default settings.""" return settings.PKCS11_PATH, softhsm_token, None, settings.PKCS11_USER_PIN diff --git a/ca/django_ca/tests/key_backends/test_storages.py b/ca/django_ca/tests/key_backends/test_storages.py index 238adbb66..d8086cce0 100644 --- a/ca/django_ca/tests/key_backends/test_storages.py +++ b/ca/django_ca/tests/key_backends/test_storages.py @@ -40,7 +40,7 @@ def test_private_key_options_key_size(key_size: int) -> None: @pytest.mark.parametrize("key_size", (-2048, -1, 0, 1, 1023, 1025, 2047, 2049, 8191, 8193, 1000, 2000, 3000)) def test_private_key_options_with_invalid_key_size(key_size: int) -> None: """Test invalid key sizes for private key options.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message StoragesCreatePrivateKeyOptions( key_type="RSA", password=None, path=Path("/does/not/exist"), key_size=key_size ) diff --git a/ca/django_ca/tests/models/test_certificate.py b/ca/django_ca/tests/models/test_certificate.py index 2c11547cd..a3e14d0e4 100644 --- a/ca/django_ca/tests/models/test_certificate.py +++ b/ca/django_ca/tests/models/test_certificate.py @@ -46,7 +46,7 @@ def test_revocation() -> None: # Never really happens in real life, but should still be checked cert = Certificate(revoked=False) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"^Certificate is not revoked\.$"): cert.get_revocation() @@ -129,7 +129,7 @@ def test_validate_past(root_cert: Certificate) -> None: root_cert.full_clean() -@pytest.mark.parametrize("name,algorithm", (("sha256", hashes.SHA256()), ("sha512", hashes.SHA512()))) +@pytest.mark.parametrize(("name", "algorithm"), (("sha256", hashes.SHA256()), ("sha512", hashes.SHA512()))) def test_get_fingerprint(name: str, algorithm: hashes.HashAlgorithm, usable_cert: Certificate) -> None: """Test getting the fingerprint value.""" cert_name = usable_cert.test_name # type: ignore[attr-defined] diff --git a/ca/django_ca/tests/models/test_certificate_authority.py b/ca/django_ca/tests/models/test_certificate_authority.py index a2748e6a1..f5e9261be 100644 --- a/ca/django_ca/tests/models/test_certificate_authority.py +++ b/ca/django_ca/tests/models/test_certificate_authority.py @@ -450,7 +450,7 @@ def test_serial(usable_ca: CertificateAuthority) -> None: assert usable_ca.serial == CERT_DATA[usable_ca.name].get("serial") -@pytest.mark.parametrize("name,algorithm", (("sha256", hashes.SHA256()), ("sha512", hashes.SHA512()))) +@pytest.mark.parametrize(("name", "algorithm"), (("sha256", hashes.SHA256()), ("sha512", hashes.SHA512()))) def test_get_fingerprint(name: str, algorithm: hashes.HashAlgorithm, usable_ca: CertificateAuthority) -> None: """Test getting the fingerprint value.""" assert usable_ca.get_fingerprint(algorithm) == CERT_DATA[usable_ca.name][name] diff --git a/ca/django_ca/tests/pydantic/base.py b/ca/django_ca/tests/pydantic/base.py index 21c7e9b1c..5c9ff1a2c 100644 --- a/ca/django_ca/tests/pydantic/base.py +++ b/ca/django_ca/tests/pydantic/base.py @@ -61,7 +61,7 @@ def assert_validation_errors( expected_errors: ExpectedErrors, ) -> None: """Assertion method to test validation errors.""" - with pytest.raises(ValidationError) as ex_info: + with pytest.raises(ValidationError) as ex_info: # noqa: PT012 if isinstance(parameters, list): model_class(parameters) # type: ignore[call-arg] # ruled out with overload else: diff --git a/ca/django_ca/tests/pydantic/test_extensions.py b/ca/django_ca/tests/pydantic/test_extensions.py index eca8ac096..75b263862 100644 --- a/ca/django_ca/tests/pydantic/test_extensions.py +++ b/ca/django_ca/tests/pydantic/test_extensions.py @@ -183,7 +183,7 @@ def test_critical_validation() -> None: @pytest.mark.parametrize( - "parameters,expected", + ("parameters", "expected"), ( ( { @@ -212,7 +212,7 @@ def test_access_description_model(parameters: dict[str, Any], expected: x509.Acc @pytest.mark.parametrize( - "parameters,expected", + ("parameters", "expected"), ( ( {"full_name": [GENERAL_NAME]}, @@ -252,7 +252,7 @@ def test_distribution_point(parameters: dict[str, Any], expected: x509.Distribut @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( { @@ -279,7 +279,7 @@ def test_distribution_point_errors(parameters: dict[str, Any], expected_errors: @pytest.mark.parametrize( - "parameters,expected", + ("parameters", "expected"), ( ( {"full_name": [GENERAL_NAME]}, @@ -359,9 +359,9 @@ def test_signed_certificate_timestamp(signed_certificate_timestamp_pub: x509.Cer @pytest.mark.parametrize("critical", (True, False, None)) -@pytest.mark.parametrize("general_names,parsed_general_names", (([GENERAL_NAME], [dns("example.com")]),)) +@pytest.mark.parametrize(("general_names", "parsed_general_names"), (([GENERAL_NAME], [dns("example.com")]),)) @pytest.mark.parametrize( - "model,extension_type", + ("model", "extension_type"), ( (SubjectAlternativeNameModel, x509.SubjectAlternativeName), (IssuerAlternativeNameModel, x509.IssuerAlternativeName), @@ -381,7 +381,7 @@ def test_alternative_name_extensions( @pytest.mark.parametrize("model", (SubjectAlternativeNameModel, IssuerAlternativeNameModel)) @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ([], [("value_error", ("value",), "Value error, value must not be empty")]), ([GENERAL_NAME] * 2, [("value_error", ("value",), re.compile("value must be unique$"))]), @@ -398,7 +398,7 @@ def test_alternative_name_extensions_errors( @pytest.mark.parametrize("critical", (False, None)) @pytest.mark.parametrize( - "parameters,descriptions", + ("parameters", "descriptions"), ( ( [{"access_method": "ocsp", "access_location": GENERAL_NAME}], @@ -425,7 +425,7 @@ def test_authority_information_access( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ({"value": []}, [("value_error", ("value",), "Value error, value must not be empty")]), ( @@ -461,7 +461,7 @@ def test_authority_information_access_errors( @pytest.mark.parametrize("critical", (False, None)) @pytest.mark.parametrize( - "parameters,extension", + ("parameters", "extension"), ( ( {"key_identifier": b"MTIz"}, @@ -503,7 +503,7 @@ def test_authority_key_identifier( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( { @@ -566,7 +566,7 @@ def test_authority_key_identifier_errors(parameters: dict[str, Any], expected_er @pytest.mark.parametrize("critical", (True, False, None)) @pytest.mark.parametrize( - "parameters,extension", + ("parameters", "extension"), ( ({"ca": False, "path_length": None}, x509.BasicConstraints(ca=False, path_length=None)), ({"ca": True, "path_length": None}, x509.BasicConstraints(ca=True, path_length=None)), @@ -587,7 +587,7 @@ def test_basic_constraints( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( {}, @@ -620,7 +620,7 @@ def test_basic_constraints_errors(parameters: dict[str, Any], expected_errors: E @pytest.mark.parametrize("critical", (True, False, None)) @pytest.mark.parametrize( - "parameters,policies", + ("parameters", "policies"), ( ( ( @@ -745,7 +745,7 @@ def test_certificate_policies( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( [], @@ -807,7 +807,7 @@ def test_certificate_policies_errors(parameters: dict[str, Any], expected_errors @pytest.mark.parametrize("critical", (True, False, None)) @pytest.mark.parametrize( - "parameters,distribution_points", + ("parameters", "distribution_points"), DISTRIBUTION_POINTS_PARAMETERS, ) def test_crl_distribution_points( @@ -823,7 +823,7 @@ def test_crl_distribution_points( @pytest.mark.parametrize("model", (CRLDistributionPointsModel, FreshestCRLModel)) @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( [], @@ -851,14 +851,14 @@ def test_distribution_point_extension_errors( @pytest.mark.parametrize("critical", (False, None)) -@pytest.mark.parametrize("crl_number", [0, 1]) +@pytest.mark.parametrize("crl_number", (0, 1)) def test_crl_number(critical: Optional[bool], crl_number: int) -> None: """Test the CRLNumberModel.""" assert_extension_model(CRLNumberModel, crl_number, x509.CRLNumber(crl_number), critical) @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ({"value": -1}, [("greater_than_equal", ("value",), "Input should be greater than or equal to 0")]), ({"value": 0, "critical": True}, [MUST_BE_NON_CRITICAL_ERROR]), @@ -870,14 +870,14 @@ def test_crl_number_errors(parameters: dict[str, Any], expected_errors: Expected @pytest.mark.parametrize("critical", (True, None)) -@pytest.mark.parametrize("crl_number", [0, 1, 2]) +@pytest.mark.parametrize("crl_number", (0, 1, 2)) def test_delta_crl_indicator(critical: Optional[bool], crl_number: int) -> None: """Test the DeltaCRLModel.""" assert_extension_model(DeltaCRLIndicatorModel, crl_number, x509.DeltaCRLIndicator(crl_number), critical) @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ({"value": -1}, [("greater_than_equal", ("value",), "Input should be greater than or equal to 0")]), ({"value": 0, "critical": False}, [MUST_BE_CRITICAL_ERROR]), @@ -893,7 +893,7 @@ def test_delta_crl_indicator_errors( @pytest.mark.parametrize("critical", (True, False, None)) @pytest.mark.parametrize( - "usages,extension", + ("usages", "extension"), ( ( [ExtendedKeyUsageOID.CLIENT_AUTH.dotted_string, ExtendedKeyUsageOID.SERVER_AUTH.dotted_string], @@ -916,7 +916,7 @@ def test_extended_key_usage( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( [], @@ -954,7 +954,7 @@ def test_extended_key_usage_errors(parameters: dict[str, Any], expected_errors: @pytest.mark.parametrize("critical", (False, None)) -@pytest.mark.parametrize("parameters,distribution_points", DISTRIBUTION_POINTS_PARAMETERS) +@pytest.mark.parametrize(("parameters", "distribution_points"), DISTRIBUTION_POINTS_PARAMETERS) def test_freshest_crl( critical: Optional[bool], parameters: list[dict[str, Any]], @@ -977,14 +977,14 @@ def test_freshest_crl_critical_error() -> None: @pytest.mark.parametrize("critical", (True, None)) -@pytest.mark.parametrize("skip_certs", [0, 1]) +@pytest.mark.parametrize("skip_certs", (0, 1)) def test_inhibit_any_policy(critical: Optional[bool], skip_certs: int) -> None: """Test the InhibitAnyPolicyModel.""" assert_extension_model(InhibitAnyPolicyModel, skip_certs, x509.InhibitAnyPolicy(skip_certs), critical) @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ({"value": -1}, [("greater_than_equal", ("value",), "Input should be greater than or equal to 0")]), ({"value": 0, "critical": False}, [MUST_BE_CRITICAL_ERROR]), @@ -997,7 +997,7 @@ def test_inhibit_any_policy_errors(parameters: dict[str, Any], expected_errors: @pytest.mark.parametrize("critical", (True, None)) @pytest.mark.parametrize( - "parameters,issuing_distribution_point", + ("parameters", "issuing_distribution_point"), ( ( {"full_name": [GENERAL_NAME]}, @@ -1023,7 +1023,7 @@ def test_issuing_distribution_point( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ({"value": {}}, [("value_error", ("value",), "Value error, cannot create empty extension")]), ( @@ -1068,7 +1068,7 @@ def test_issuing_distribution_point_errors( @pytest.mark.parametrize("critical", (True, False, None)) @pytest.mark.parametrize( - "parameters,extension", + ("parameters", "extension"), ( (["crl_sign"], key_usage(crl_sign=True).value), ( @@ -1083,7 +1083,7 @@ def test_key_usage(critical: Optional[bool], parameters: dict[str, bool], extens @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( [], @@ -1128,7 +1128,7 @@ def test_key_usage_errors(parameters: dict[str, bool], expected_errors: Expected @pytest.mark.parametrize("critical", (True, False)) @pytest.mark.parametrize( - "parameters,extension", + ("parameters", "extension"), ( ( {"template_id": NameOID.COMMON_NAME.dotted_string}, @@ -1157,7 +1157,7 @@ def test_ms_certificate_template( @pytest.mark.parametrize("critical", (True, None)) @pytest.mark.parametrize( - "parameters,extension", + ("parameters", "extension"), ( ( {"permitted_subtrees": [GENERAL_NAME]}, @@ -1186,7 +1186,7 @@ def test_name_constraints( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( {"value": {}}, @@ -1238,7 +1238,8 @@ def test_name_constraints_errors(parameters: dict[str, bool], expected_errors: E @pytest.mark.parametrize("critical", (True, None)) @pytest.mark.parametrize( - "require_explicit_policy,inhibit_policy_mapping", ((0, 0), (1, 1), (0, 5), (5, 0), (None, 0), (0, None)) + ("require_explicit_policy", "inhibit_policy_mapping"), + ((0, 0), (1, 1), (0, 5), (5, 0), (None, 0), (0, None)), ) def test_policy_constraints( critical: Optional[bool], require_explicit_policy: int, inhibit_policy_mapping: int @@ -1255,7 +1256,7 @@ def test_policy_constraints( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( {"value": {"require_explicit_policy": None, "inhibit_policy_mapping": None}}, @@ -1333,7 +1334,7 @@ def test_signed_certificate_timestamps(signed_certificate_timestamps_pub: x509.C @pytest.mark.parametrize( - "parameters,extension", + ("parameters", "extension"), ( ( [{"access_method": "ca_repository", "access_location": GENERAL_NAME}], @@ -1352,7 +1353,7 @@ def test_subject_information_access( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ({"value": []}, [("value_error", ("value",), "Value error, value must not be empty")]), ( @@ -1392,7 +1393,7 @@ def test_subject_information_access_errors( @pytest.mark.parametrize( - "digest,extension", + ("digest", "extension"), ( # (b"123", x509.SubjectKeyIdentifier(b"123")), (b"kA==", x509.SubjectKeyIdentifier(b"\x90")), @@ -1416,7 +1417,7 @@ def test_subject_key_identifier_errors() -> None: @pytest.mark.parametrize( - "parameters,features", + ("parameters", "features"), ( (["status_request"], [x509.TLSFeatureType.status_request]), (["OCSPMustStaple"], [x509.TLSFeatureType.status_request]), @@ -1449,7 +1450,7 @@ def test_tls_feature( @pytest.mark.parametrize( - "parameters,expected_errors", + ("parameters", "expected_errors"), ( ( {"value": []}, @@ -1471,7 +1472,7 @@ def test_tls_feature_errors(parameters: dict[str, bool], expected_errors: Expect @pytest.mark.parametrize( - "parameters,extension_type", + ("parameters", "extension_type"), ( ( {"value": b"MTIz", "oid": "1.2.3"}, diff --git a/ca/django_ca/tests/pydantic/test_general_name.py b/ca/django_ca/tests/pydantic/test_general_name.py index 0611e3aba..c285de8f5 100644 --- a/ca/django_ca/tests/pydantic/test_general_name.py +++ b/ca/django_ca/tests/pydantic/test_general_name.py @@ -41,7 +41,7 @@ def test_doctests() -> None: @pytest.mark.parametrize( - "typ,value,encoded", + ("typ", "value", "encoded"), ( ("UTF8", "example", b"\x0c\x07example"), ("UTF8String", "example", b"\x0c\x07example"), @@ -126,7 +126,7 @@ def test_other_name_octetstring_type_errors() -> None: @pytest.mark.parametrize( - "value,match", + ("value", "match"), ( (b"123", r"Value error, could not parse asn1 data: .*"), (b"\x03\x02\x04P", "3: Unknown otherName type found."), @@ -145,7 +145,7 @@ def test_othername_general_errors(value: bytes, match: str) -> None: @pytest.mark.parametrize( - "parameters,name,discriminated", + ("parameters", "name", "discriminated"), ( ({"type": "DNS", "value": "example.com"}, dns("example.com"), str), # 0 ({"type": "DNS", "value": "xn--exmple-cua.com"}, dns("xn--exmple-cua.com"), str), # 1 @@ -186,7 +186,7 @@ def test_general_name(parameters: dict[str, Any], name: x509.GeneralName, discri @pytest.mark.parametrize( - "typ,value,errors", + ("typ", "value", "errors"), ( ("URI", 123, [("string_type", ("value", "str"), "Input should be a valid string")]), ("email", 123, [("string_type", ("value", "str"), "Input should be a valid string")]), diff --git a/ca/django_ca/tests/pydantic/test_name.py b/ca/django_ca/tests/pydantic/test_name.py index caa700e39..c50db0102 100644 --- a/ca/django_ca/tests/pydantic/test_name.py +++ b/ca/django_ca/tests/pydantic/test_name.py @@ -33,7 +33,7 @@ def test_doctests() -> None: @pytest.mark.parametrize( - "parameters,name_attr", + ("parameters", "name_attr"), ( ( {"oid": NameOID.COMMON_NAME.dotted_string, "value": "example.com"}, @@ -81,7 +81,7 @@ def test_name_attribute(parameters: dict[str, Any], name_attr: x509.NameAttribut @pytest.mark.parametrize( - "parameters,errors", + ("parameters", "errors"), ( ( {"oid": "foo", "value": "example.com"}, @@ -123,7 +123,7 @@ def test_name_attribute_empty_common_name(oid: Any) -> None: @pytest.mark.parametrize( - "serialized,expected", + ("serialized", "expected"), ( ([], x509.Name([])), ( @@ -150,7 +150,7 @@ def test_name(serialized: list[dict[str, Any]], expected: list[x509.NameAttribut @pytest.mark.parametrize( - "value,errors", + ("value", "errors"), ( ( [ diff --git a/ca/django_ca/tests/pydantic/test_type_aliases.py b/ca/django_ca/tests/pydantic/test_type_aliases.py index 3a5813a9e..f733ac21c 100644 --- a/ca/django_ca/tests/pydantic/test_type_aliases.py +++ b/ca/django_ca/tests/pydantic/test_type_aliases.py @@ -53,7 +53,7 @@ class SerialModel(BaseModel): value: Serial -@pytest.mark.parametrize("name,curve_cls", constants.ELLIPTIC_CURVE_TYPES.items()) +@pytest.mark.parametrize(("name", "curve_cls"), constants.ELLIPTIC_CURVE_TYPES.items()) def test_elliptic_curve(name: str, curve_cls: type[ec.EllipticCurve]) -> None: """Test EllipticCurveTypeAliasModel.""" model = EllipticCurveTypeAliasModel(value=name) @@ -79,11 +79,11 @@ def test_elliptic_curve(name: str, curve_cls: type[ec.EllipticCurve]) -> None: @pytest.mark.parametrize("value", ("", "wrong", True, 42, ec.SECP224R1)) def test_elliptic_curve_errors(value: str) -> None: """Test invalid values for EllipticCurveTypeAliasModel.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message EllipticCurveTypeAliasModel(value=value) -@pytest.mark.parametrize("name,hash_cls", constants.HASH_ALGORITHM_TYPES.items()) +@pytest.mark.parametrize(("name", "hash_cls"), constants.HASH_ALGORITHM_TYPES.items()) def test_hash_algorithm(name: str, hash_cls: type[hashes.HashAlgorithm]) -> None: """Test EllipticCurveTypeAliasModel.""" model = HashAlgorithmTypeAliasModel(value=name) @@ -110,12 +110,12 @@ def test_hash_algorithm(name: str, hash_cls: type[hashes.HashAlgorithm]) -> None @pytest.mark.parametrize("hash_obj", (hashes.SM3(), hashes.BLAKE2b(64), hashes.BLAKE2s(32))) def test_hash_algorithm_unsupported_types(hash_obj: hashes.HashAlgorithm) -> None: """Test that unsupported hash algorithm instances throw an error.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message HashAlgorithmTypeAliasModel(value=hash_obj) @pytest.mark.parametrize( - "value,encoded", + ("value", "encoded"), ( (b"\xb5\xee\x0e\x01\x10U", "te4OARBV"), (b"\xb5\xee\x0e\x01\x10U\xaa", "te4OARBVqg=="), @@ -140,7 +140,7 @@ def test_json_serializable_bytes(value: bytes, encoded: str) -> None: @pytest.mark.parametrize( - "value,validated", + ("value", "validated"), ( ("a", "A"), ("abc", "ABC"), @@ -176,5 +176,5 @@ def test_serial(value: str, validated: str) -> None: ) def test_serial_errors(value: str) -> None: """Test invalid values for the Serial type alias.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 # pydantic controls the message SerialModel(value=value) diff --git a/ca/django_ca/tests/pydantic/test_validators.py b/ca/django_ca/tests/pydantic/test_validators.py index 96ed2335c..653d1ee74 100644 --- a/ca/django_ca/tests/pydantic/test_validators.py +++ b/ca/django_ca/tests/pydantic/test_validators.py @@ -26,8 +26,8 @@ def test_doctests() -> None: @pytest.mark.parametrize( - "name,validated", - [ + ("name", "validated"), + ( ("example.com", "example.com"), ("er.tl", "er.tl"), ("exämple.com", "xn--exmple-cua.com"), @@ -38,7 +38,7 @@ def test_doctests() -> None: # Examples from Wikipedia: ("ουτοπία.δπθ.gr", "xn--kxae4bafwg.xn--pxaix.gr"), ("bücher.example", "xn--bcher-kva.example"), - ], + ), ) def test_dns_validator(name: str, validated: str) -> None: """Test :py:func:`django_ca.pydantic.validators.dns_validator`.""" @@ -46,8 +46,8 @@ def test_dns_validator(name: str, validated: str) -> None: @pytest.mark.parametrize( - "name,error", - [("example com", "^Invalid domain: example com:"), ("@example.com", r"^Invalid domain: @example.com:")], + ("name", "error"), + (("example com", "^Invalid domain: example com:"), ("@example.com", r"^Invalid domain: @example.com:")), ) def test_dns_validator_errors(name: str, error: str) -> None: """Test errors for :py:func:`django_ca.pydantic.validators.dns_validator`.""" @@ -56,8 +56,8 @@ def test_dns_validator_errors(name: str, error: str) -> None: @pytest.mark.parametrize( - "email,validated", - [("user@example.com", "user@example.com"), ("user@exämple.com", "user@xn--exmple-cua.com")], + ("email", "validated"), + (("user@example.com", "user@example.com"), ("user@exämple.com", "user@xn--exmple-cua.com")), ) def test_email_validator(email: str, validated: str) -> None: """Test :py:func:`django_ca.pydantic.validators.email_validator`.""" @@ -65,13 +65,13 @@ def test_email_validator(email: str, validated: str) -> None: @pytest.mark.parametrize( - "email,error", - [ + ("email", "error"), + ( ("user@example com", "^Invalid domain: example com"), ("user", "^Invalid email address: user$"), ("example.com", r"^Invalid email address: example\.com$"), ("@example.com", r"^@example.com: node part is empty$"), - ], + ), ) def test_email_validator_errors(email: str, error: str) -> None: """Test errors for :py:func:`django_ca.pydantic.validators.email_validator`.""" @@ -80,8 +80,8 @@ def test_email_validator_errors(email: str, error: str) -> None: @pytest.mark.parametrize( - "url,validated", - [ + ("url", "validated"), + ( ("http://example.com", "http://example.com"), ("http://exämple.com", "http://xn--exmple-cua.com"), ("https://www.example.net", "https://www.example.net"), @@ -91,7 +91,7 @@ def test_email_validator_errors(email: str, error: str) -> None: ("https://www.exämple.net:443", "https://www.xn--exmple-cua.net:443"), ("https://www.example.net:443/", "https://www.example.net:443/"), ("https://www.example.net:443/test", "https://www.example.net:443/test"), - ], + ), ) def test_url_validator(url: str, validated: str) -> None: """Test py:func:`django_ca.pydantic.validators.url_validator`.""" @@ -99,15 +99,15 @@ def test_url_validator(url: str, validated: str) -> None: @pytest.mark.parametrize( - "url,error", - [ + ("url", "error"), + ( ("https://example com", "^Invalid domain: example com: "), ("https://example com:80", "^Invalid domain: example com: "), ("example.com", r"^URL requires scheme and network location: example\.com$"), ("https://[abc", r"^Could not parse URL: https://\[abc: "), # urlsplit() raises an error for this ("https://example.com:abc", r"^Invalid port: https://example\.com:abc: "), # reading port... ("https://example.com:-1", r"^Invalid port: https://example\.com:-1: "), - ], + ), ) def test_url_validator_errors(url: str, error: str) -> None: """Test errors for :py:func:`django_ca.pydantic.validators.url_validator`.""" diff --git a/ca/django_ca/tests/test_acme.py b/ca/django_ca/tests/test_acme.py index 3181198c9..a342c1da8 100644 --- a/ca/django_ca/tests/test_acme.py +++ b/ca/django_ca/tests/test_acme.py @@ -35,6 +35,7 @@ from django_ca.acme import validation from django_ca.acme.constants import IdentifierType, Status from django_ca.models import AcmeAccount, AcmeAuthorization, AcmeChallenge, AcmeOrder +from django_ca.tests.base.assertions import assert_count_equal from django_ca.tests.base.mixins import TestCaseMixin urlpatterns = [ @@ -61,12 +62,12 @@ class TestConstantsTestCase(TestCase): def test_status_enum(self) -> None: """Test that the Status Enum is equivalent to the main ACME library.""" expected = [*acme.messages.Status.POSSIBLE_NAMES, "expired"] - self.assertCountEqual(expected, [s.value for s in Status]) + assert_count_equal(expected, [s.value for s in Status]) def test_identifier_enum(self) -> None: """Test that the IdentifierType Enum is equivalent to the main ACME library.""" actual = list(acme.messages.IdentifierType.POSSIBLE_NAMES) - self.assertCountEqual(actual, [s.value for s in IdentifierType]) + assert_count_equal(actual, [s.value for s in IdentifierType]) class Dns01ValidationTestCase(TestCaseMixin, TestCase): @@ -98,7 +99,7 @@ def assertLogMessages( # pylint: disable=invalid-name # unittest standard if challenge is None: challenge = self.chall - self.assertEqual(logcm.output, [self.get_log_message(challenge), *messages]) + assert logcm.output == [self.get_log_message(challenge), *messages] def get_log_message(self, chall: AcmeChallenge) -> str: """Get the default log message for DNS-01 validation.""" @@ -121,7 +122,7 @@ def mock_response(self, domain: str, *responses: Iterable[bytes]) -> Iterator[mo # Note: Only assert the first two parameters, as otherwise we'd test dnspython internals resolve_mock.assert_called_once() expected = (f"_acme_challenge.{domain}", "TXT") - self.assertEqual(resolve_mock.call_args_list[0].args[:2], expected) + assert resolve_mock.call_args_list[0].args[:2] == expected @contextmanager def resolve(self, side_effect: Any) -> Iterator[mock.Mock]: @@ -136,16 +137,16 @@ def to_txt_record(self, values: Iterable[bytes]) -> TXTBase: def test_validation(self) -> None: """Test successful DNS-01 validation.""" with self.mock_response(self.domain, [self.chall.expected]), self.assertLogMessages(): - self.assertTrue(validation.validate_dns_01(self.chall)) + assert validation.validate_dns_01(self.chall) with self.mock_response(self.domain, [self.chall.expected, b"foo"]), self.assertLogMessages(): - self.assertTrue(validation.validate_dns_01(self.chall)) + assert validation.validate_dns_01(self.chall) with self.mock_response(self.domain, [b"data"], [self.chall.expected]), self.assertLogMessages(): - self.assertTrue(validation.validate_dns_01(self.chall)) + assert validation.validate_dns_01(self.chall) with ( self.mock_response(self.domain, [b"data"], [b"multiple", self.chall.expected]), self.assertLogMessages(), ): - self.assertTrue(validation.validate_dns_01(self.chall)) + assert validation.validate_dns_01(self.chall) def test_precomputed(self) -> None: """Runa test with pre-computed values to test basic behavior.""" @@ -161,33 +162,33 @@ def test_precomputed(self) -> None: expected = chall.expected with self.mock_response(self.domain, [chall.expected]), self.assertLogMessages(challenge=chall): - self.assertTrue(validation.validate_dns_01(chall)) + assert validation.validate_dns_01(chall) with self.mock_response(self.domain, [expected, b"foo"]), self.assertLogMessages(challenge=chall): - self.assertTrue(validation.validate_dns_01(chall)) + assert validation.validate_dns_01(chall) with self.mock_response(self.domain, [b"data"], [expected]), self.assertLogMessages(challenge=chall): - self.assertTrue(validation.validate_dns_01(chall)) + assert validation.validate_dns_01(chall) with ( self.mock_response(self.domain, [b"data"], [b"foo", expected]), self.assertLogMessages(challenge=chall), ): - self.assertTrue(validation.validate_dns_01(chall)) + assert validation.validate_dns_01(chall) def test_wrong_txt_response(self) -> None: """Test failing a challenge via the wrong DNS response.""" with self.mock_response(self.domain, [b"foo"]), self.assertLogMessages(): - self.assertFalse(validation.validate_dns_01(self.chall)) + assert not validation.validate_dns_01(self.chall) with self.mock_response(self.domain, [b"foo"], [b"bar"]), self.assertLogMessages(): - self.assertFalse(validation.validate_dns_01(self.chall)) + assert not validation.validate_dns_01(self.chall) with self.mock_response(self.domain, [b"foo", b"bar"], [b"bar"]), self.assertLogMessages(): - self.assertFalse(validation.validate_dns_01(self.chall)) + assert not validation.validate_dns_01(self.chall) def test_dns_exception(self) -> None: """Mock resolver throwing a DNS exception.""" with self.resolve(side_effect=dns.exception.DNSException) as resolve, self.assertLogs() as logcm: - self.assertFalse(validation.validate_dns_01(self.chall)) + assert not validation.validate_dns_01(self.chall) resolve.assert_called_once_with(f"_acme_challenge.{self.domain}", "TXT", lifetime=1, search=False) - self.assertEqual(len(logcm.output), 2) - self.assertIn("dns.exception.DNSException", logcm.output[1]) + assert len(logcm.output) == 2 + assert "dns.exception.DNSException" in logcm.output[1] def test_nxdomain(self) -> None: """Test validating a domain where the record simply does not exist.""" @@ -197,12 +198,12 @@ def test_nxdomain(self) -> None: f"DEBUG:django_ca.acme.validation:TXT _acme_challenge.{self.domain}: record does not exist." ), ): - self.assertFalse(validation.validate_dns_01(self.chall)) + assert not validation.validate_dns_01(self.chall) resolve.assert_called_once_with(f"_acme_challenge.{self.domain}", "TXT", lifetime=1, search=False) def test_wrong_acme_challenge(self) -> None: """Test passing an ACME challenge of the wrong type.""" - with self.assertRaisesRegex(ValueError, r"^This function can only validate DNS-01 challenges$"): + with pytest.raises(ValueError, match=r"^This function can only validate DNS-01 challenges$"): validation.validate_dns_01(AcmeChallenge(type=AcmeChallenge.TYPE_HTTP_01)) - with self.assertRaisesRegex(ValueError, r"^This function can only validate DNS-01 challenges$"): + with pytest.raises(ValueError, match=r"^This function can only validate DNS-01 challenges$"): validation.validate_dns_01(AcmeChallenge(type=AcmeChallenge.TYPE_TLS_ALPN_01)) diff --git a/ca/django_ca/tests/test_base.py b/ca/django_ca/tests/test_base.py index a2a31145d..63f1292d8 100644 --- a/ca/django_ca/tests/test_base.py +++ b/ca/django_ca/tests/test_base.py @@ -23,6 +23,8 @@ from django.conf import settings from django.test import TestCase +import pytest + from django_ca.tests.base.assertions import assert_extensions from django_ca.tests.base.mixins import TestCaseMixin from django_ca.tests.base.utils import cmd, cmd_e2e, override_tmpcadir @@ -39,7 +41,7 @@ def test_pragmas(self) -> None: @override_tmpcadir() def test_override_tmpcadir(self) -> None: """Test override_tmpcadir as decorator.""" - self.assertTrue(settings.CA_DIR.startswith(tempfile.gettempdir())) + assert settings.CA_DIR.startswith(tempfile.gettempdir()) @override_tmpcadir() def test_assert_extensions(self) -> None: @@ -92,25 +94,25 @@ class OverrideCaDirForFuncTestCase(TestCaseMixin, TestCase): @override_tmpcadir() def test_a(self) -> None: # add three tests to make sure that every test case sees a different dir - self.assertTrue(settings.CA_DIR.startswith(tempfile.gettempdir())) - self.assertNotIn(settings.CA_DIR, self.seen_dirs) + assert settings.CA_DIR.startswith(tempfile.gettempdir()) + assert settings.CA_DIR not in self.seen_dirs self.seen_dirs.add(settings.CA_DIR) @override_tmpcadir() def test_b(self) -> None: - self.assertTrue(settings.CA_DIR.startswith(tempfile.gettempdir())) - self.assertNotIn(settings.CA_DIR, self.seen_dirs) + assert settings.CA_DIR.startswith(tempfile.gettempdir()) + assert settings.CA_DIR not in self.seen_dirs self.seen_dirs.add(settings.CA_DIR) @override_tmpcadir() def test_c(self) -> None: - self.assertTrue(settings.CA_DIR.startswith(tempfile.gettempdir())) - self.assertNotIn(settings.CA_DIR, self.seen_dirs) + assert settings.CA_DIR.startswith(tempfile.gettempdir()) + assert settings.CA_DIR not in self.seen_dirs self.seen_dirs.add(settings.CA_DIR) def test_no_classes(self) -> None: msg = r"^Only functions can use override_tmpcadir\(\)$" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): @override_tmpcadir() class Foo: # pylint: disable=missing-class-docstring,unused-variable @@ -126,8 +128,8 @@ def test_basic(self) -> None: """Trivial basic test.""" stdout, stderr = cmd_e2e(["list_cas"]) serial = add_colons(self.ca.serial) - self.assertEqual(stdout, f"{serial} - {self.ca.name}\n") - self.assertEqual(stderr, "") + assert stdout == f"{serial} - {self.ca.name}\n" + assert stderr == "" class TypingTestCase(TestCaseMixin): # never executed as it's not actually a subclass of TestCase diff --git a/ca/django_ca/tests/test_checks.py b/ca/django_ca/tests/test_checks.py index 5d22df5c8..d9b7e222a 100644 --- a/ca/django_ca/tests/test_checks.py +++ b/ca/django_ca/tests/test_checks.py @@ -34,11 +34,11 @@ def test_no_cache(self) -> None: ) with self.settings(CACHES={}): errors = check_cache([app_config]) - self.assertEqual(errors, [expected]) + assert errors == [expected] with self.settings(CACHES={}): errors = check_cache(None) - self.assertEqual(errors, [expected]) + assert errors == [expected] def test_loc_mem_cache(self) -> None: """Test what happens if LocMemCache is used.""" @@ -55,16 +55,16 @@ def test_loc_mem_cache(self) -> None: } with self.settings(CACHES=setting): errors = check_cache([app_config]) - self.assertEqual(errors, [expected]) + assert errors == [expected] with self.settings(CACHES=setting): errors = check_cache(None) - self.assertEqual(errors, [expected]) + assert errors == [expected] def test_django_ca_not_checked(self) -> None: """Test that no checks are run if django_ca is not checked.""" app_config = apps.get_app_config("auth") errors = check_cache([app_config]) - self.assertEqual(errors, []) + assert not errors def test_redis_cache(self) -> None: """Test if redis cache backend is used.""" @@ -76,4 +76,4 @@ def test_redis_cache(self) -> None: } with self.settings(CACHES=setting): errors = check_cache([app_config]) - self.assertEqual(errors, []) + assert not errors diff --git a/ca/django_ca/tests/test_fields.py b/ca/django_ca/tests/test_fields.py index f095d63d3..c0bfbe917 100644 --- a/ca/django_ca/tests/test_fields.py +++ b/ca/django_ca/tests/test_fields.py @@ -11,11 +11,11 @@ # You should have received a copy of the GNU General Public License along with django-ca. If not, see # . -# TYPEHINT NOTE: mypy-django typehints assertFieldOutput complete wrong. -# type: ignore - """Test custom Django form fields.""" +# TYPEHINT NOTE: mypy-django typehints assertFieldOutput completely wrong. +# mypy: ignore-errors + import html import json from typing import Any @@ -69,21 +69,21 @@ def assertRequiredError(self, value) -> None: # pylint: disable=invalid-name field = self.field_class(required=True) error_required = [field.error_messages["required"]] - with self.assertRaises(ValidationError) as context_manager: + with pytest.raises(ValidationError) as context_manager: field.clean(value) - self.assertEqual(context_manager.exception.messages, error_required) + assert context_manager.exception.messages == error_required @pytest.mark.parametrize("critical", (True, False)) @pytest.mark.parametrize("required", (True, False)) @pytest.mark.parametrize( - "field_class,extension_type", + ("field_class", "extension_type"), ( (fields.IssuerAlternativeNameField, x509.IssuerAlternativeName), (fields.SubjectAlternativeNameField, x509.SubjectAlternativeName), ), ) -@pytest.mark.parametrize("value,general_names", (([SER_D1], [DNS1]), ([SER_D1, SER_D2], [DNS1, DNS2]))) +@pytest.mark.parametrize(("value", "general_names"), (([SER_D1], [DNS1]), ([SER_D1, SER_D2], [DNS1, DNS2]))) def test_alternative_name_fields( critical: bool, required: bool, @@ -101,14 +101,14 @@ def test_alternative_name_fields( @pytest.mark.parametrize("critical", (True, False)) @pytest.mark.parametrize("required", (True, False)) @pytest.mark.parametrize( - "field_class,extension_type", + ("field_class", "extension_type"), ( (fields.CRLDistributionPointField, x509.CRLDistributionPoints), (fields.FreshestCRLField, x509.FreshestCRL), ), ) @pytest.mark.parametrize( - "value,dpoint", + ("value", "dpoint"), ( (([SER_D1], "", [], ()), distribution_point([DNS1])), (([SER_D1, SER_D2], "", [], ()), (distribution_point([DNS1, DNS2]))), @@ -169,7 +169,7 @@ def test_distribution_point_fields( @pytest.mark.parametrize("critical", (True, False)) @pytest.mark.parametrize("required", (True, False)) @pytest.mark.parametrize( - "invalid,error", + ("invalid", "error"), ( (([SER_D1], f"CN={D1}", [], ()), r"You cannot provide both full_name and relative_name\."), ( @@ -210,7 +210,7 @@ def test_crl_distribution_points_field_with_empty_input( # Test how the field is rendered name = "field-name" - raw_html = field.widget.render(name, None) + raw_html = field.widget.render(name=name, value=None) assertInHTML( f'', raw_html ) @@ -228,8 +228,8 @@ def test_crl_distribution_points_field_rendering() -> None: field = fields.CRLDistributionPointField() reasons = frozenset([x509.ReasonFlags.key_compromise, x509.ReasonFlags.certificate_hold]) raw_html = field.widget.render( - name, - crl_distribution_points(distribution_point([DNS1], crl_issuer=[DNS2], reasons=reasons)), + name=name, + value=crl_distribution_points(distribution_point([DNS1], crl_issuer=[DNS2], reasons=reasons)), ) full_name_value = html.escape(json.dumps([SER_D1])) @@ -298,7 +298,7 @@ def test_crl_distribution_points_field_rendering_with_multiple_dps() -> None: @pytest.mark.parametrize("critical", (True, False)) @pytest.mark.parametrize("required", (True, False)) @pytest.mark.parametrize( - "ser_ca_issuers,ser_ocsp,ca_issuers,ocsp", + ("ser_ca_issuers", "ser_ocsp", "ca_issuers", "ocsp"), ( ((SER_D1,), (), (DNS1,), ()), ((), (SER_D2,), (), (DNS2,)), @@ -324,7 +324,7 @@ def test_authority_information_access_field( @pytest.mark.parametrize("critical", (True, False)) # make sure that critical flag has no effect @pytest.mark.parametrize("required", (True, False)) @pytest.mark.parametrize( - "ser_ca_issuers,ser_ocsp", + ("ser_ca_issuers", "ser_ocsp"), (("", ""), ("[]", "[]"), (None, None)), ) def test_authority_information_access_field_with_empty_value( @@ -336,7 +336,7 @@ def test_authority_information_access_field_with_empty_value( @pytest.mark.parametrize( - "ser_ca_issuers,ser_ocsp,error", + ("ser_ca_issuers", "ser_ocsp", "error"), ( (({"type": "DNS", "value": "http://example.com"},), (), ""), (({"type": "IP", "value": "example.com"},), (), "example.com: Could not parse IP address"), @@ -397,7 +397,7 @@ def test_rendering(self) -> None: name = "field-name" field = self.field_class() - raw_html = field.widget.render(name, None) + raw_html = field.widget.render(name=name, value=None) for choice, text in self.field_class.choices: self.assertInHTML(f'', raw_html) @@ -412,7 +412,7 @@ def test_rendering_profiles(self) -> None: choices = [key_usage_choices[choice] for choice in choices] ext = key_usage(**{choice: True for choice in choices}) - raw_html = field.widget.render("unused", ext) + raw_html = field.widget.render(name="unused", value=ext) for choice, text in self.field_class.choices: if choice in choices: diff --git a/ca/django_ca/tests/test_management_actions.py b/ca/django_ca/tests/test_management_actions.py index 68942c19c..130f4616e 100644 --- a/ca/django_ca/tests/test_management_actions.py +++ b/ca/django_ca/tests/test_management_actions.py @@ -108,12 +108,12 @@ def setUp(self) -> None: def assertValue(self, namespace: argparse.Namespace, value: Any) -> None: # pylint: disable=invalid-name """Assert a given extension value.""" extension = x509.Extension(oid=x509.SubjectAlternativeName.oid, critical=False, value=value) - self.assertEqual(namespace.alt, extension) + assert namespace.alt == extension def test_basic(self) -> None: """Test basic functionality.""" namespace = self.parser.parse_args([]) - self.assertEqual(namespace.alt, None) + assert namespace.alt is None namespace = self.parser.parse_args(["--alt", "example.com"]) self.assertValue(namespace, x509.SubjectAlternativeName([dns("example.com")])) @@ -138,15 +138,10 @@ def test_add_cps(self) -> None: oid = "1.2.3" cps = "http://example.com/cps" namespace = self.parser.parse_args(["--pi", oid, "--cps", cps]) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[cps] - ) - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[cps]) + ] ) def test_add_multiple_cps(self) -> None: @@ -155,15 +150,12 @@ def test_add_multiple_cps(self) -> None: cps1 = "http://example.com/cps1" cps2 = "http://example.com/cps2" namespace = self.parser.parse_args(["--pi", oid, "--cps", cps1, "--cps", cps2]) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[cps1, cps2] - ) - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[cps1, cps2] + ) + ] ) def test_add_multiple_cps_to_different_policy_identifiers(self) -> None: @@ -173,18 +165,15 @@ def test_add_multiple_cps_to_different_policy_identifiers(self) -> None: cps1 = "http://example.com/cps1" cps2 = "http://example.com/cps2" namespace = self.parser.parse_args(["--pi", oid1, "--cps", cps1, "--pi", oid2, "--cps", cps2]) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid1), policy_qualifiers=[cps1] - ), - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid2), policy_qualifiers=[cps2] - ), - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier(oid1), policy_qualifiers=[cps1] + ), + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier(oid2), policy_qualifiers=[cps2] + ), + ] ) def test_missing_policy_identifier(self) -> None: @@ -217,21 +206,21 @@ def setUp(self) -> None: def test_basic(self) -> None: """Test basic functionality of action.""" namespace = self.parser.parse_args([]) - self.assertIsNone(namespace.eku) + assert namespace.eku is None namespace = self.parser.parse_args(["--eku", "clientAuth"]) - self.assertEqual(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), namespace.eku) + assert x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]) == namespace.eku namespace = self.parser.parse_args(["--eku", "clientAuth", "serverAuth"]) - self.assertEqual( - x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH]), - namespace.eku, + assert ( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH]) + == namespace.eku ) def test_dotted_string_value(self) -> None: """Test passing a dotted string.""" namespace = self.parser.parse_args(["--eku", "1.3.6.1.5.5.7.3.2"]) - self.assertEqual(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), namespace.eku) + assert x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]) == namespace.eku def test_duplicate_values(self) -> None: """Test wrong option values.""" @@ -265,13 +254,10 @@ def test_policy_identifier(self) -> None: """Basic test for adding a policy identifier.""" oid = "1.2.3" namespace = self.parser.parse_args(["--pi", oid]) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[]) - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid), policy_qualifiers=[]) + ] ) def test_multiple_policy_identifiers(self) -> None: @@ -279,18 +265,11 @@ def test_multiple_policy_identifiers(self) -> None: oid1 = "1.2.3" oid2 = "1.2.4" namespace = self.parser.parse_args(["--pi", oid1, "--pi", oid2]) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid1), policy_qualifiers=[] - ), - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid2), policy_qualifiers=[] - ), - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid1), policy_qualifiers=[]), + x509.PolicyInformation(policy_identifier=x509.ObjectIdentifier(oid2), policy_qualifiers=[]), + ] ) def test_any_policy_value_disallowed(self) -> None: @@ -309,15 +288,12 @@ def test_any_policy_value(self) -> None: oid = "anyPolicy" namespace = parser.parse_args(["--pi", oid]) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier("2.5.29.32.0"), policy_qualifiers=[] - ) - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier("2.5.29.32.0"), policy_qualifiers=[] + ) + ] ) def test_invalid_dotted_string(self) -> None: @@ -337,16 +313,16 @@ def test_no_min_no_max(self) -> None: """Test action with no min/max values.""" parser = argparse.ArgumentParser() parser.add_argument("--value", action=actions.IntegerRangeAction) - self.assertEqual(parser.parse_args(["--value=0"]).value, 0) - self.assertEqual(parser.parse_args(["--value=1"]).value, 1) - self.assertEqual(parser.parse_args(["--value=-1"]).value, -1) + assert parser.parse_args(["--value=0"]).value == 0 + assert parser.parse_args(["--value=1"]).value == 1 + assert parser.parse_args(["--value=-1"]).value == -1 def test_min_values(self) -> None: """Test the min value for the action.""" self.parser = argparse.ArgumentParser() self.parser.add_argument("--value", action=actions.IntegerRangeAction, min=0) - self.assertEqual(self.parser.parse_args(["--value=0"]).value, 0) - self.assertEqual(self.parser.parse_args(["--value=1"]).value, 1) + assert self.parser.parse_args(["--value=0"]).value == 0 + assert self.parser.parse_args(["--value=1"]).value == 1 assert_parser_error( self.parser, ["--value=-1"], @@ -358,8 +334,8 @@ def test_max_values(self) -> None: """Test the max value for the action.""" self.parser = argparse.ArgumentParser() self.parser.add_argument("--value", action=actions.IntegerRangeAction, max=0) - self.assertEqual(self.parser.parse_args(["--value=0"]).value, 0) - self.assertEqual(self.parser.parse_args(["--value=-1"]).value, -1) + assert self.parser.parse_args(["--value=0"]).value == 0 + assert self.parser.parse_args(["--value=-1"]).value == -1 assert_parser_error( self.parser, ["--value=1"], @@ -379,12 +355,10 @@ def setUp(self) -> None: def test_basic(self) -> None: """Test basic functionality of action.""" namespace = self.parser.parse_args(["--key-usage", "keyCertSign"]) - self.assertEqual(key_usage(key_cert_sign=True, critical=False).value, namespace.key_usage) + assert key_usage(key_cert_sign=True, critical=False).value == namespace.key_usage namespace = self.parser.parse_args(["--key-usage", "keyCertSign", "keyAgreement"]) - self.assertEqual( - key_usage(key_cert_sign=True, key_agreement=True, critical=False).value, namespace.key_usage - ) + assert key_usage(key_cert_sign=True, key_agreement=True, critical=False).value == namespace.key_usage def test_invalid_values(self) -> None: """Test passing invalid values.""" @@ -448,15 +422,15 @@ def setUp(self) -> None: def test_basic(self) -> None: """Test basic functionality of action.""" namespace = self.parser.parse_args(["--tls-feature", "status_request"]) - self.assertEqual(x509.TLSFeature([x509.TLSFeatureType.status_request]), namespace.tls_feature) + assert x509.TLSFeature([x509.TLSFeatureType.status_request]) == namespace.tls_feature namespace = self.parser.parse_args(["--tls-feature", "status_request_v2"]) - self.assertEqual(x509.TLSFeature([x509.TLSFeatureType.status_request_v2]), namespace.tls_feature) + assert x509.TLSFeature([x509.TLSFeatureType.status_request_v2]) == namespace.tls_feature namespace = self.parser.parse_args(["--tls-feature", "status_request", "status_request_v2"]) - self.assertEqual( - x509.TLSFeature([x509.TLSFeatureType.status_request, x509.TLSFeatureType.status_request_v2]), - namespace.tls_feature, + assert ( + x509.TLSFeature([x509.TLSFeatureType.status_request, x509.TLSFeatureType.status_request_v2]) + == namespace.tls_feature ) def test_error(self) -> None: @@ -483,16 +457,13 @@ def test_add_notice(self) -> None: oid = "1.2.3" notice = "notice text" namespace = self.parser.parse_args(["--pi", oid, "--notice", notice]) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid), - policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice)], - ) - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier(oid), + policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice)], + ) + ] ) def test_add_multiple_notices(self) -> None: @@ -501,19 +472,16 @@ def test_add_multiple_notices(self) -> None: notice1 = "notice text one" notice2 = "notice text two" namespace = self.parser.parse_args(["--pi", oid, "--notice", notice1, "--notice", notice2]) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid), - policy_qualifiers=[ - x509.UserNotice(notice_reference=None, explicit_text=notice1), - x509.UserNotice(notice_reference=None, explicit_text=notice2), - ], - ) - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier(oid), + policy_qualifiers=[ + x509.UserNotice(notice_reference=None, explicit_text=notice1), + x509.UserNotice(notice_reference=None, explicit_text=notice2), + ], + ) + ] ) def test_add_multiple_cps_to_different_policy_identifiers(self) -> None: @@ -525,20 +493,17 @@ def test_add_multiple_cps_to_different_policy_identifiers(self) -> None: namespace = self.parser.parse_args( ["--pi", oid1, "--notice", notice1, "--pi", oid2, "--notice", notice2] ) - self.assertEqual( - namespace.pi, - x509.CertificatePolicies( - policies=[ - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid1), - policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice1)], - ), - x509.PolicyInformation( - policy_identifier=x509.ObjectIdentifier(oid2), - policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice2)], - ), - ] - ), + assert namespace.pi == x509.CertificatePolicies( + policies=[ + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier(oid1), + policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice1)], + ), + x509.PolicyInformation( + policy_identifier=x509.ObjectIdentifier(oid2), + policy_qualifiers=[x509.UserNotice(notice_reference=None, explicit_text=notice2)], + ), + ] ) def test_missing_policy_identifier(self) -> None: @@ -571,13 +536,13 @@ def setUp(self) -> None: def test_basic(self) -> None: """Test basic functionality of action.""" args = self.parser.parse_args(["--action=DER"]) - self.assertEqual(args.action, Encoding.DER) + assert args.action == Encoding.DER args = self.parser.parse_args(["--action=ASN1"]) - self.assertEqual(args.action, Encoding.DER) + assert args.action == Encoding.DER args = self.parser.parse_args(["--action=PEM"]) - self.assertEqual(args.action, Encoding.PEM) + assert args.action == Encoding.PEM def test_error(self) -> None: """Test false option values.""" @@ -601,13 +566,13 @@ def setUp(self) -> None: def test_basic(self) -> None: """Test basic functionality of action.""" args = self.parser.parse_args(["--curve=sect409k1"]) - self.assertIsInstance(args.curve, ec.SECT409K1) + assert isinstance(args.curve, ec.SECT409K1) args = self.parser.parse_args(["--curve=sect409r1"]) - self.assertIsInstance(args.curve, ec.SECT409R1) + assert isinstance(args.curve, ec.SECT409R1) args = self.parser.parse_args(["--curve=brainpoolP512r1"]) - self.assertIsInstance(args.curve, ec.BrainpoolP512R1) + assert isinstance(args.curve, ec.BrainpoolP512R1) def test_error(self) -> None: """Test false option values.""" @@ -633,10 +598,10 @@ def setUp(self) -> None: def test_basic(self) -> None: """Test basic functionality of action.""" args = self.parser.parse_args(["--algo=SHA-256"]) - self.assertIsInstance(args.algo, hashes.SHA256) + assert isinstance(args.algo, hashes.SHA256) args = self.parser.parse_args(["--algo=SHA-512"]) - self.assertIsInstance(args.algo, hashes.SHA512) + assert isinstance(args.algo, hashes.SHA512) def test_error(self) -> None: """Test false option values.""" @@ -663,10 +628,10 @@ def setUp(self) -> None: def test_basic(self) -> None: """Test basic functionality of action.""" args = self.parser.parse_args(["--size=2048"]) - self.assertEqual(args.size, 2048) + assert args.size == 2048 args = self.parser.parse_args(["--size=4096"]) - self.assertEqual(args.size, 4096) + assert args.size == 4096 def test_no_power_two(self) -> None: """Test giving values that are not the power of two.""" @@ -708,12 +673,12 @@ def setUp(self) -> None: def test_none(self) -> None: """Test passing no password option at all.""" args = self.parser.parse_args([]) - self.assertIsNone(args.password) + assert args.password is None def test_given(self) -> None: """Test giving a password on the command line.""" args = self.parser.parse_args(["--password=foobar"]) - self.assertEqual(args.password, b"foobar") + assert args.password == b"foobar" @mock.patch("getpass.getpass", spec_set=True, return_value="prompted") def test_output(self, getpass: mock.MagicMock) -> None: @@ -722,7 +687,7 @@ def test_output(self, getpass: mock.MagicMock) -> None: parser = argparse.ArgumentParser() parser.add_argument("--password", nargs="?", action=actions.PasswordAction, prompt=prompt) args = parser.parse_args(["--password"]) - self.assertEqual(args.password, b"prompted") + assert args.password == b"prompted" getpass.assert_called_once_with(prompt=prompt) @mock.patch("getpass.getpass", spec_set=True, return_value="prompted") @@ -731,7 +696,7 @@ def test_prompt(self, getpass: mock.MagicMock) -> None: parser = argparse.ArgumentParser() parser.add_argument("--password", nargs="?", action=actions.PasswordAction) args = parser.parse_args(["--password"]) - self.assertEqual(args.password, b"prompted") + assert args.password == b"prompted" getpass.assert_called_once() @@ -750,12 +715,12 @@ def test_basic(self) -> None: """Test basic functionality of action.""" for name, cert in self.certs.items(): args = self.parser.parse_args([CERT_DATA[name]["serial"]]) - self.assertEqual(args.cert, cert) + assert args.cert == cert def test_abbreviation(self) -> None: """Test using an abbreviation.""" args = self.parser.parse_args([CERT_DATA["root-cert"]["serial"][:6]]) - self.assertEqual(args.cert, self.certs["root-cert"]) + assert args.cert == self.certs["root-cert"] def test_missing(self) -> None: """Test giving an unknown cert.""" @@ -801,13 +766,13 @@ def test_basic(self) -> None: """Test basic functionality of action.""" for name, ca in self.usable_cas: args = self.parser.parse_args([CERT_DATA[name]["serial"]]) - self.assertEqual(args.ca, ca) + assert args.ca == ca @override_tmpcadir() def test_abbreviation(self) -> None: """Test using an abbreviation.""" args = self.parser.parse_args([CERT_DATA["ec"]["serial"][:6]]) - self.assertEqual(args.ca, self.cas["ec"]) + assert args.ca == self.cas["ec"] def test_missing(self) -> None: """Test giving an unknown CA.""" @@ -850,7 +815,7 @@ def test_disabled(self) -> None: parser.add_argument("ca", action=actions.CertificateAuthorityAction, allow_disabled=True) args = parser.parse_args([self.ca.serial]) - self.assertEqual(args.ca, self.ca) + assert args.ca == self.ca # TODO: re-enable with better checks # def test_private_key_does_not_exists(self) -> None: @@ -870,7 +835,7 @@ def test_disabled(self) -> None: def test_password(self) -> None: """Test that the action works with a password-encrypted CA.""" args = self.parser.parse_args([CERT_DATA["pwd"]["serial"]]) - self.assertEqual(args.ca, self.cas["pwd"]) + assert args.ca == self.cas["pwd"] class URLActionTestCase(ParserTestCaseMixin, TestCase): @@ -885,7 +850,7 @@ def test_basic(self) -> None: """Test basic functionality of action.""" for url in ["http://example.com", "https://www.example.org"]: args = self.parser.parse_args([f"--url={url}"]) - self.assertEqual(args.url, url) + assert args.url == url def test_error(self) -> None: """Test false option values.""" @@ -908,7 +873,7 @@ def test_basic(self) -> None: """Test basic functionality of action.""" expires = timedelta(days=30) args = self.parser.parse_args(["--expires=30"]) - self.assertEqual(args.expires, expires) + assert args.expires == expires def test_default(self) -> None: """Test using the default value.""" @@ -916,7 +881,7 @@ def test_default(self) -> None: parser = argparse.ArgumentParser() parser.add_argument("--expires", action=actions.ExpiresAction, default=delta) args = parser.parse_args([]) - self.assertEqual(args.expires, delta) + assert args.expires == delta def test_negative(self) -> None: """Test passing a negative value.""" @@ -950,7 +915,7 @@ def setUp(self) -> None: def test_basic(self) -> None: """Test basic functionality of action.""" args = self.parser.parse_args([ReasonFlags.unspecified.name]) - self.assertEqual(args.reason, ReasonFlags.unspecified) + assert args.reason == ReasonFlags.unspecified def test_error(self) -> None: """Test false option values.""" @@ -986,17 +951,17 @@ def test_basic(self) -> None: parser.add_argument("--url", action=actions.MultipleURLAction) args = parser.parse_args([f"--url={url}"]) - self.assertEqual(args.url, [url]) + assert args.url == [url] parser = argparse.ArgumentParser() parser.add_argument("--url", action=actions.MultipleURLAction) args = parser.parse_args([f"--url={urls[0]}", f"--url={urls[1]}"]) - self.assertEqual(args.url, urls) + assert args.url == urls def test_none(self) -> None: """Test passing no value at all.""" args = self.parser.parse_args([]) - self.assertEqual(args.url, []) + assert args.url == [] def test_error(self) -> None: """Test false option values.""" diff --git a/ca/django_ca/tests/test_migration_helpers.py b/ca/django_ca/tests/test_migration_helpers.py index 4bfed2065..7a7adacc7 100644 --- a/ca/django_ca/tests/test_migration_helpers.py +++ b/ca/django_ca/tests/test_migration_helpers.py @@ -26,7 +26,7 @@ @pytest.mark.parametrize( - "crl_url,full_name", + ("crl_url", "full_name"), ( ("https://example.com", [uri("https://example.com")]), ( @@ -53,7 +53,7 @@ def test_0040_crl_url_to_sign_crl_distribution_points( @pytest.mark.parametrize( - "issuer_alt_name,general_names", + ("issuer_alt_name", "general_names"), ( ("https://example.com", [uri("https://example.com")]), ("URI:https://example.com", [uri("https://example.com")]), @@ -75,7 +75,7 @@ def test_0040_issuer_alt_name_to_sign_issuer_alternative_name( @pytest.mark.parametrize( - "issuer_url,ocsp_url,access_descriptions", + ("issuer_url", "ocsp_url", "access_descriptions"), ( ( "https://issuer.example.com", @@ -133,7 +133,7 @@ def test_0040_ocsp_url_and_issuer_url_to_sign_authority_information_access( @pytest.mark.parametrize( - "distribution_points,crl_url", + ("distribution_points", "crl_url"), ( ([distribution_point([uri("https://example.com")])], "https://example.com"), ( @@ -189,7 +189,7 @@ def test_0040_backwards_sign_crl_distribution_points_to_crl_url( @pytest.mark.parametrize( - "issuer_alt_name,general_names", + ("issuer_alt_name", "general_names"), ( ("URI:https://example.com", [uri("https://example.com")]), # issuer_alt_name was a CharField, values where comma-separated. @@ -210,7 +210,7 @@ def test_0040_backwards_sign_issuer_alternative_name_to_issuer_url( @pytest.mark.parametrize( - "issuer_url,ocsp_url,access_descriptions", + ("issuer_url", "ocsp_url", "access_descriptions"), ( ( "https://issuer.example.com", diff --git a/ca/django_ca/tests/test_models.py b/ca/django_ca/tests/test_models.py index 318a601d2..32a7f2a9c 100644 --- a/ca/django_ca/tests/test_models.py +++ b/ca/django_ca/tests/test_models.py @@ -27,6 +27,7 @@ from django.test import RequestFactory, TestCase, override_settings from django.utils import timezone +import pytest from freezegun import freeze_time from django_ca.key_backends.storages import StoragesUsePrivateKeyOptions @@ -57,8 +58,8 @@ def test_from_addr(self) -> None: name = "Firstname Lastname" watcher = Watcher.from_addr(f"{name} <{mail}>") - self.assertEqual(watcher.mail, mail) - self.assertEqual(watcher.name, name) + assert watcher.mail == mail + assert watcher.name == name def test_spaces(self) -> None: """Test that ``from_addr() is agnostic to spaces.""" @@ -66,18 +67,18 @@ def test_spaces(self) -> None: name = "Firstname Lastname" watcher = Watcher.from_addr(f"{name} <{mail}>") - self.assertEqual(watcher.mail, mail) - self.assertEqual(watcher.name, name) + assert watcher.mail == mail + assert watcher.name == name watcher = Watcher.from_addr(f"{name}<{mail}>") - self.assertEqual(watcher.mail, mail) - self.assertEqual(watcher.name, name) + assert watcher.mail == mail + assert watcher.name == name def test_error(self) -> None: """Test some validation errors.""" - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Watcher.from_addr("foobar ") - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Watcher.from_addr("foobar @") def test_update(self) -> None: @@ -88,8 +89,8 @@ def test_update(self) -> None: Watcher.from_addr(f"{name} <{mail}>") watcher = Watcher.from_addr(f"{newname} <{mail}>") - self.assertEqual(watcher.mail, mail) - self.assertEqual(watcher.name, newname) + assert watcher.mail == mail + assert watcher.name == newname def test_str(self) -> None: """Test the str function.""" @@ -97,10 +98,10 @@ def test_str(self) -> None: name = "Firstname Lastname" watcher = Watcher(mail=mail) - self.assertEqual(str(watcher), mail) + assert str(watcher) == mail watcher.name = name - self.assertEqual(str(watcher), f"{name} <{mail}>") + assert str(watcher) == f"{name} <{mail}>" class ModelfieldsTests(TestCaseMixin, TestCase): @@ -121,14 +122,14 @@ def test_create_pem_bytes(self) -> None: not_after=timezone.now(), not_before=timezone.now(), ) - self.assertEqual(cert.pub, pub) - self.assertEqual(cert.csr, csr) + assert cert.pub == pub + assert cert.csr == csr # Refresh, so that we get lazy values cert.refresh_from_db() - self.assertEqual(cert.pub.loaded, self.pub["parsed"]) - self.assertEqual(cert.csr.loaded, self.csr["parsed"]) + assert cert.pub.loaded == self.pub["parsed"] + assert cert.csr.loaded == self.csr["parsed"] def test_create_bytearray(self) -> None: """Test creating with bytes-encoded PEM.""" @@ -141,14 +142,14 @@ def test_create_bytearray(self) -> None: not_after=timezone.now(), not_before=timezone.now(), ) - self.assertEqual(cert.pub, pub) - self.assertEqual(cert.csr, csr) + assert cert.pub == pub + assert cert.csr == csr # Refresh, so that we get lazy values cert.refresh_from_db() - self.assertEqual(cert.pub.loaded, self.pub["parsed"]) - self.assertEqual(cert.csr.loaded, self.csr["parsed"]) + assert cert.pub.loaded == self.pub["parsed"] + assert cert.csr.loaded == self.csr["parsed"] def test_create_memoryview(self) -> None: """Test creating with bytes-encoded PEM.""" @@ -161,20 +162,20 @@ def test_create_memoryview(self) -> None: not_after=timezone.now(), not_before=timezone.now(), ) - self.assertEqual(cert.pub, pub) - self.assertEqual(cert.csr, csr) + assert cert.pub == pub + assert cert.csr == csr # Refresh, so that we get lazy values cert.refresh_from_db() - self.assertEqual(cert.pub.loaded, self.pub["parsed"]) - self.assertEqual(cert.csr.loaded, self.csr["parsed"]) + assert cert.pub.loaded == self.pub["parsed"] + assert cert.csr.loaded == self.csr["parsed"] def test_create_from_instance(self) -> None: """Test creating a certificate from LazyField instances.""" loaded = self.load_named_cert("root-cert") - self.assertIsInstance(loaded.pub, LazyCertificate) - self.assertIsInstance(loaded.csr, LazyCertificateSigningRequest) + assert isinstance(loaded.pub, LazyCertificate) + assert isinstance(loaded.csr, LazyCertificateSigningRequest) cert = Certificate.objects.create( pub=loaded.pub, csr=loaded.csr, @@ -182,12 +183,12 @@ def test_create_from_instance(self) -> None: not_after=timezone.now(), not_before=timezone.now(), ) - self.assertEqual(loaded.pub, cert.pub) - self.assertEqual(loaded.csr, cert.csr) + assert loaded.pub == cert.pub + assert loaded.csr == cert.csr reloaded = Certificate.objects.get(pk=cert.pk) - self.assertEqual(loaded.pub, reloaded.pub) - self.assertEqual(loaded.csr, reloaded.csr) + assert loaded.pub == reloaded.pub + assert loaded.csr == reloaded.csr def test_repr(self) -> None: """Test ``repr()`` for custom modelfields.""" @@ -201,8 +202,8 @@ def test_repr(self) -> None: cert.refresh_from_db() subject = "CN=root-cert.example.com,OU=Django CA Testsuite,O=Django CA,L=Vienna,ST=Vienna,C=AT" - self.assertEqual(repr(cert.pub), f"") - self.assertEqual(repr(cert.csr), "") + assert repr(cert.pub) == f"" + assert repr(cert.csr) == "" def test_none_value(self) -> None: """Test that nullable fields work.""" @@ -213,9 +214,9 @@ def test_none_value(self) -> None: not_after=timezone.now(), not_before=timezone.now(), ) - self.assertIsNone(cert.csr) + assert cert.csr is None cert.refresh_from_db() - self.assertIsNone(cert.csr) + assert cert.csr is None def test_filter(self) -> None: """Test that we can use various representations for filtering.""" @@ -229,8 +230,8 @@ def test_filter(self) -> None: for prop in ["parsed", "pem", "der"]: qs = Certificate.objects.filter(pub=self.pub[prop]) - self.assertCountEqual(qs, [cert]) - self.assertEqual(qs[0].pub.der, self.pub["der"]) + assert list(qs) == [cert] + assert qs[0].pub.der == self.pub["der"] def test_full_clean(self) -> None: """Test the full_clean() method, which invokes ``to_python()`` on the field.""" @@ -244,8 +245,8 @@ def test_full_clean(self) -> None: serial="0", ) cert.full_clean() - self.assertEqual(cert.pub.loaded, self.pub["parsed"]) - self.assertEqual(cert.csr.loaded, self.csr["parsed"]) + assert cert.pub.loaded == self.pub["parsed"] + assert cert.csr.loaded == self.csr["parsed"] cert = Certificate( pub=cert.pub, @@ -257,8 +258,8 @@ def test_full_clean(self) -> None: serial="0", ) cert.full_clean() - self.assertEqual(cert.pub.loaded, self.pub["parsed"]) - self.assertEqual(cert.csr.loaded, self.csr["parsed"]) + assert cert.pub.loaded == self.pub["parsed"] + assert cert.csr.loaded == self.csr["parsed"] def test_empty_csr(self) -> None: """Test an empty CSR.""" @@ -272,12 +273,12 @@ def test_empty_csr(self) -> None: serial="0", ) cert.full_clean() - self.assertEqual(cert.pub.loaded, self.pub["parsed"]) - self.assertIsNone(cert.csr) + assert cert.pub.loaded == self.pub["parsed"] + assert cert.csr is None def test_invalid_value(self) -> None: """Test passing invalid values.""" - with self.assertRaisesRegex(ValueError, r"^True: Could not parse CertificateSigningRequest$"): + with pytest.raises(ValueError, match=r"^True: Could not parse CertificateSigningRequest$"): Certificate.objects.create( pub=CERT_DATA["child-cert"]["pub"]["parsed"], csr=True, # type: ignore[misc] # what we test @@ -286,7 +287,7 @@ def test_invalid_value(self) -> None: not_before=timezone.now(), ) - with self.assertRaisesRegex(ValueError, r"^True: Could not parse Certificate$"): + with pytest.raises(ValueError, match=r"^True: Could not parse Certificate$"): Certificate.objects.create( csr=CERT_DATA["child-cert"]["csr"]["parsed"], pub=True, # type: ignore[misc] # what we test @@ -329,56 +330,56 @@ def setUp(self) -> None: def test_str(self) -> None: """Test str() function.""" - self.assertEqual(str(self.account1), "user@example.com") - self.assertEqual(str(self.account2), "user@example.net") - self.assertEqual(str(AcmeAccount()), "") + assert str(self.account1) == "user@example.com" + assert str(self.account2) == "user@example.net" + assert str(AcmeAccount()) == "" def test_serial(self) -> None: """Test the ``serial`` property.""" - self.assertEqual(self.account1.serial, self.cas["root"].serial) - self.assertEqual(self.account2.serial, self.cas["child"].serial) + assert self.account1.serial == self.cas["root"].serial + assert self.account2.serial == self.cas["child"].serial # pylint: disable=no-member; false positive: pylint does not detect RelatedObjectDoesNotExist member - with self.assertRaisesRegex(AcmeAccount.ca.RelatedObjectDoesNotExist, r"^AcmeAccount has no ca\.$"): + with pytest.raises(AcmeAccount.ca.RelatedObjectDoesNotExist, match=r"^AcmeAccount has no ca\.$"): AcmeAccount().serial # noqa: B018 @freeze_time(TIMESTAMPS["everything_valid"]) def test_usable(self) -> None: """Test the ``usable`` property.""" - self.assertTrue(self.account1.usable) - self.assertFalse(self.account2.usable) + assert self.account1.usable + assert not self.account2.usable # Try states that make an account **unusable** self.account1.status = AcmeAccount.STATUS_DEACTIVATED - self.assertFalse(self.account1.usable) + assert not self.account1.usable self.account1.status = AcmeAccount.STATUS_REVOKED - self.assertFalse(self.account1.usable) + assert not self.account1.usable # Make the account usable again self.account1.status = AcmeAccount.STATUS_VALID - self.assertTrue(self.account1.usable) + assert self.account1.usable # TOS not agreed, but CA does not have any self.account1.terms_of_service_agreed = False - self.assertTrue(self.account1.usable) + assert self.account1.usable # TOS not agreed, but CA does have them, so account is now unusable self.cas["root"].terms_of_service = "http://tos.example.com" self.cas["root"].save() - self.assertFalse(self.account1.usable) + assert not self.account1.usable # Make the account usable again self.account1.terms_of_service_agreed = True - self.assertTrue(self.account1.usable) + assert self.account1.usable # If the CA is not usable, neither is the account self.account1.ca.enabled = False - self.assertFalse(self.account1.usable) + assert not self.account1.usable def test_unique_together(self) -> None: """Test that a thumbprint must be unique for the given CA.""" msg = r"^UNIQUE constraint failed: django_ca_acmeaccount\.ca_id, django_ca_acmeaccount\.thumbprint$" - with transaction.atomic(), self.assertRaisesRegex(IntegrityError, msg): + with transaction.atomic(), pytest.raises(IntegrityError, match=msg): AcmeAccount.objects.create(ca=self.account1.ca, thumbprint=self.account1.thumbprint) # Works, because CA is different @@ -390,9 +391,9 @@ def test_set_kid(self) -> None: hostname = settings.ALLOWED_HOSTS[0] req = RequestFactory().get("/foobar", HTTP_HOST=hostname) self.account1.set_kid(req) - self.assertEqual( - self.account1.kid, - f"http://{hostname}/django_ca/acme/{self.account1.serial}/acct/{self.account1.slug}/", + assert ( + self.account1.kid + == f"http://{hostname}/django_ca/acme/{self.account1.serial}/acct/{self.account1.slug}/" ) def test_validate_pem(self) -> None: @@ -428,35 +429,33 @@ def setUp(self) -> None: def test_str(self) -> None: """Test the str function.""" - self.assertEqual(str(self.order1), f"{self.order1.slug} ({self.account})") + assert str(self.order1) == f"{self.order1.slug} ({self.account})" def test_acme_url(self) -> None: """Test the acme url function.""" - self.assertEqual( - self.order1.acme_url, f"/django_ca/acme/{self.account.ca.serial}/order/{self.order1.slug}/" - ) + assert self.order1.acme_url == f"/django_ca/acme/{self.account.ca.serial}/order/{self.order1.slug}/" def test_acme_finalize_url(self) -> None: """Test the acme finalize url function.""" - self.assertEqual( - self.order1.acme_finalize_url, - f"/django_ca/acme/{self.account.ca.serial}/order/{self.order1.slug}/finalize/", + assert ( + self.order1.acme_finalize_url + == f"/django_ca/acme/{self.account.ca.serial}/order/{self.order1.slug}/finalize/" ) def test_add_authorizations(self) -> None: """Test the add_authorizations method.""" identifier = messages.Identifier(typ=messages.IDENTIFIER_FQDN, value="example.com") auths = self.order1.add_authorizations([identifier]) - self.assertEqual(auths[0].type, "dns") - self.assertEqual(auths[0].value, "example.com") + assert auths[0].type == "dns" + assert auths[0].value == "example.com" msg = r"^UNIQUE constraint failed: django_ca_acmeauthorization\.order_id, django_ca_acmeauthorization\.type, django_ca_acmeauthorization\.value$" # NOQA: E501 - with transaction.atomic(), self.assertRaisesRegex(IntegrityError, msg): + with transaction.atomic(), pytest.raises(IntegrityError, match=msg): self.order1.add_authorizations([identifier]) def test_serial(self) -> None: """Test getting the serial of the associated CA.""" - self.assertEqual(self.order1.serial, self.cas["root"].serial) + assert self.order1.serial == self.cas["root"].serial class AcmeAuthorizationTestCase(TestCaseMixin, AcmeValuesMixin, TestCase): @@ -484,58 +483,52 @@ def setUp(self) -> None: def test_str(self) -> None: """Test the __str__ method.""" - self.assertEqual(str(self.auth1), "dns: example.com") - self.assertEqual(str(self.auth2), "dns: example.net") + assert str(self.auth1) == "dns: example.com" + assert str(self.auth2) == "dns: example.net" def test_account_property(self) -> None: """Test the account property.""" - self.assertEqual(self.auth1.account, self.account) - self.assertEqual(self.auth2.account, self.account) + assert self.auth1.account == self.account + assert self.auth2.account == self.account def test_acme_url(self) -> None: """Test acme_url property.""" - self.assertEqual( - self.auth1.acme_url, - f"/django_ca/acme/{self.cas['root'].serial}/authz/{self.auth1.slug}/", - ) - self.assertEqual( - self.auth2.acme_url, - f"/django_ca/acme/{self.cas['root'].serial}/authz/{self.auth2.slug}/", - ) + assert self.auth1.acme_url == f"/django_ca/acme/{self.cas['root'].serial}/authz/{self.auth1.slug}/" + assert self.auth2.acme_url == f"/django_ca/acme/{self.cas['root'].serial}/authz/{self.auth2.slug}/" def test_expires(self) -> None: """Test the expires property.""" - self.assertEqual(self.auth1.expires, self.order.expires) - self.assertEqual(self.auth2.expires, self.order.expires) + assert self.auth1.expires == self.order.expires + assert self.auth2.expires == self.order.expires def test_identifier(self) -> None: """Test the identifier property.""" - self.assertEqual( - self.auth1.identifier, messages.Identifier(typ=messages.IDENTIFIER_FQDN, value=self.auth1.value) + assert self.auth1.identifier == messages.Identifier( + typ=messages.IDENTIFIER_FQDN, value=self.auth1.value ) - self.assertEqual( - self.auth2.identifier, messages.Identifier(typ=messages.IDENTIFIER_FQDN, value=self.auth2.value) + assert self.auth2.identifier == messages.Identifier( + typ=messages.IDENTIFIER_FQDN, value=self.auth2.value ) def test_identifier_unknown_type(self) -> None: """Test that an identifier with an unknown type raises a ValueError.""" self.auth1.type = "foo" - with self.assertRaisesRegex(ValueError, r"^Unknown identifier type: foo$"): + with pytest.raises(ValueError, match=r"^Unknown identifier type: foo$"): self.auth1.identifier # noqa: B018 def test_subject_alternative_name(self) -> None: """Test the subject_alternative_name property.""" - self.assertEqual(self.auth1.subject_alternative_name, "dns:example.com") - self.assertEqual(self.auth2.subject_alternative_name, "dns:example.net") + assert self.auth1.subject_alternative_name == "dns:example.com" + assert self.auth2.subject_alternative_name == "dns:example.net" def test_get_challenges(self) -> None: """Test the get_challenges() method.""" chall_qs = self.auth1.get_challenges() - self.assertIsInstance(chall_qs[0], AcmeChallenge) - self.assertIsInstance(chall_qs[1], AcmeChallenge) + assert isinstance(chall_qs[0], AcmeChallenge) + assert isinstance(chall_qs[1], AcmeChallenge) - self.assertEqual(self.auth1.get_challenges(), chall_qs) - self.assertEqual(AcmeChallenge.objects.all().count(), 2) + assert self.auth1.get_challenges() == chall_qs + assert AcmeChallenge.objects.all().count() == 2 class AcmeChallengeTestCase(TestCaseMixin, AcmeValuesMixin, TestCase): @@ -563,17 +556,17 @@ def assertChallenge( # pylint: disable=invalid-name self, challenge: ChallengeTypeVar, typ: str, token: bytes, cls: type[ChallengeTypeVar] ) -> None: """Test that the ACME challenge is of the given type.""" - self.assertIsInstance(challenge, cls) - self.assertEqual(challenge.typ, typ) - self.assertEqual(challenge.token, token) + assert isinstance(challenge, cls) + assert challenge.typ == typ + assert challenge.token == token def test_str(self) -> None: """Test the __str__ method.""" - self.assertEqual(str(self.chall), f"{self.hostname} ({self.chall.type})") + assert str(self.chall) == f"{self.hostname} ({self.chall.type})" def test_acme_url(self) -> None: """Test acme_url property.""" - self.assertEqual(self.chall.acme_url, f"/django_ca/acme/{self.chall.serial}/chall/{self.chall.slug}/") + assert self.chall.acme_url == f"/django_ca/acme/{self.chall.serial}/chall/{self.chall.slug}/" def test_acme_challenge(self) -> None: """Test acme_challenge property.""" @@ -590,67 +583,67 @@ def test_acme_challenge(self) -> None: ) self.chall.type = "foo" - with self.assertRaisesRegex(ValueError, r"^foo: Unsupported challenge type\.$"): + with pytest.raises(ValueError, match=r"^foo: Unsupported challenge type\.$"): self.chall.acme_challenge # noqa: B018 @freeze_time(TIMESTAMPS["everything_valid"]) def test_acme_validated(self) -> None: """Test acme_validated property.""" # preconditions for checks (might change them in setUp without realising it might affect this test) - self.assertNotEqual(self.chall.status, AcmeChallenge.STATUS_VALID) - self.assertIsNone(self.chall.validated) + assert self.chall.status != AcmeChallenge.STATUS_VALID + assert self.chall.validated is None - self.assertIsNone(self.chall.acme_validated) + assert self.chall.acme_validated is None self.chall.status = AcmeChallenge.STATUS_VALID - self.assertIsNone(self.chall.acme_validated) # still None (no validated timestamp) + assert self.chall.acme_validated is None # still None (no validated timestamp) self.chall.validated = timezone.now() - self.assertEqual(self.chall.acme_validated, TIMESTAMPS["everything_valid"]) + assert self.chall.acme_validated == TIMESTAMPS["everything_valid"] # We return a UTC timestamp, even if timezone support is disabled. with self.settings(USE_TZ=False): self.chall.validated = timezone.now() - self.assertEqual(self.chall.acme_validated, TIMESTAMPS["everything_valid"]) + assert self.chall.acme_validated == TIMESTAMPS["everything_valid"] def test_encoded(self) -> None: """Test the encoded property.""" self.chall.token = "ADwFxCAXrnk47rcCnnbbtGYSo_l61MCYXqtBziPt26mk7-QzpYNNKnTsKjbBYPzD" self.chall.save() - self.assertEqual( - self.chall.encoded_token, - b"QUR3RnhDQVhybms0N3JjQ25uYmJ0R1lTb19sNjFNQ1lYcXRCemlQdDI2bWs3LVF6cFlOTktuVHNLamJCWVB6RA", + assert ( + self.chall.encoded_token + == b"QUR3RnhDQVhybms0N3JjQ25uYmJ0R1lTb19sNjFNQ1lYcXRCemlQdDI2bWs3LVF6cFlOTktuVHNLamJCWVB6RA" ) def test_expected(self) -> None: """Test the expected property.""" self.chall.token = "ADwFxCAXrnk47rcCnnbbtGYSo_l61MCYXqtBziPt26mk7-QzpYNNKnTsKjbBYPzD" self.chall.save() - self.assertEqual( - self.chall.expected, self.chall.encoded_token + b"." + self.account.thumbprint.encode("utf-8") + assert self.chall.expected == self.chall.encoded_token + b"." + self.account.thumbprint.encode( + "utf-8" ) self.chall.type = AcmeChallenge.TYPE_DNS_01 self.chall.save() - self.assertEqual(self.chall.expected, b"LoNgngEeuLw4rWDFpplPA0XBp9dd9spzuuqbsRFcKug") + assert self.chall.expected == b"LoNgngEeuLw4rWDFpplPA0XBp9dd9spzuuqbsRFcKug" self.chall.type = AcmeChallenge.TYPE_TLS_ALPN_01 self.chall.save() - with self.assertRaisesRegex(ValueError, r"^tls-alpn-01: Unsupported challenge type\.$"): + with pytest.raises(ValueError, match=r"^tls-alpn-01: Unsupported challenge type\.$"): self.chall.expected # noqa: B018 def test_get_challenge(self) -> None: """Test the get_challenge() function.""" body = self.chall.get_challenge(RequestFactory().get("/")) - self.assertIsInstance(body, messages.ChallengeBody) - self.assertEqual(body.chall, self.chall.acme_challenge) - self.assertEqual(body.status, self.chall.status) - self.assertEqual(body.validated, self.chall.acme_validated) - self.assertEqual(body.uri, f"http://testserver{self.chall.acme_url}") + assert isinstance(body, messages.ChallengeBody) + assert body.chall == self.chall.acme_challenge + assert body.status == self.chall.status + assert body.validated == self.chall.acme_validated + assert body.uri == f"http://testserver{self.chall.acme_url}" def test_serial(self) -> None: """Test the serial property.""" - self.assertEqual(self.chall.serial, self.chall.auth.order.account.ca.serial) + assert self.chall.serial == self.chall.auth.order.account.ca.serial class AcmeCertificateTestCase(TestCaseMixin, AcmeValuesMixin, TestCase): @@ -673,13 +666,11 @@ def setUp(self) -> None: def test_acme_url(self) -> None: """Test the acme_url property.""" - self.assertEqual( - self.acme_cert.acme_url, f"/django_ca/acme/{self.order.serial}/cert/{self.acme_cert.slug}/" - ) + assert self.acme_cert.acme_url == f"/django_ca/acme/{self.order.serial}/cert/{self.acme_cert.slug}/" def test_parse_csr(self) -> None: """Test the parse_csr property.""" self.acme_cert.csr = ( CERT_DATA["root-cert"]["csr"]["parsed"].public_bytes(Encoding.PEM).decode("utf-8") ) - self.assertIsInstance(self.acme_cert.parse_csr(), x509.CertificateSigningRequest) + assert isinstance(self.acme_cert.parse_csr(), x509.CertificateSigningRequest) diff --git a/ca/django_ca/tests/test_querysets.py b/ca/django_ca/tests/test_querysets.py index 9da5c1ce4..276d160ee 100644 --- a/ca/django_ca/tests/test_querysets.py +++ b/ca/django_ca/tests/test_querysets.py @@ -31,6 +31,7 @@ Certificate, CertificateAuthority, ) +from django_ca.tests.base.assertions import assert_count_equal from django_ca.tests.base.constants import TIMESTAMPS from django_ca.tests.base.mixins import AcmeValuesMixin, TestCaseMixin @@ -42,7 +43,7 @@ def assertQuerySet( # pylint: disable=invalid-name; unittest standard self, qs: "models.QuerySet[models.Model]", *items: models.Model ) -> None: """Minor shortcut to test querysets.""" - self.assertCountEqual(qs, items) + assert_count_equal(qs, items) @contextmanager def attr(self, obj: models.Model, attr: str, value: Any) -> Iterator[None]: @@ -67,42 +68,42 @@ def test_enabled_disabled(self) -> None: """Test enabled/disabled filter.""" self.load_named_cas("__usable__") - self.assertCountEqual(CertificateAuthority.objects.enabled(), self.cas.values()) - self.assertCountEqual(CertificateAuthority.objects.disabled(), []) + assert_count_equal(CertificateAuthority.objects.enabled(), self.cas.values()) + assert not CertificateAuthority.objects.disabled() self.ca.enabled = False self.ca.save() - self.assertCountEqual( + assert_count_equal( CertificateAuthority.objects.enabled(), [c for c in self.cas.values() if c.name != self.ca.name], ) - self.assertCountEqual(CertificateAuthority.objects.disabled(), [self.ca]) + assert_count_equal(CertificateAuthority.objects.disabled(), [self.ca]) def test_valid(self) -> None: """Test valid/usable/invalid filters.""" self.load_named_cas("__usable__") with freeze_time(TIMESTAMPS["before_cas"]): - self.assertCountEqual(CertificateAuthority.objects.valid(), []) - self.assertCountEqual(CertificateAuthority.objects.usable(), []) - self.assertCountEqual(CertificateAuthority.objects.invalid(), self.cas.values()) + assert not CertificateAuthority.objects.valid() + assert not CertificateAuthority.objects.usable() + assert_count_equal(CertificateAuthority.objects.invalid(), self.cas.values()) with freeze_time(TIMESTAMPS["before_child"]): valid = [c for c in self.cas.values() if c.name != "child"] - self.assertCountEqual(CertificateAuthority.objects.valid(), valid) - self.assertCountEqual(CertificateAuthority.objects.usable(), valid) - self.assertCountEqual(CertificateAuthority.objects.invalid(), [self.cas["child"]]) + assert_count_equal(CertificateAuthority.objects.valid(), valid) + assert_count_equal(CertificateAuthority.objects.usable(), valid) + assert_count_equal(CertificateAuthority.objects.invalid(), [self.cas["child"]]) with freeze_time(TIMESTAMPS["after_child"]): - self.assertCountEqual(CertificateAuthority.objects.valid(), self.cas.values()) - self.assertCountEqual(CertificateAuthority.objects.usable(), self.cas.values()) - self.assertCountEqual(CertificateAuthority.objects.invalid(), []) + assert_count_equal(CertificateAuthority.objects.valid(), self.cas.values()) + assert_count_equal(CertificateAuthority.objects.usable(), self.cas.values()) + assert not CertificateAuthority.objects.invalid() with freeze_time(TIMESTAMPS["cas_expired"]): - self.assertCountEqual(CertificateAuthority.objects.valid(), []) - self.assertCountEqual(CertificateAuthority.objects.usable(), []) - self.assertCountEqual(CertificateAuthority.objects.invalid(), self.cas.values()) + assert not CertificateAuthority.objects.valid() + assert not CertificateAuthority.objects.usable() + assert_count_equal(CertificateAuthority.objects.invalid(), self.cas.values()) class CertificateQuerysetTestCase(QuerySetTestCaseMixin, TestCase): diff --git a/ca/django_ca/tests/test_settings.py b/ca/django_ca/tests/test_settings.py index c93d900cb..2f18fcee1 100644 --- a/ca/django_ca/tests/test_settings.py +++ b/ca/django_ca/tests/test_settings.py @@ -337,7 +337,7 @@ def test_ca_acme_cert_validity_timedelta_settings_as_int(settings: SettingsWrapp @pytest.mark.parametrize("setting", ("CA_ACME_DEFAULT_CERT_VALIDITY", "CA_ACME_MAX_CERT_VALIDITY")) @pytest.mark.parametrize( - "value,message", + ("value", "message"), ( (0.9, "Input should be greater than or equal to 1 day"), (timedelta(seconds=1), "Input should be greater than or equal to 1 day"), @@ -356,7 +356,7 @@ def test_ca_acme_cert_validity_limits( @pytest.mark.parametrize( - "value,message", + ("value", "message"), ( (timedelta(seconds=59), "Input should be greater than or equal to 1 minute"), (timedelta(days=2), "Input should be less than or equal to 1 day"), @@ -414,7 +414,7 @@ def test_ca_crl_profiles_with_deprecated_scope(settings: SettingsWrapper, scope: @pytest.mark.parametrize( - "value,parsed", + ("value", "parsed"), ( ("0a:bc", "ABC"), # leading zero is stripped ("0", "0"), # single zero is *not* stripped @@ -481,7 +481,7 @@ def test_ca_default_name_order(settings: SettingsWrapper) -> None: @pytest.mark.parametrize( - "value,msg", + ("value", "msg"), ( (True, r"Input should be a valid tuple"), (("invalid-oid",), "invalid-oid: Invalid object identifier"), @@ -499,7 +499,7 @@ def test_ca_default_profile_not_defined(settings: SettingsWrapper) -> None: settings.CA_DEFAULT_PROFILE = "foo" -@pytest.mark.parametrize("value,expected", (("SHA-224", hashes.SHA224), ("SHA3/384", hashes.SHA3_384))) +@pytest.mark.parametrize(("value", "expected"), (("SHA-224", hashes.SHA224), ("SHA3/384", hashes.SHA3_384))) def test_ca_default_signature_hash_algorithm( settings: SettingsWrapper, value: Any, expected: type[hashes.HashAlgorithm] ) -> None: @@ -516,7 +516,7 @@ def test_ca_default_signature_hash_algorithm_with_invalid_value(settings: Settin @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( # Serialized version ( @@ -541,7 +541,7 @@ def test_ca_default_subject(settings: SettingsWrapper, value: Any, expected: x50 @pytest.mark.parametrize( - "value,msg", + ("value", "msg"), ( ((("CN", ""),), r"Value error, Attribute's length must be >= 1 and <= 64, but it was 0"), ((("CN", "X" * 65),), r"Value error, Attribute's length must be >= 1 and <= 64, but it was 65"), @@ -554,7 +554,7 @@ def test_ca_default_subject_with_invalid_values(settings: SettingsWrapper, value @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ([("CN", "example.com")], x509.Name([cn("example.com")])), ((("C", "AT"), ("CN", "example.com")), x509.Name([country("AT"), cn("example.com")])), @@ -576,7 +576,7 @@ def test_ca_default_subject_with_deprecated_values( @pytest.mark.parametrize( - "value,msg", + ("value", "msg"), ( ([("invalid", "wrong")], "invalid: Invalid object identifier"), ([["one-element"]], r"Must be lists/tuples with two items, got 1\."), @@ -646,7 +646,7 @@ def test_ca_profiles_update_description(settings: SettingsWrapper) -> None: @pytest.mark.parametrize( - "subject,expected", + ("subject", "expected"), ( (False, False), ([], x509.Name([])), @@ -669,7 +669,7 @@ def test_ca_profiles_override_subject_with_deprecated_values(settings: SettingsW @pytest.mark.parametrize( - "value,msg", + ("value", "msg"), ( ("foo", "Input should be a valid dictionary"), # whole setting is invalid ({"client": {"subject": "foo"}}, r"Value error, foo: Must be a list or tuple\."), @@ -734,7 +734,7 @@ def test_ca_crl_profiles_invalid_scope(settings: SettingsWrapper) -> None: @pytest.mark.parametrize( - "base,override", + ("base", "override"), ( ("only_contains_ca_certs", "only_contains_user_certs"), ("only_contains_user_certs", "only_contains_ca_certs"), diff --git a/ca/django_ca/tests/test_sphinx_extensions.py b/ca/django_ca/tests/test_sphinx_extensions.py index 8c287149f..ab96138ce 100644 --- a/ca/django_ca/tests/test_sphinx_extensions.py +++ b/ca/django_ca/tests/test_sphinx_extensions.py @@ -24,13 +24,13 @@ class CommandLineTextWrapperTestCase(TestCase): def assertWraps(self, command: str, expected: list[str]) -> None: # pylint: disable=invalid-name """Assert that the given command wraps to the expected full text.""" wrapper = CommandLineTextWrapper(width=12) - self.assertEqual(wrapper.wrap(command), expected) + assert wrapper.wrap(command) == expected def assertSplits(self, command: str, expected: list[str]) -> None: # pylint: disable=invalid-name """Assert that the given command splits into the expected tokens.""" wrapper = CommandLineTextWrapper() # PYLINT note: this is the function that we override - self.assertEqual(wrapper._split(command), expected) # pylint: disable=protected-access + assert wrapper._split(command) == expected # pylint: disable=protected-access def test_split(self) -> None: """Test the overwritten split function.""" diff --git a/ca/django_ca/tests/test_tasks.py b/ca/django_ca/tests/test_tasks.py index 1a5ebc785..4f7a2e54b 100644 --- a/ca/django_ca/tests/test_tasks.py +++ b/ca/django_ca/tests/test_tasks.py @@ -55,12 +55,12 @@ class TestBasic(TestCaseMixin, TestCase): def test_missing_celery(self) -> None: """Test that we work even if celery is not installed.""" # negative assertion to make sure that the IsInstance assertion below is actually meaningful - self.assertNotIsInstance(tasks.cache_crl, types.FunctionType) + assert not isinstance(tasks.cache_crl, types.FunctionType) try: with mock.patch.dict("sys.modules", celery=None): importlib.reload(tasks) - self.assertIsInstance(tasks.cache_crl, types.FunctionType) + assert isinstance(tasks.cache_crl, types.FunctionType) finally: # Make sure that module is reloaded, or any failed test in the try block will cause *all other # tests* to fail, because the celery import would be cached to *not* work @@ -71,7 +71,7 @@ def test_run_task(self) -> None: # run_task() without celery with self.settings(CA_USE_CELERY=False), self.patch("django_ca.tasks.cache_crls") as task_mock: tasks.run_task(tasks.cache_crls) - self.assertEqual(task_mock.call_count, 1) + assert task_mock.call_count == 1 # finally, run_task() with celery with self.settings(CA_USE_CELERY=True), self.mute_celery((((), {}), {})): @@ -117,16 +117,16 @@ def refresh_from_db(self) -> None: def assertInvalid(self) -> None: # pylint: disable=invalid-name; unittest standard """Assert that the challenge validation failed.""" self.refresh_from_db() - self.assertEqual(self.chall.status, AcmeChallenge.STATUS_INVALID) - self.assertEqual(self.auth.status, AcmeAuthorization.STATUS_INVALID) - self.assertEqual(self.order.status, AcmeOrder.STATUS_INVALID) + assert self.chall.status == AcmeChallenge.STATUS_INVALID + assert self.auth.status == AcmeAuthorization.STATUS_INVALID + assert self.order.status == AcmeOrder.STATUS_INVALID def assertValid(self, order_state: str = AcmeOrder.STATUS_READY) -> None: # pylint: disable=invalid-name """Assert that the challenge is valid.""" self.refresh_from_db() - self.assertEqual(self.chall.status, AcmeChallenge.STATUS_VALID) - self.assertEqual(self.auth.status, AcmeAuthorization.STATUS_VALID) - self.assertEqual(self.order.status, order_state) + assert self.chall.status == AcmeChallenge.STATUS_VALID + assert self.auth.status == AcmeAuthorization.STATUS_VALID + assert self.order.status == order_state @contextmanager def mock_challenge( @@ -144,7 +144,7 @@ def test_acme_disabled(self) -> None: """Test invoking task when ACME support is not enabled.""" with self.settings(CA_ENABLE_ACME=False), self.assertLogs() as logcm: tasks.acme_validate_challenge(self.chall.pk) - self.assertEqual(logcm.output, ["ERROR:django_ca.tasks:ACME is not enabled."]) + assert logcm.output == ["ERROR:django_ca.tasks:ACME is not enabled."] def test_unknown_challenge(self) -> None: """Test invoking task with an unknown challenge.""" @@ -152,7 +152,7 @@ def test_unknown_challenge(self) -> None: with self.assertLogs() as logcm: tasks.acme_validate_challenge(self.chall.pk) - self.assertEqual(logcm.output, [f"ERROR:django_ca.tasks:Challenge with id={self.chall.pk} not found"]) + assert logcm.output == [f"ERROR:django_ca.tasks:Challenge with id={self.chall.pk} not found"] def test_status_not_processing(self) -> None: """Test invoking task where the status is not "processing".""" @@ -162,9 +162,9 @@ def test_status_not_processing(self) -> None: with self.assertLogs() as logcm: tasks.acme_validate_challenge(self.chall.pk) - self.assertEqual( - logcm.output, [f"ERROR:django_ca.tasks:{self.chall}: pending: Invalid state (must be processing)"] - ) + assert logcm.output == [ + f"ERROR:django_ca.tasks:{self.chall}: pending: Invalid state (must be processing)" + ] def test_unusable_auth(self) -> None: """Test invoking task with an unusable authentication.""" @@ -174,7 +174,7 @@ def test_unusable_auth(self) -> None: with self.assertLogs() as logcm: tasks.acme_validate_challenge(self.chall.pk) - self.assertEqual(logcm.output, [f"ERROR:django_ca.tasks:{self.chall}: Authentication is not usable"]) + assert logcm.output == [f"ERROR:django_ca.tasks:{self.chall}: Authentication is not usable"] def test_response_wrong_content(self) -> None: """Test the server returning the wrong content in the response.""" @@ -184,12 +184,9 @@ def test_response_wrong_content(self) -> None: ): tasks.acme_validate_challenge(self.chall.pk) self.assertInvalid() - self.assertEqual( - logcm.output, - [ - f"INFO:django_ca.tasks:{self.chall!s} is invalid", - ], - ) + assert logcm.output == [ + f"INFO:django_ca.tasks:{self.chall!s} is invalid", + ] def test_unsupported_challenge(self) -> None: """Test what happens when challenge type is not supported.""" @@ -202,13 +199,10 @@ def test_unsupported_challenge(self) -> None: ): tasks.acme_validate_challenge(self.chall.pk) self.assertInvalid() - self.assertEqual( - logcm.output, - [ - f"ERROR:django_ca.tasks:{self.chall!s}: Challenge type is not supported.", - f"INFO:django_ca.tasks:{self.chall!s} is invalid", - ], - ) + assert logcm.output == [ + f"ERROR:django_ca.tasks:{self.chall!s}: Challenge type is not supported.", + f"INFO:django_ca.tasks:{self.chall!s} is invalid", + ] def test_basic(self) -> None: """Test validation actually working.""" @@ -270,7 +264,7 @@ def mock_challenge( matcher = req_mock.get(url, raw=HTTPResponse(body=content, status=status, preload_content=False)) yield req_mock - self.assertEqual(matcher.call_count, call_count) + assert matcher.call_count == call_count def test_response_not_ok(self) -> None: """Test the server not returning a HTTP status code 200.""" @@ -284,10 +278,10 @@ def test_request_exception(self) -> None: with self.patch("requests.get", side_effect=Exception(val)) as req_mock, self.assertLogs() as logcm: tasks.acme_validate_challenge(self.chall.pk) self.assertInvalid() - self.assertEqual(req_mock.mock_calls, [((self.url,), {"timeout": 1, "stream": True})]) - self.assertEqual(len(logcm.output), 2) - self.assertIn(val, logcm.output[0]) - self.assertEqual(logcm.output[1], f"INFO:django_ca.tasks:{self.chall!s} is invalid") + assert req_mock.mock_calls == [((self.url,), {"timeout": 1, "stream": True})] + assert len(logcm.output) == 2 + assert val in logcm.output[0] + assert logcm.output[1] == f"INFO:django_ca.tasks:{self.chall!s} is invalid" @freeze_time(TIMESTAMPS["everything_valid"]) @@ -332,7 +326,7 @@ def mock_challenge( # Note: Only assert the first two parameters, as otherwise we'd test dnspython internals resolve_cm.assert_called_once() expected = (f"_acme_challenge.{domain}", "TXT") - self.assertEqual(resolve_cm.call_args_list[0].args[:2], expected) + assert resolve_cm.call_args_list[0].args[:2] == expected def test_nxdomain(self) -> None: """Test a ACME validation where the domain does not exist.""" @@ -348,14 +342,11 @@ def test_nxdomain(self) -> None: exp = self.chall.expected.decode("ascii") acme_domain = f"_acme_challenge.{domain}" logger = "django_ca.acme.validation" - self.assertEqual( - logcm.output, - [ - f"INFO:{logger}:DNS-01 validation of {domain}: Expect {exp} on {acme_domain}", - f"DEBUG:{logger}:TXT {acme_domain}: record does not exist.", - f"INFO:django_ca.tasks:{self.chall!s} is invalid", - ], - ) + assert logcm.output == [ + f"INFO:{logger}:DNS-01 validation of {domain}: Expect {exp} on {acme_domain}", + f"DEBUG:{logger}:TXT {acme_domain}: record does not exist.", + f"INFO:django_ca.tasks:{self.chall!s} is invalid", + ] @freeze_time(TIMESTAMPS["everything_valid"]) @@ -385,7 +376,7 @@ def test_acme_disabled(self) -> None: """Test invoking task when ACME support is not enabled.""" with self.settings(CA_ENABLE_ACME=False), self.assertLogs() as logcm: tasks.acme_issue_certificate(self.acme_cert.pk) - self.assertEqual(logcm.output, ["ERROR:django_ca.tasks:ACME is not enabled."]) + assert logcm.output == ["ERROR:django_ca.tasks:ACME is not enabled."] def test_unknown_certificate(self) -> None: """Test invoking task with an unknown cert.""" @@ -393,9 +384,7 @@ def test_unknown_certificate(self) -> None: with self.assertLogs() as logcm: tasks.acme_issue_certificate(self.acme_cert.pk) - self.assertEqual( - logcm.output, [f"ERROR:django_ca.tasks:Certificate with id={self.acme_cert.pk} not found"] - ) + assert logcm.output == [f"ERROR:django_ca.tasks:Certificate with id={self.acme_cert.pk} not found"] def test_unusable_cert(self) -> None: """Test invoking task where the order is not usable.""" @@ -405,9 +394,9 @@ def test_unusable_cert(self) -> None: with self.assertLogs() as logcm: tasks.acme_issue_certificate(self.acme_cert.pk) - self.assertEqual( - logcm.output, [f"ERROR:django_ca.tasks:{self.order}: Cannot issue certificate for this order"] - ) + assert logcm.output == [ + f"ERROR:django_ca.tasks:{self.order}: Cannot issue certificate for this order" + ] @override_tmpcadir() def test_basic(self) -> None: @@ -415,22 +404,20 @@ def test_basic(self) -> None: with self.assertLogs() as logcm: tasks.acme_issue_certificate(self.acme_cert.pk) - self.assertEqual( - logcm.output, [f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"] - ) + assert logcm.output == [ + f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}" + ] + self.acme_cert.refresh_from_db() assert self.acme_cert.cert is not None, "Check to make mypy happy" self.order.refresh_from_db() - self.assertEqual(self.order.status, AcmeOrder.STATUS_VALID) - self.assertEqual( - self.acme_cert.cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME], - subject_alternative_name(x509.DNSName(self.hostname)), - ) - self.assertEqual( - self.acme_cert.cert.not_after, timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY - ) - self.assertEqual(self.acme_cert.cert.cn, self.hostname) - self.assertEqual(self.acme_cert.cert.profile, model_settings.CA_DEFAULT_PROFILE) + assert self.order.status == AcmeOrder.STATUS_VALID + assert self.acme_cert.cert.extensions[ + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ] == subject_alternative_name(x509.DNSName(self.hostname)) + assert self.acme_cert.cert.not_after == timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY + assert self.acme_cert.cert.cn == self.hostname + assert self.acme_cert.cert.profile == model_settings.CA_DEFAULT_PROFILE @override_settings(USE_TZ=False) def test_basic_without_timezone_support(self) -> None: @@ -449,15 +436,13 @@ def test_two_hostnames(self) -> None: self.acme_cert.refresh_from_db() assert self.acme_cert.cert is not None, "Check to make mypy happy" self.order.refresh_from_db() - self.assertEqual(self.order.status, AcmeOrder.STATUS_VALID) - self.assertEqual( - self.acme_cert.cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME], - subject_alternative_name(x509.DNSName(self.hostname), x509.DNSName(hostname2)), - ) - self.assertEqual( - self.acme_cert.cert.not_after, timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY - ) - self.assertIn(self.acme_cert.cert.cn, [self.hostname, hostname2]) + assert self.order.status == AcmeOrder.STATUS_VALID + assert self.acme_cert.cert.extensions[ + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ] == subject_alternative_name(x509.DNSName(self.hostname), x509.DNSName(hostname2)) + + assert self.acme_cert.cert.not_after == timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY + assert self.acme_cert.cert.cn in [self.hostname, hostname2] @override_tmpcadir() def test_not_after(self) -> None: @@ -469,19 +454,20 @@ def test_not_after(self) -> None: with self.assertLogs() as logcm: tasks.acme_issue_certificate(self.acme_cert.pk) - self.assertEqual( - logcm.output, [f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"] - ) + assert logcm.output == [ + f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}" + ] + self.acme_cert.refresh_from_db() assert self.acme_cert.cert is not None, "Check to make mypy happy" self.order.refresh_from_db() - self.assertEqual(self.order.status, AcmeOrder.STATUS_VALID) - self.assertEqual( - self.acme_cert.cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME], - subject_alternative_name(x509.DNSName(self.hostname)), - ) - self.assertEqual(self.acme_cert.cert.not_after, not_after) - self.assertEqual(self.acme_cert.cert.cn, self.hostname) + assert self.order.status == AcmeOrder.STATUS_VALID + assert self.acme_cert.cert.extensions[ + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ] == subject_alternative_name(x509.DNSName(self.hostname)) + + assert self.acme_cert.cert.not_after == not_after + assert self.acme_cert.cert.cn == self.hostname def test_not_after_with_use_tz_is_false(self) -> None: """Test not_after with USE_TZ=False.""" @@ -498,22 +484,22 @@ def test_profile(self) -> None: with self.assertLogs() as logcm: tasks.acme_issue_certificate(self.acme_cert.pk) - self.assertEqual( - logcm.output, [f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}"] - ) + assert logcm.output == [ + f"INFO:django_ca.tasks:{self.order}: Issuing certificate for dns:{self.hostname}" + ] + self.acme_cert.refresh_from_db() assert self.acme_cert.cert is not None, "Check to make mypy happy" self.order.refresh_from_db() - self.assertEqual(self.order.status, AcmeOrder.STATUS_VALID) - self.assertEqual( - self.acme_cert.cert.extensions[ExtensionOID.SUBJECT_ALTERNATIVE_NAME], - subject_alternative_name(x509.DNSName(self.hostname)), - ) - self.assertEqual( - self.acme_cert.cert.not_after, timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY - ) - self.assertEqual(self.acme_cert.cert.cn, self.hostname) - self.assertEqual(self.acme_cert.cert.profile, "client") + assert self.order.status == AcmeOrder.STATUS_VALID + assert self.acme_cert.cert.extensions[ + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ] == subject_alternative_name(x509.DNSName(self.hostname)) + + assert self.acme_cert.cert.not_after == timezone.now() + model_settings.CA_ACME_DEFAULT_CERT_VALIDITY + + assert self.acme_cert.cert.cn == self.hostname + assert self.acme_cert.cert.profile == "client" @freeze_time(TIMESTAMPS["everything_valid"]) @@ -546,27 +532,27 @@ def test_basic(self) -> None: """Basic test.""" tasks.acme_cleanup() # does nothing if nothing is expired - self.assertEqual(self.acme_cert, AcmeCertificate.objects.get(pk=self.acme_cert.pk)) - self.assertEqual(self.order, AcmeOrder.objects.get(pk=self.order.pk)) - self.assertEqual(self.auth, AcmeAuthorization.objects.get(pk=self.auth.pk)) - self.assertEqual(self.account, AcmeAccount.objects.get(pk=self.account.pk)) + assert self.acme_cert == AcmeCertificate.objects.get(pk=self.acme_cert.pk) + assert self.order == AcmeOrder.objects.get(pk=self.order.pk) + assert self.auth == AcmeAuthorization.objects.get(pk=self.auth.pk) + assert self.account == AcmeAccount.objects.get(pk=self.account.pk) with self.freeze_time(timezone.now() + timedelta(days=3)): tasks.acme_cleanup() - self.assertEqual(AcmeOrder.objects.all().count(), 0) - self.assertEqual(AcmeAuthorization.objects.all().count(), 0) - self.assertEqual(AcmeChallenge.objects.all().count(), 0) - self.assertEqual(AcmeCertificate.objects.all().count(), 0) + assert AcmeOrder.objects.all().count() == 0 + assert AcmeAuthorization.objects.all().count() == 0 + assert AcmeChallenge.objects.all().count() == 0 + assert AcmeCertificate.objects.all().count() == 0 def test_acme_disabled(self) -> None: """Test task when ACME is disabled.""" with self.settings(CA_ENABLE_ACME=False), self.assertLogs() as logcm: with self.freeze_time(timezone.now() + timedelta(days=3)): tasks.acme_cleanup() - self.assertEqual(logcm.output, ["INFO:django_ca.tasks:ACME is not enabled, not doing anything."]) + assert logcm.output == ["INFO:django_ca.tasks:ACME is not enabled, not doing anything."] - self.assertEqual(AcmeOrder.objects.all().count(), 1) - self.assertEqual(AcmeAuthorization.objects.all().count(), 1) - self.assertEqual(AcmeChallenge.objects.all().count(), 1) - self.assertEqual(AcmeCertificate.objects.all().count(), 1) + assert AcmeOrder.objects.all().count() == 1 + assert AcmeAuthorization.objects.all().count() == 1 + assert AcmeChallenge.objects.all().count() == 1 + assert AcmeCertificate.objects.all().count() == 1 diff --git a/ca/django_ca/tests/test_typehints.py b/ca/django_ca/tests/test_typehints.py index 43cc16c88..d84c906e3 100644 --- a/ca/django_ca/tests/test_typehints.py +++ b/ca/django_ca/tests/test_typehints.py @@ -37,7 +37,7 @@ def test_end_entity_certificate_extension_keys() -> None: @pytest.mark.parametrize( - "extension_types,extensions", + ("extension_types", "extensions"), ( (typehints.ConfigurableExtensionType, typehints.ConfigurableExtension), (typehints.EndEntityCertificateExtensionType, typehints.EndEntityCertificateExtension), diff --git a/ca/django_ca/tests/test_utils.py b/ca/django_ca/tests/test_utils.py index dd944113d..632cb32f7 100644 --- a/ca/django_ca/tests/test_utils.py +++ b/ca/django_ca/tests/test_utils.py @@ -78,8 +78,8 @@ def test_read_file(tmpcadir: Path) -> None: @pytest.mark.parametrize( - "attributes,expected", - [ + ("attributes", "expected"), + ( ([(NameOID.COMMON_NAME, "example.com")], [cn("example.com")]), ( [(NameOID.COUNTRY_NAME, "AT"), (NameOID.COMMON_NAME, "example.com")], @@ -89,7 +89,7 @@ def test_read_file(tmpcadir: Path) -> None: [(NameOID.X500_UNIQUE_IDENTIFIER, "65:78:61:6D:70:6C:65")], [x509.NameAttribute(NameOID.X500_UNIQUE_IDENTIFIER, b"example", _type=_ASN1Type.BitString)], ), - ], + ), ) def test_parse_serialized_name_attributes( attributes: list[tuple[x509.ObjectIdentifier, str]], expected: list[x509.NameAttribute] @@ -107,26 +107,26 @@ class GeneratePrivateKeyTestCase(TestCase): def test_key_types(self) -> None: """Test generating various private key types.""" ec_key = generate_private_key(None, "EC", ec.BrainpoolP256R1()) - self.assertIsInstance(ec_key, ec.EllipticCurvePrivateKey) - self.assertIsInstance(ec_key.curve, ec.BrainpoolP256R1) + assert isinstance(ec_key, ec.EllipticCurvePrivateKey) + assert isinstance(ec_key.curve, ec.BrainpoolP256R1) ed448_key = generate_private_key(None, "Ed448", None) - self.assertIsInstance(ed448_key, ed448.Ed448PrivateKey) + assert isinstance(ed448_key, ed448.Ed448PrivateKey) def test_dsa_default_key_size(self) -> None: """Test the default DSA key size.""" key = generate_private_key(None, "DSA", None) - self.assertIsInstance(key, dsa.DSAPrivateKey) - self.assertEqual(key.key_size, model_settings.CA_DEFAULT_KEY_SIZE) + assert isinstance(key, dsa.DSAPrivateKey) + assert key.key_size == model_settings.CA_DEFAULT_KEY_SIZE def test_invalid_type(self) -> None: """Test passing an invalid key type.""" - with self.assertRaisesRegex(ValueError, r"^FOO: Unknown key type\.$"): + with pytest.raises(ValueError, match=r"^FOO: Unknown key type\.$"): generate_private_key(16, "FOO", None) # type: ignore[call-overload] @pytest.mark.parametrize( - "general_name,expected", + ("general_name", "expected"), ( (dns("example.com"), "DNS:example.com"), (x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), "IP:127.0.0.1"), @@ -147,14 +147,11 @@ class SerializeName(TestCase): def test_name(self) -> None: """Test passing a standard Name.""" - self.assertEqual( - serialize_name(x509.Name([cn("example.com")])), - [{"oid": "2.5.4.3", "value": "example.com"}], - ) - self.assertEqual( - serialize_name(x509.Name([country("AT"), cn("example.com")])), - [{"oid": "2.5.4.6", "value": "AT"}, {"oid": "2.5.4.3", "value": "example.com"}], - ) + assert serialize_name(x509.Name([cn("example.com")])) == [{"oid": "2.5.4.3", "value": "example.com"}] + assert serialize_name(x509.Name([country("AT"), cn("example.com")])) == [ + {"oid": "2.5.4.6", "value": "AT"}, + {"oid": "2.5.4.3", "value": "example.com"}, + ] @unittest.skipIf(CRYPTOGRAPHY_VERSION < (37, 0), "cg<36 does not yet have bytes.") def test_bytes(self) -> None: @@ -162,19 +159,12 @@ def test_bytes(self) -> None: name = x509.Name( [x509.NameAttribute(NameOID.X500_UNIQUE_IDENTIFIER, b"example.com", _type=_ASN1Type.BitString)] ) - self.assertEqual( - serialize_name(name), [{"oid": "2.5.4.45", "value": "65:78:61:6D:70:6C:65:2E:63:6F:6D"}] - ) + assert serialize_name(name) == [{"oid": "2.5.4.45", "value": "65:78:61:6D:70:6C:65:2E:63:6F:6D"}] @pytest.mark.parametrize( - "value,expected", - ( - ("PEM", Encoding.PEM), - ("DER", Encoding.DER), - ("ASN1", Encoding.DER), - ("OpenSSH", Encoding.OpenSSH), - ), + ("value", "expected"), + (("PEM", Encoding.PEM), ("DER", Encoding.DER), ("ASN1", Encoding.DER), ("OpenSSH", Encoding.OpenSSH)), ) def test_parse_encoding(value: Any, expected: Encoding) -> None: """Test :py:func:`django_ca.utils.parse_encoding`.""" @@ -192,36 +182,36 @@ class AddColonsTestCase(TestCase): def test_basic(self) -> None: """Some basic tests.""" - self.assertEqual(utils.add_colons(""), "") - self.assertEqual(utils.add_colons("a"), "0a") - self.assertEqual(utils.add_colons("ab"), "ab") - self.assertEqual(utils.add_colons("abc"), "0a:bc") - self.assertEqual(utils.add_colons("abcd"), "ab:cd") - self.assertEqual(utils.add_colons("abcde"), "0a:bc:de") - self.assertEqual(utils.add_colons("abcdef"), "ab:cd:ef") - self.assertEqual(utils.add_colons("abcdefg"), "0a:bc:de:fg") + assert utils.add_colons("") == "" + assert utils.add_colons("a") == "0a" + assert utils.add_colons("ab") == "ab" + assert utils.add_colons("abc") == "0a:bc" + assert utils.add_colons("abcd") == "ab:cd" + assert utils.add_colons("abcde") == "0a:bc:de" + assert utils.add_colons("abcdef") == "ab:cd:ef" + assert utils.add_colons("abcdefg") == "0a:bc:de:fg" def test_pad(self) -> None: """Test padding.""" - self.assertEqual(utils.add_colons("a", pad="z"), "za") - self.assertEqual(utils.add_colons("ab", pad="z"), "ab") - self.assertEqual(utils.add_colons("abc", pad="z"), "za:bc") + assert utils.add_colons("a", pad="z") == "za" + assert utils.add_colons("ab", pad="z") == "ab" + assert utils.add_colons("abc", pad="z") == "za:bc" def test_no_pad(self) -> None: """Test disabling padding.""" - self.assertEqual(utils.add_colons("a", pad=""), "a") - self.assertEqual(utils.add_colons("ab", pad=""), "ab") - self.assertEqual(utils.add_colons("abc", pad=""), "ab:c") + assert utils.add_colons("a", pad="") == "a" + assert utils.add_colons("ab", pad="") == "ab" + assert utils.add_colons("abc", pad="") == "ab:c" def test_zero_padding(self) -> None: """Test when there is no padding.""" - self.assertEqual( - utils.add_colons("F570A555BC5000FA301E8C75FFB31684FCF64436"), - "F5:70:A5:55:BC:50:00:FA:30:1E:8C:75:FF:B3:16:84:FC:F6:44:36", + assert ( + utils.add_colons("F570A555BC5000FA301E8C75FFB31684FCF64436") + == "F5:70:A5:55:BC:50:00:FA:30:1E:8C:75:FF:B3:16:84:FC:F6:44:36" ) - self.assertEqual( - utils.add_colons("85BDA79A857379A4C9E910DAEA21C896D16394"), - "85:BD:A7:9A:85:73:79:A4:C9:E9:10:DA:EA:21:C8:96:D1:63:94", + assert ( + utils.add_colons("85BDA79A857379A4C9E910DAEA21C896D16394") + == "85:BD:A7:9A:85:73:79:A4:C9:E9:10:DA:EA:21:C8:96:D1:63:94" ) @@ -230,75 +220,75 @@ class IntToHexTestCase(TestCase): def test_basic(self) -> None: """Test the first view numbers.""" - self.assertEqual(utils.int_to_hex(0), "0") - self.assertEqual(utils.int_to_hex(1), "1") - self.assertEqual(utils.int_to_hex(2), "2") - self.assertEqual(utils.int_to_hex(3), "3") - self.assertEqual(utils.int_to_hex(4), "4") - self.assertEqual(utils.int_to_hex(5), "5") - self.assertEqual(utils.int_to_hex(6), "6") - self.assertEqual(utils.int_to_hex(7), "7") - self.assertEqual(utils.int_to_hex(8), "8") - self.assertEqual(utils.int_to_hex(9), "9") - self.assertEqual(utils.int_to_hex(10), "A") - self.assertEqual(utils.int_to_hex(11), "B") - self.assertEqual(utils.int_to_hex(12), "C") - self.assertEqual(utils.int_to_hex(13), "D") - self.assertEqual(utils.int_to_hex(14), "E") - self.assertEqual(utils.int_to_hex(15), "F") - self.assertEqual(utils.int_to_hex(16), "10") - self.assertEqual(utils.int_to_hex(17), "11") - self.assertEqual(utils.int_to_hex(18), "12") - self.assertEqual(utils.int_to_hex(19), "13") - self.assertEqual(utils.int_to_hex(20), "14") - self.assertEqual(utils.int_to_hex(21), "15") - self.assertEqual(utils.int_to_hex(22), "16") - self.assertEqual(utils.int_to_hex(23), "17") - self.assertEqual(utils.int_to_hex(24), "18") - self.assertEqual(utils.int_to_hex(25), "19") - self.assertEqual(utils.int_to_hex(26), "1A") - self.assertEqual(utils.int_to_hex(27), "1B") - self.assertEqual(utils.int_to_hex(28), "1C") - self.assertEqual(utils.int_to_hex(29), "1D") - self.assertEqual(utils.int_to_hex(30), "1E") - self.assertEqual(utils.int_to_hex(31), "1F") - self.assertEqual(utils.int_to_hex(32), "20") - self.assertEqual(utils.int_to_hex(33), "21") - self.assertEqual(utils.int_to_hex(34), "22") - self.assertEqual(utils.int_to_hex(35), "23") - self.assertEqual(utils.int_to_hex(36), "24") - self.assertEqual(utils.int_to_hex(37), "25") - self.assertEqual(utils.int_to_hex(38), "26") - self.assertEqual(utils.int_to_hex(39), "27") - self.assertEqual(utils.int_to_hex(40), "28") - self.assertEqual(utils.int_to_hex(41), "29") - self.assertEqual(utils.int_to_hex(42), "2A") - self.assertEqual(utils.int_to_hex(43), "2B") - self.assertEqual(utils.int_to_hex(44), "2C") - self.assertEqual(utils.int_to_hex(45), "2D") - self.assertEqual(utils.int_to_hex(46), "2E") - self.assertEqual(utils.int_to_hex(47), "2F") - self.assertEqual(utils.int_to_hex(48), "30") - self.assertEqual(utils.int_to_hex(49), "31") + assert utils.int_to_hex(0) == "0" + assert utils.int_to_hex(1) == "1" + assert utils.int_to_hex(2) == "2" + assert utils.int_to_hex(3) == "3" + assert utils.int_to_hex(4) == "4" + assert utils.int_to_hex(5) == "5" + assert utils.int_to_hex(6) == "6" + assert utils.int_to_hex(7) == "7" + assert utils.int_to_hex(8) == "8" + assert utils.int_to_hex(9) == "9" + assert utils.int_to_hex(10) == "A" + assert utils.int_to_hex(11) == "B" + assert utils.int_to_hex(12) == "C" + assert utils.int_to_hex(13) == "D" + assert utils.int_to_hex(14) == "E" + assert utils.int_to_hex(15) == "F" + assert utils.int_to_hex(16) == "10" + assert utils.int_to_hex(17) == "11" + assert utils.int_to_hex(18) == "12" + assert utils.int_to_hex(19) == "13" + assert utils.int_to_hex(20) == "14" + assert utils.int_to_hex(21) == "15" + assert utils.int_to_hex(22) == "16" + assert utils.int_to_hex(23) == "17" + assert utils.int_to_hex(24) == "18" + assert utils.int_to_hex(25) == "19" + assert utils.int_to_hex(26) == "1A" + assert utils.int_to_hex(27) == "1B" + assert utils.int_to_hex(28) == "1C" + assert utils.int_to_hex(29) == "1D" + assert utils.int_to_hex(30) == "1E" + assert utils.int_to_hex(31) == "1F" + assert utils.int_to_hex(32) == "20" + assert utils.int_to_hex(33) == "21" + assert utils.int_to_hex(34) == "22" + assert utils.int_to_hex(35) == "23" + assert utils.int_to_hex(36) == "24" + assert utils.int_to_hex(37) == "25" + assert utils.int_to_hex(38) == "26" + assert utils.int_to_hex(39) == "27" + assert utils.int_to_hex(40) == "28" + assert utils.int_to_hex(41) == "29" + assert utils.int_to_hex(42) == "2A" + assert utils.int_to_hex(43) == "2B" + assert utils.int_to_hex(44) == "2C" + assert utils.int_to_hex(45) == "2D" + assert utils.int_to_hex(46) == "2E" + assert utils.int_to_hex(47) == "2F" + assert utils.int_to_hex(48) == "30" + assert utils.int_to_hex(49) == "31" def test_high(self) -> None: """Test some high numbers.""" - self.assertEqual(utils.int_to_hex(1513282098), "5A32DA32") - self.assertEqual(utils.int_to_hex(1513282099), "5A32DA33") - self.assertEqual(utils.int_to_hex(1513282100), "5A32DA34") - self.assertEqual(utils.int_to_hex(1513282101), "5A32DA35") - self.assertEqual(utils.int_to_hex(1513282102), "5A32DA36") - self.assertEqual(utils.int_to_hex(1513282103), "5A32DA37") - self.assertEqual(utils.int_to_hex(1513282104), "5A32DA38") - self.assertEqual(utils.int_to_hex(1513282105), "5A32DA39") - self.assertEqual(utils.int_to_hex(1513282106), "5A32DA3A") - self.assertEqual(utils.int_to_hex(1513282107), "5A32DA3B") - self.assertEqual(utils.int_to_hex(1513282108), "5A32DA3C") - self.assertEqual(utils.int_to_hex(1513282109), "5A32DA3D") - self.assertEqual(utils.int_to_hex(1513282110), "5A32DA3E") - self.assertEqual(utils.int_to_hex(1513282111), "5A32DA3F") - self.assertEqual(utils.int_to_hex(1513282112), "5A32DA40") - self.assertEqual(utils.int_to_hex(1513282113), "5A32DA41") + assert utils.int_to_hex(1513282098) == "5A32DA32" + assert utils.int_to_hex(1513282099) == "5A32DA33" + assert utils.int_to_hex(1513282100) == "5A32DA34" + assert utils.int_to_hex(1513282101) == "5A32DA35" + assert utils.int_to_hex(1513282102) == "5A32DA36" + assert utils.int_to_hex(1513282103) == "5A32DA37" + assert utils.int_to_hex(1513282104) == "5A32DA38" + assert utils.int_to_hex(1513282105) == "5A32DA39" + assert utils.int_to_hex(1513282106) == "5A32DA3A" + assert utils.int_to_hex(1513282107) == "5A32DA3B" + assert utils.int_to_hex(1513282108) == "5A32DA3C" + assert utils.int_to_hex(1513282109) == "5A32DA3D" + assert utils.int_to_hex(1513282110) == "5A32DA3E" + assert utils.int_to_hex(1513282111) == "5A32DA3F" + assert utils.int_to_hex(1513282112) == "5A32DA40" + assert utils.int_to_hex(1513282113) == "5A32DA41" class BytesToHexTestCase(TestCase): @@ -306,11 +296,11 @@ class BytesToHexTestCase(TestCase): def test_basic(self) -> None: """Some basic test cases.""" - self.assertEqual(bytes_to_hex(b"test"), "74:65:73:74") - self.assertEqual(bytes_to_hex(b"foo"), "66:6F:6F") - self.assertEqual(bytes_to_hex(b"bar"), "62:61:72") - self.assertEqual(bytes_to_hex(b""), "") - self.assertEqual(bytes_to_hex(b"a"), "61") + assert bytes_to_hex(b"test") == "74:65:73:74" + assert bytes_to_hex(b"foo") == "66:6F:6F" + assert bytes_to_hex(b"bar") == "62:61:72" + assert bytes_to_hex(b"") == "" + assert bytes_to_hex(b"a") == "61" class SanitizeSerialTestCase(TestCase): @@ -318,23 +308,23 @@ class SanitizeSerialTestCase(TestCase): def test_already_sanitized(self) -> None: """Test some already sanitized input.""" - self.assertEqual(utils.sanitize_serial("A"), "A") - self.assertEqual(utils.sanitize_serial("5A32DA3B"), "5A32DA3B") - self.assertEqual(utils.sanitize_serial("1234567890ABCDEF"), "1234567890ABCDEF") + assert utils.sanitize_serial("A") == "A" + assert utils.sanitize_serial("5A32DA3B") == "5A32DA3B" + assert utils.sanitize_serial("1234567890ABCDEF") == "1234567890ABCDEF" def test_sanitized(self) -> None: """Test some input that can be correctly sanitized.""" - self.assertEqual(utils.sanitize_serial("5A:32:DA:3B"), "5A32DA3B") - self.assertEqual(utils.sanitize_serial("0A:32:DA:3B"), "A32DA3B") - self.assertEqual(utils.sanitize_serial("0a:32:da:3b"), "A32DA3B") + assert utils.sanitize_serial("5A:32:DA:3B") == "5A32DA3B" + assert utils.sanitize_serial("0A:32:DA:3B") == "A32DA3B" + assert utils.sanitize_serial("0a:32:da:3b") == "A32DA3B" def test_zero(self) -> None: """An imported CA might have a serial of just a ``0``, so it must not be stripped.""" - self.assertEqual(utils.sanitize_serial("0"), "0") + assert utils.sanitize_serial("0") == "0" def test_invalid_input(self) -> None: """Test some input that raises an exception.""" - with self.assertRaisesRegex(ValueError, r"^ABCXY: Serial has invalid characters$"): + with pytest.raises(ValueError, match=r"^ABCXY: Serial has invalid characters$"): utils.sanitize_serial("ABCXY") @@ -364,13 +354,13 @@ def test_str(self) -> None: ("CN", "example.com"), ("emailAddress", "user@example.com"), ] - self.assertEqual(x509_name(subject), self.name) + assert x509_name(subject) == self.name def test_multiple_other(self) -> None: """Test multiple other tokens (only OUs work).""" - with self.assertRaisesRegex(ValueError, '^Subject contains multiple "countryName" fields$'): + with pytest.raises(ValueError, match='^Subject contains multiple "countryName" fields$'): x509_name([("C", "AT"), ("C", "DE")]) - with self.assertRaisesRegex(ValueError, '^Subject contains multiple "commonName" fields$'): + with pytest.raises(ValueError, match='^Subject contains multiple "commonName" fields$'): x509_name([("CN", "AT"), ("CN", "FOO")]) @@ -400,7 +390,7 @@ def assertMerged( # pylint: disable=invalid-name # unittest standard base_name = x509.Name(base) update_name = x509.Name(update) merged_name = x509.Name(merged) - self.assertEqual(merge_x509_names(base_name, update_name), merged_name) + assert merge_x509_names(base_name, update_name) == merged_name def test_full_merge(self) -> None: """Test a basic merge.""" @@ -442,9 +432,9 @@ def test_unsortable_values(self) -> None: """Test merging unsortable values.""" sortable = x509.Name([self.cc1, self.common_name1]) unsortable = x509.Name([self.cc1, x509.NameAttribute(NameOID.INN, "unsortable")]) - with self.assertRaisesRegex(ValueError, r"Unsortable name"): + with pytest.raises(ValueError, match=r"Unsortable name"): merge_x509_names(unsortable, sortable) - with self.assertRaisesRegex(ValueError, r"Unsortable name"): + with pytest.raises(ValueError, match=r"Unsortable name"): merge_x509_names(sortable, unsortable) @@ -462,47 +452,43 @@ def test_basic(self) -> None: # pylint: disable=protected-access; only way to test builder attributes after = datetime(2020, 10, 23, 11, 21, tzinfo=tz.utc) builder = get_cert_builder(after) - self.assertEqual(builder._not_valid_before, datetime(2018, 11, 3, 11, 21)) - self.assertEqual(builder._not_valid_after, datetime(2020, 10, 23, 11, 21)) - self.assertIsInstance(builder._serial_number, int) + assert builder._not_valid_before == datetime(2018, 11, 3, 11, 21) + assert builder._not_valid_after == datetime(2020, 10, 23, 11, 21) + assert isinstance(builder._serial_number, int) @freeze_time("2021-01-23 14:42:11.1234") def test_datetime(self) -> None: """Basic tests.""" expires = datetime.now(tz.utc) + timedelta(days=10) - self.assertNotEqual(expires.second, 0) - self.assertNotEqual(expires.microsecond, 0) + assert expires.second != 0 + assert expires.microsecond != 0 expires_expected = datetime(2021, 2, 2, 14, 42) builder = get_cert_builder(expires) - self.assertEqual(builder._not_valid_after, expires_expected) # pylint: disable=protected-access - self.assertIsInstance(builder._serial_number, int) # pylint: disable=protected-access + assert builder._not_valid_after == expires_expected # pylint: disable=protected-access + assert isinstance(builder._serial_number, int) # pylint: disable=protected-access @freeze_time("2021-01-23 14:42:11.1234") def test_serial(self) -> None: """Test manually setting a serial.""" after = datetime(2022, 10, 23, 11, 21, tzinfo=tz.utc) builder = get_cert_builder(after, serial=123) - self.assertEqual(builder._serial_number, 123) # pylint: disable=protected-access - self.assertEqual( - builder._not_valid_after, # pylint: disable=protected-access - datetime(2022, 10, 23, 11, 21), - ) + assert builder._serial_number == 123 # pylint: disable=protected-access + assert builder._not_valid_after == datetime(2022, 10, 23, 11, 21) # pylint: disable=protected-access @freeze_time("2021-01-23 14:42:11") def test_negative_datetime(self) -> None: """Test passing a datetime in the past.""" - msg = r"^not_after must be in the future$" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=r"^not_after must be in the future$"): get_cert_builder(datetime.now(tz.utc) - timedelta(seconds=60)) def test_invalid_type(self) -> None: """Test passing an invalid type.""" - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): get_cert_builder("a string") # type: ignore[arg-type] def test_naive_datetime(self) -> None: """Test passing a naive datetime.""" - with self.assertRaisesRegex(ValueError, r"^not_after must not be a naive datetime$"): + with pytest.raises(ValueError, match=r"^not_after must not be a naive datetime$"): get_cert_builder(datetime.now()) @@ -530,40 +516,40 @@ def test_default_parameters(self) -> None: def test_valid_parameters(self) -> None: """Test valid parameters.""" - self.assertEqual((8192, None), validate_private_key_parameters("RSA", 8192, None)) - self.assertEqual((8192, None), validate_private_key_parameters("DSA", 8192, None)) + assert validate_private_key_parameters("RSA", 8192, None) == (8192, None) + assert validate_private_key_parameters("DSA", 8192, None) == (8192, None) key_size, elliptic_curve = validate_private_key_parameters("EC", None, ec.BrainpoolP384R1()) - self.assertIsNone(key_size) - self.assertIsInstance(elliptic_curve, ec.BrainpoolP384R1) + assert key_size is None + assert isinstance(elliptic_curve, ec.BrainpoolP384R1) def test_wrong_values(self) -> None: """Test validating various bogus values.""" key_size = model_settings.CA_DEFAULT_KEY_SIZE elliptic_curve = model_settings.CA_DEFAULT_ELLIPTIC_CURVE - with self.assertRaisesRegex(ValueError, "^FOOBAR: Unknown key type$"): + with pytest.raises(ValueError, match="^FOOBAR: Unknown key type$"): validate_private_key_parameters("FOOBAR", 4096, None) # type: ignore[call-overload] - with self.assertRaisesRegex(ValueError, r"^foo: Key size must be an int\.$"): + with pytest.raises(ValueError, match=r"^foo: Key size must be an int\.$"): validate_private_key_parameters("RSA", "foo", None) # type: ignore[call-overload] - with self.assertRaisesRegex(ValueError, "^4000: Key size must be a power of two$"): + with pytest.raises(ValueError, match="^4000: Key size must be a power of two$"): validate_private_key_parameters("RSA", 4000, None) - with self.assertRaisesRegex(ValueError, "^16: Key size must be least 1024 bits$"): + with pytest.raises(ValueError, match="^16: Key size must be least 1024 bits$"): validate_private_key_parameters("RSA", 16, None) - with self.assertRaisesRegex(ValueError, r"^Key size is not supported for EC keys\.$"): + with pytest.raises(ValueError, match=r"^Key size is not supported for EC keys\.$"): validate_private_key_parameters("EC", key_size, elliptic_curve) - with self.assertRaisesRegex(ValueError, r"^secp192r1: Must be a subclass of ec\.EllipticCurve$"): + with pytest.raises(ValueError, match=r"^secp192r1: Must be a subclass of ec\.EllipticCurve$"): validate_private_key_parameters("EC", None, "secp192r1") # type: ignore for key_type in ("Ed448", "Ed25519"): - with self.assertRaisesRegex(ValueError, rf"^Key size is not supported for {key_type} keys\.$"): + with pytest.raises(ValueError, match=rf"^Key size is not supported for {key_type} keys\.$"): validate_private_key_parameters(key_type, key_size, None) # type: ignore - with self.assertRaisesRegex( - ValueError, rf"^Elliptic curves are not supported for {key_type} keys\.$" + with pytest.raises( + ValueError, match=rf"^Elliptic curves are not supported for {key_type} keys\.$" ): validate_private_key_parameters(key_type, None, elliptic_curve) # type: ignore @@ -581,14 +567,14 @@ def test_valid_parameters(self) -> None: def test_invalid_parameters(self) -> None: """Test invalid parameters.""" - with self.assertRaisesRegex(ValueError, "^FOOBAR: Unknown key type$"): + with pytest.raises(ValueError, match="^FOOBAR: Unknown key type$"): validate_public_key_parameters("FOOBAR", None) # type: ignore[arg-type] for key_type in ("RSA", "DSA", "EC"): msg = rf"^{key_type}: algorithm must be an instance of hashes.HashAlgorithm\.$" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): validate_public_key_parameters(key_type, True) # type: ignore[arg-type] for key_type in ("Ed448", "Ed25519"): msg = rf"^{key_type} keys do not allow an algorithm for signing\.$" - with self.assertRaisesRegex(ValueError, msg): + with pytest.raises(ValueError, match=msg): validate_public_key_parameters(key_type, hashes.SHA256()) # type: ignore[arg-type] diff --git a/ca/django_ca/tests/test_views_ocsp.py b/ca/django_ca/tests/test_views_ocsp.py index dcf347574..61b8572a3 100644 --- a/ca/django_ca/tests/test_views_ocsp.py +++ b/ca/django_ca/tests/test_views_ocsp.py @@ -42,13 +42,7 @@ from django_ca.key_backends.storages import StoragesUsePrivateKeyOptions from django_ca.modelfields import LazyCertificate from django_ca.models import Certificate, CertificateAuthority -from django_ca.tests.base.constants import ( - CERT_DATA, - CRYPTOGRAPHY_VERSION, - FIXTURES_DATA, - FIXTURES_DIR, - TIMESTAMPS, -) +from django_ca.tests.base.constants import CERT_DATA, FIXTURES_DATA, FIXTURES_DIR, TIMESTAMPS from django_ca.tests.base.mixins import TestCaseMixin from django_ca.tests.base.typehints import HttpResponse from django_ca.tests.base.utils import override_tmpcadir @@ -190,12 +184,14 @@ def assertOCSPSignature( # pylint: disable=invalid-name if isinstance(public_key, rsa.RSAPublicKey): hash_algorithm = typing.cast(hashes.HashAlgorithm, hash_algorithm) # to make mypy happy - self.assertIsNone( + assert ( public_key.verify(response.signature, tbs_response, padding.PKCS1v15(), hash_algorithm) + is None ) + elif isinstance(public_key, ec.EllipticCurvePublicKey): hash_algorithm = typing.cast(hashes.HashAlgorithm, hash_algorithm) # to make mypy happy - self.assertIsNone(public_key.verify(response.signature, tbs_response, ec.ECDSA(hash_algorithm))) + assert public_key.verify(response.signature, tbs_response, ec.ECDSA(hash_algorithm)) is None elif isinstance(public_key, dsa.DSAPublicKey): hash_algorithm = typing.cast(hashes.HashAlgorithm, hash_algorithm) # to make mypy happy public_key.verify(response.signature, tbs_response, hash_algorithm) @@ -213,22 +209,13 @@ def assertCertificateStatus( # pylint: disable=invalid-name ) -> None: """Check information related to the certificate status.""" if certificate.revoked is False: - self.assertEqual(response.certificate_status, ocsp.OCSPCertStatus.GOOD) - if CRYPTOGRAPHY_VERSION < (43,): # pragma: only cg<=42 - self.assertIsNone(response.revocation_time) - else: - self.assertIsNone(response.revocation_time_utc) - self.assertIsNone(response.revocation_reason) + assert response.certificate_status == ocsp.OCSPCertStatus.GOOD + assert response.revocation_time_utc is None + assert response.revocation_reason is None else: - self.assertEqual(response.certificate_status, ocsp.OCSPCertStatus.REVOKED) - self.assertEqual(response.revocation_reason, certificate.get_revocation_reason()) - if CRYPTOGRAPHY_VERSION < (43,): # pragma: only cg<=42 - self.assertEqual( - response.revocation_time.replace(tzinfo=timezone.utc), # type: ignore[union-attr] - certificate.get_revocation_time(), - ) - else: - self.assertEqual(response.revocation_time_utc, certificate.get_revocation_time()) + assert response.certificate_status == ocsp.OCSPCertStatus.REVOKED + assert response.revocation_reason == certificate.get_revocation_reason() + assert response.revocation_time_utc == certificate.get_revocation_time() def assertOCSPSingleResponse( # pylint: disable=invalid-name self, @@ -241,8 +228,8 @@ def assertOCSPSingleResponse( # pylint: disable=invalid-name Note that `hash_algorithm` cannot be ``None``, as it must match the algorithm of the OCSP request. """ self.assertCertificateStatus(certificate, response) - self.assertEqual(response.serial_number, certificate.pub.loaded.serial_number) - self.assertIsInstance(response.hash_algorithm, hash_algorithm) + assert response.serial_number == certificate.pub.loaded.serial_number + assert isinstance(response.hash_algorithm, hash_algorithm) def assertOCSPResponse( # pylint: disable=invalid-name self, @@ -260,47 +247,43 @@ def assertOCSPResponse( # pylint: disable=invalid-name if responder_certificate is None: responder_certificate = self.certs["profile-ocsp"] - self.assertEqual(http_response["Content-Type"], "application/ocsp-response") + assert http_response["Content-Type"] == "application/ocsp-response" response = ocsp.load_der_ocsp_response(http_response.content) - self.assertEqual(response.response_status, response_status) + assert response.response_status == response_status if signature_hash_algorithm is None: - self.assertIsNone(response.signature_hash_algorithm) + assert response.signature_hash_algorithm is None else: - self.assertIsInstance(response.signature_hash_algorithm, signature_hash_algorithm) - self.assertEqual(response.signature_algorithm_oid, signature_algorithm_oid) - self.assertEqual(response.certificates, [responder_certificate.pub.loaded]) # responder certificate! - self.assertIsNone(response.responder_name) - self.assertIsInstance(response.responder_key_hash, bytes) # TODO: Validate responder id + assert isinstance(response.signature_hash_algorithm, signature_hash_algorithm) + assert response.signature_algorithm_oid == signature_algorithm_oid + assert response.certificates == [responder_certificate.pub.loaded] # responder certificate! + assert response.responder_name is None + assert isinstance(response.responder_key_hash, bytes) # TODO: Validate responder id # TODO: validate issuer_key_hash, issuer_name_hash # Check TIMESTAMPS # self.assertEqual(response.produced_at, datetime.now()) - if CRYPTOGRAPHY_VERSION < (43,): # pragma: only cg<=42 - self.assertEqual(response.this_update, datetime.now()) - self.assertEqual(response.next_update, datetime.now() + timedelta(seconds=expires)) - else: - now = datetime.now(tz=timezone.utc) - self.assertEqual(response.this_update_utc, now) - self.assertEqual(response.next_update_utc, now + timedelta(seconds=expires)) + now = datetime.now(tz=timezone.utc) + assert response.this_update_utc == now + assert response.next_update_utc == now + timedelta(seconds=expires) # Check nonce if passed if nonce is None: - self.assertEqual(len(response.extensions), 0) + assert len(response.extensions) == 0 else: nonce_extension = response.extensions.get_extension_for_oid(OCSPExtensionOID.NONCE) - self.assertIs(nonce_extension.critical, False) - self.assertEqual(nonce_extension.value.nonce, nonce) # type: ignore[attr-defined] + assert nonce_extension.critical is False + assert nonce_extension.value.nonce == nonce # type: ignore[attr-defined] - self.assertEqual(response.serial_number, requested_certificate.pub.loaded.serial_number) + assert response.serial_number == requested_certificate.pub.loaded.serial_number # Check the certificate status self.assertCertificateStatus(requested_certificate, response) # Assert single response single_responses = list(response.responses) # otherwise it has no len()/index - self.assertEqual(len(single_responses), 1) + assert len(single_responses) == 1 self.assertOCSPSingleResponse( requested_certificate, single_responses[0], single_response_hash_algorithm ) @@ -345,7 +328,7 @@ def ocsp_get( }, ) response = self.client.get(url) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK return response @@ -363,7 +346,7 @@ def test_get(self) -> None: """Basic GET test.""" data = base64.b64encode(req1).decode("utf-8") response = self.client.get(reverse("get", kwargs={"data": data})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertOCSPResponse( response, requested_certificate=self.cert, @@ -375,9 +358,9 @@ def test_get(self) -> None: def test_bad_query(self) -> None: """Test sending a bad query.""" response = self.client.get(reverse("get", kwargs={"data": "XXX"})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.MALFORMED_REQUEST) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.MALFORMED_REQUEST def test_raises_exception(self) -> None: """Generic test if the handling function throws any uncaught exception.""" @@ -389,26 +372,26 @@ def test_raises_exception(self) -> None: with mock.patch(view_path, side_effect=ex), self.assertLogs() as logcm: response = self.client.get(reverse("get", kwargs={"data": data})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) - self.assertEqual(len(logcm.output), 1) - self.assertIn(exception_str, logcm.output[0]) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR + assert len(logcm.output) == 1 + assert exception_str in logcm.output[0] # also do a post request with mock.patch(view_path, side_effect=ex), self.assertLogs() as logcm: response = self.client.post(reverse("post"), req1, content_type="application/ocsp-request") - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) - self.assertEqual(len(logcm.output), 1) - self.assertIn(exception_str, logcm.output[0]) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR + assert len(logcm.output) == 1 + assert exception_str in logcm.output[0] @override_tmpcadir() def test_post(self) -> None: """Test the post request.""" response = self.client.post(reverse("post"), req1, content_type="application/ocsp-request") - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertOCSPResponse( response, requested_certificate=self.cert, @@ -423,7 +406,7 @@ def test_post(self) -> None: content_type="application/ocsp-request", single_response_hash_algorithm=hashes.SHA1, ) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertOCSPResponse( response, requested_certificate=self.cert, @@ -433,7 +416,7 @@ def test_post(self) -> None: ) response = self.client.post(reverse("post-full-pem"), req1, content_type="application/ocsp-request") - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertOCSPResponse( response, requested_certificate=self.cert, @@ -448,7 +431,7 @@ def test_loaded_cryptography_cert(self) -> None: response = self.client.post( reverse("post-loaded-cryptography"), req1, content_type="application/ocsp-request" ) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertOCSPResponse( response, requested_certificate=self.cert, @@ -463,7 +446,7 @@ def test_revoked(self) -> None: self.cert.revoke() response = self.client.post(reverse("post"), req1, content_type="application/ocsp-request") - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertOCSPResponse( response, requested_certificate=self.cert, @@ -474,7 +457,7 @@ def test_revoked(self) -> None: self.cert.revoke(ReasonFlags.affiliation_changed) response = self.client.post(reverse("post"), req1, content_type="application/ocsp-request") - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertOCSPResponse( response, requested_certificate=self.cert, @@ -494,7 +477,7 @@ def test_ca_ocsp(self) -> None: data = base64.b64encode(req1).decode("utf-8") response = self.client.get(reverse("get-ca", kwargs={"data": data})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK self.assertOCSPResponse( response, requested_certificate=ca, @@ -508,31 +491,22 @@ def test_bad_ca(self) -> None: data = base64.b64encode(req1).decode("utf-8") with self.assertLogs() as logcm: response = self.client.get(reverse("unknown", kwargs={"data": data})) - self.assertEqual( - logcm.output, - [ - "ERROR:django_ca.views:unknown: Certificate Authority could not be found.", - ], - ) + assert logcm.output == ["ERROR:django_ca.views:unknown: Certificate Authority could not be found."] - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR def test_unknown(self) -> None: """Test fetching data for an unknown certificate.""" data = base64.b64encode(unknown_req).decode("utf-8") with self.assertLogs() as logcm: response = self.client.get(reverse("get", kwargs={"data": data})) - self.assertEqual( - logcm.output, - [ - "WARNING:django_ca.views:7B: OCSP request for unknown cert received.", - ], - ) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert logcm.output == ["WARNING:django_ca.views:7B: OCSP request for unknown cert received."] + + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR @override_tmpcadir() def test_unknown_ca(self) -> None: @@ -541,12 +515,11 @@ def test_unknown_ca(self) -> None: with self.assertLogs() as logcm: response = self.client.get(reverse("get-ca", kwargs={"data": data})) serial = self.certs["child-cert"].serial - self.assertEqual( - logcm.output, [f"WARNING:django_ca.views:{serial}: OCSP request for unknown CA received."] - ) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert logcm.output, [f"WARNING:django_ca.views:{serial}: OCSP request for unknown CA received."] + + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR @override_tmpcadir() def test_bad_private_key_type(self) -> None: @@ -563,14 +536,11 @@ def test_bad_private_key_type(self) -> None: ): response = self.client.get(reverse("get", kwargs={"data": data})) ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) - self.assertEqual( - logcm.output, - [ - "ERROR:django_ca.views:: Unsupported private key type.", - "ERROR:django_ca.views:Could not read responder key/cert.", - ], - ) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR + assert logcm.output == [ + "ERROR:django_ca.views:: Unsupported private key type.", + "ERROR:django_ca.views:Could not read responder key/cert.", + ] def test_bad_responder_cert(self) -> None: """Test the error when the private key cannot be read. @@ -581,32 +551,32 @@ def test_bad_responder_cert(self) -> None: with self.assertLogs() as logcm: response = self.client.get(reverse("get", kwargs={"data": data})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) - self.assertEqual(logcm.output, ["ERROR:django_ca.views:Could not read responder key/cert."]) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR + assert logcm.output == ["ERROR:django_ca.views:Could not read responder key/cert."] def test_bad_request(self) -> None: """Try making a bad request.""" data = base64.b64encode(b"foobar").decode("utf-8") with self.assertLogs() as logcm: response = self.client.get(reverse("get", kwargs={"data": data})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.MALFORMED_REQUEST) - self.assertEqual(len(logcm.output), 1) - self.assertIn("ValueError: error parsing asn1 value", logcm.output[0], logcm.output[0]) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.MALFORMED_REQUEST + assert len(logcm.output) == 1 + assert "ValueError: error parsing asn1 value" in logcm.output[0], logcm.output[0] def test_multiple(self) -> None: """Try making multiple OCSP requests (not currently supported).""" data = base64.b64encode(multiple_req).decode("utf-8") with self.assertLogs() as logcm: response = self.client.get(reverse("get", kwargs={"data": data})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.MALFORMED_REQUEST) - self.assertEqual(len(logcm.output), 1) - self.assertIn("OCSP request contains more than one request", logcm.output[0]) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.MALFORMED_REQUEST + assert len(logcm.output) == 1 + assert "OCSP request contains more than one request" in logcm.output[0] @override_tmpcadir() def test_bad_ca_cert(self) -> None: @@ -618,11 +588,11 @@ def test_bad_ca_cert(self) -> None: data = base64.b64encode(req1).decode("utf-8") with self.assertLogs() as logcm: response = self.client.get(reverse("get", kwargs={"data": data})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) - self.assertEqual(len(logcm.output), 1) - self.assertIn("ValueError: ", logcm.output[0]) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR + assert len(logcm.output) == 1 + assert "ValueError: " in logcm.output[0] @override_tmpcadir() def test_bad_responder_key(self) -> None: @@ -631,10 +601,10 @@ def test_bad_responder_key(self) -> None: with self.assertLogs() as logcm: response = self.client.get(reverse("false-key", kwargs={"data": data})) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) - self.assertEqual(logcm.output, ["ERROR:django_ca.views:Could not read responder key/cert."]) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR + assert logcm.output == ["ERROR:django_ca.views:Could not read responder key/cert."] @override_tmpcadir() def test_bad_responder_pem(self) -> None: @@ -644,16 +614,16 @@ def test_bad_responder_pem(self) -> None: with self.assertLogs() as logcm: response = self.client.get(reverse("false-pem-serial", kwargs={"data": data})) - self.assertEqual(logcm.output, [msg]) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert logcm.output == [msg] + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR with self.assertLogs() as logcm: response = self.client.get(reverse("false-pem-full", kwargs={"data": data})) - self.assertEqual(logcm.output, [msg]) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert logcm.output == [msg] + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR @override_settings(ROOT_URLCONF=__name__) @@ -810,10 +780,10 @@ def test_invalid_responder_key(self) -> None: with self.assertLogs() as logcm: response = self.ocsp_get(self.cert, hash_algorithm=hashes.SHA512) - self.assertEqual(logcm.output, ["ERROR:django_ca.views:Could not read responder key/cert."]) - self.assertEqual(response.status_code, HTTPStatus.OK) + assert logcm.output == ["ERROR:django_ca.views:Could not read responder key/cert."] + assert response.status_code == HTTPStatus.OK ocsp_response = ocsp.load_der_ocsp_response(response.content) - self.assertEqual(ocsp_response.response_status, ocsp.OCSPResponseStatus.INTERNAL_ERROR) + assert ocsp_response.response_status == ocsp.OCSPResponseStatus.INTERNAL_ERROR @override_tmpcadir() def test_ed25519_certificate_authority(self) -> None: @@ -837,8 +807,8 @@ def test_cert_method_not_allowed(self) -> None: """Try HTTP methods that are not allowed.""" url = reverse("django_ca:ocsp-cert-post", kwargs={"serial": "00AA"}) response = self.client.get(url) - self.assertEqual(response.status_code, 405) + assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED # 405 url = reverse("django_ca:ocsp-cert-get", kwargs={"serial": "00AA", "data": "irrelevant"}) response = self.client.post(url, req1, content_type="application/ocsp-request") - self.assertEqual(response.status_code, 405) + assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED # 405 diff --git a/ca/django_ca/tests/utils/test_get_crl_cache_key.py b/ca/django_ca/tests/utils/test_get_crl_cache_key.py index 22b293c1e..e35e361c0 100644 --- a/ca/django_ca/tests/utils/test_get_crl_cache_key.py +++ b/ca/django_ca/tests/utils/test_get_crl_cache_key.py @@ -33,7 +33,7 @@ @pytest.mark.parametrize( - "kwargs,expected", + ("kwargs", "expected"), ( (DEFAULT_KWARGS, "crl_123_DER_False_False_False_None"), ({**DEFAULT_KWARGS, "encoding": Encoding.PEM}, "crl_123_PEM_False_False_False_None"), diff --git a/ca/django_ca/tests/utils/test_othername.py b/ca/django_ca/tests/utils/test_othername.py index b0d07db9a..967a2b89d 100644 --- a/ca/django_ca/tests/utils/test_othername.py +++ b/ca/django_ca/tests/utils/test_othername.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize( - "value,expected,normalized", + ("value", "expected", "normalized"), ( ("UNIVERSALSTRING:ex", b"\x1c\x08\x00\x00\x00e\x00\x00\x00x", True), ("UNIV:ex", b"\x1c\x08\x00\x00\x00e\x00\x00\x00x", False), @@ -59,7 +59,8 @@ def test_parse_and_format_othername(value: str, expected: bytes, normalized: boo @pytest.mark.parametrize("typ", ("UTF8", "UTF8String")) @pytest.mark.parametrize( - "value,expected", (("example", b"\x0c\x07example"), ("example;wrong:val", b"\x0c\x11example;wrong:val")) + ("value", "expected"), + (("example", b"\x0c\x07example"), ("example;wrong:val", b"\x0c\x11example;wrong:val")), ) def test_othername_with_utf8(typ: str, value: str, expected: bytes) -> None: """Test UTF8 values.""" @@ -88,7 +89,7 @@ def test_othername_with_boolean_false(typ: str, value: str) -> None: @pytest.mark.parametrize("typ", ("INT", "INTEGER")) @pytest.mark.parametrize( - "raw_value,expected_bytes,formatted_value", + ("raw_value", "expected_bytes", "formatted_value"), ( ("0", b"\x02\x01\x00", "0"), ("1", b"\x02\x01\x01", "1"), @@ -104,7 +105,7 @@ def test_othername_integer(typ: str, raw_value: str, expected_bytes: bytes, form @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ( "2.4.5.3;BOOL:WRONG", diff --git a/ca/django_ca/tests/utils/test_parse_general_name.py b/ca/django_ca/tests/utils/test_parse_general_name.py index 5e9651ccd..fd7c61e96 100644 --- a/ca/django_ca/tests/utils/test_parse_general_name.py +++ b/ca/django_ca/tests/utils/test_parse_general_name.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize("prefix", ("", "ip:")) @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ("1.2.3.4", IPv4Address("1.2.3.4")), ("1.2.3.0/24", IPv4Network("1.2.3.0/24")), @@ -47,7 +47,7 @@ def test_ip( @pytest.mark.parametrize("prefix", ("", "DNS:")) @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ("example.com", dns("example.com")), (".example.com", dns(".example.com")), diff --git a/ca/django_ca/tests/utils/test_parse_name_rfc4514.py b/ca/django_ca/tests/utils/test_parse_name_rfc4514.py index 5445fc7e5..b4ccfccf8 100644 --- a/ca/django_ca/tests/utils/test_parse_name_rfc4514.py +++ b/ca/django_ca/tests/utils/test_parse_name_rfc4514.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ("CN=example.com", x509.Name([cn("example.com")])), (f"{NameOID.COMMON_NAME.dotted_string}=example.com", x509.Name([cn("example.com")])), @@ -37,7 +37,7 @@ def test_parse_name_rfc4514(value: str, expected: x509.Name) -> None: @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ( "C=FOO", @@ -57,7 +57,7 @@ def test_parse_name_rfc4514_with_error(value: str, expected: str) -> None: @pytest.mark.skipif(CRYPTOGRAPHY_VERSION < (43,), reason="cryptography check was added in version 43") @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ("CN=", r"^Attribute's length must be >= 1 and <= 64, but it was 0$"), (f"CN={'x' * 65}", r"^Attribute's length must be >= 1 and <= 64, but it was 65$"), diff --git a/ca/django_ca/tests/utils/test_parse_name_x509.py b/ca/django_ca/tests/utils/test_parse_name_x509.py index 03edf5a96..dd6353dc3 100644 --- a/ca/django_ca/tests/utils/test_parse_name_x509.py +++ b/ca/django_ca/tests/utils/test_parse_name_x509.py @@ -22,7 +22,7 @@ @pytest.mark.parametrize( - "value,expected", + ("value", "expected"), ( ("/CN=example.com", [(NameOID.COMMON_NAME, "example.com")]), # leading or trailing spaces are always ok: @@ -93,16 +93,6 @@ ("/O=/OU=", [(NameOID.ORGANIZATION_NAME, ""), (NameOID.ORGANIZATIONAL_UNIT_NAME, "")]), # no slash at start works: ("CN=example.com", [(NameOID.COMMON_NAME, "example.com")]), - # test multiple OUs - ( - "/C=AT/OU=foo/OU=bar/CN=example.com", - [ - (NameOID.COUNTRY_NAME, "AT"), - (NameOID.ORGANIZATIONAL_UNIT_NAME, "foo"), - (NameOID.ORGANIZATIONAL_UNIT_NAME, "bar"), - (NameOID.COMMON_NAME, "example.com"), - ], - ), ( "/OU=foo/OU=bar", [(NameOID.ORGANIZATIONAL_UNIT_NAME, "foo"), (NameOID.ORGANIZATIONAL_UNIT_NAME, "bar")], diff --git a/ca/django_ca/tests/utils/test_split_str.py b/ca/django_ca/tests/utils/test_split_str.py index b60e7a89a..ed4dcecc0 100644 --- a/ca/django_ca/tests/utils/test_split_str.py +++ b/ca/django_ca/tests/utils/test_split_str.py @@ -19,7 +19,7 @@ @pytest.mark.parametrize( - "value,seperator,expected", + ("value", "seperator", "expected"), ( ("foo", "/", ["foo"]), ("foo bar", "/", ["foo bar"]), @@ -37,10 +37,7 @@ ("/foo/bar", "/", ["foo", "bar"]), ("/foo/bar/", "/", ["foo", "bar"]), ("/C=AT/CN=example.com/", "/", ["C=AT", "CN=example.com"]), - (r"foo/bar", "/", ["foo", "bar"]), # test quoting - (r"foo'/'bar", "/", ["foo/bar"]), - (r'foo"/"bar', "/", ["foo/bar"]), (r'fo"o/b"ar', "/", ["foo/bar"]), (r'"foo\"bar"', "/", ['foo"bar']), # escape quotes inside quotes # Test the escape character @@ -60,12 +57,6 @@ (r'"foo\\xbar"', "/", [r"foo\xbar"]), # ... but in single quote it's not an escape -> double backslash in result (r"'foo\\xbar'", "/", [r"foo\\xbar"]), - # No quotes, single backslash preceeding "/" --> "/" is escaped - (r"foo\/bar", "/", ["foo/bar"]), - # No quotes, but *double* backslash preceeding "/" --> backslash itself is escaped, slash is delimiter - (r"foo\\/bar", "/", ["foo\\", "bar"]), - # With quotes/double quotes, no backslashes -> slash is inside quoted string -> it's not a delimiter - ('"foo/bar"/bla', "/", ["foo/bar", "bla"]), ("'foo/bar'/bla", "/", ["foo/bar", "bla"]), # With quotes/double quotes, with one backslash (r'"foo\/bar"/bla', "/", [r"foo\/bar", "bla"]), @@ -98,7 +89,7 @@ def test_basic(value: str, seperator: str, expected: list[str]) -> None: @pytest.mark.parametrize( - "value,match", + ("value", "match"), ( (r"'foo\'bar'", "^No closing quotation$"), (r"foo'bar", "^No closing quotation$"), diff --git a/ca/django_ca/tests/utils/test_validate_hostname.py b/ca/django_ca/tests/utils/test_validate_hostname.py index 4bd2efdd2..06eeeed19 100644 --- a/ca/django_ca/tests/utils/test_validate_hostname.py +++ b/ca/django_ca/tests/utils/test_validate_hostname.py @@ -55,7 +55,7 @@ def test_no_allow_port(value: str) -> None: @pytest.mark.parametrize( - "value,error", + ("value", "error"), ( ("localhost:no-int", "^no-int: Port must be an integer$"), ("localhost:0", "^0: Port must be between 1 and 65535$"), diff --git a/ca/django_ca/tests/views/test_certificate_revocation_list_view.py b/ca/django_ca/tests/views/test_certificate_revocation_list_view.py index e33a92db5..497e3dfd7 100644 --- a/ca/django_ca/tests/views/test_certificate_revocation_list_view.py +++ b/ca/django_ca/tests/views/test_certificate_revocation_list_view.py @@ -92,7 +92,7 @@ @pytest.fixture -def default_url(root: CertificateAuthority) -> Iterator[str]: +def default_url(root: CertificateAuthority) -> str: """Fixture for the default URL for the root CA.""" return reverse("default", kwargs={"serial": root.serial}) diff --git a/ca/django_ca/views.py b/ca/django_ca/views.py index c8a00b4ea..90cc09fcb 100644 --- a/ca/django_ca/views.py +++ b/ca/django_ca/views.py @@ -147,9 +147,7 @@ def get_key_backend_options(self, ca: CertificateAuthority) -> BaseModel: def fetch_crl(self, ca: CertificateAuthority, encoding: CertificateRevocationListEncodings) -> bytes: """Actually fetch the CRL (nested function so that we can easily catch any exception).""" - print(self.scope) if self.scope is not _NOT_SET: - print(2) warnings.warn( "The scope parameter is deprecated and will be removed in django-ca 2.3.0, use " "`only_contains_{ca,user,attribute}_cert` instead.",