From 4ef0935ee1d6a741ab5d13400fe8eacad5dbdf81 Mon Sep 17 00:00:00 2001 From: pedrooot Date: Wed, 13 Nov 2024 14:40:15 +0100 Subject: [PATCH 1/3] refactor(azure): get locations with self session --- prowler/providers/azure/azure_provider.py | 38 ++++++++++------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/prowler/providers/azure/azure_provider.py b/prowler/providers/azure/azure_provider.py index de7d9bc42d8..86076fb293f 100644 --- a/prowler/providers/azure/azure_provider.py +++ b/prowler/providers/azure/azure_provider.py @@ -198,7 +198,7 @@ def __init__( ) # TODO: should we keep this here or within the identity? - self._locations = self.get_locations(self.session) + self._locations = self.get_locations() # Audit Config if config_content: @@ -942,33 +942,29 @@ async def get_azure_identity(): return identity - def get_locations(self, credentials) -> dict[str, list[str]]: + def get_locations(self) -> dict[str, list[str]]: """ Retrieves the locations available for each subscription using the provided credentials. - Args: - credentials: The credentials object used to authenticate the request. - Returns: A dictionary containing the locations available for each subscription. The dictionary has subscription display names as keys and lists of location names as values. """ - locations = None - if credentials: - locations = {} - token = credentials.get_token("https://management.azure.com/.default").token - for display_name, subscription_id in self._identity.subscriptions.items(): - locations.update({display_name: []}) - url = f"https://management.azure.com/subscriptions/{subscription_id}/locations?api-version=2022-12-01" - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } - response = requests.get(url, headers=headers) - if response.status_code == 200: - data = response.json() - for location in data["value"]: - locations[display_name].append(location["name"]) + credentials = self.session + locations = {} + token = credentials.get_token("https://management.azure.com/.default").token + for display_name, subscription_id in self._identity.subscriptions.items(): + locations.update({display_name: []}) + url = f"https://management.azure.com/subscriptions/{subscription_id}/locations?api-version=2022-12-01" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + response = requests.get(url, headers=headers) + if response.status_code == 200: + data = response.json() + for location in data["value"]: + locations[display_name].append(location["name"]) return locations @staticmethod From c83a4f31d183c8b3de3334343410e9084bc504ca Mon Sep 17 00:00:00 2001 From: pedrooot Date: Wed, 13 Nov 2024 17:08:38 +0100 Subject: [PATCH 2/3] feat(azure): resolve comments --- prowler/providers/azure/azure_provider.py | 30 ++++++++++++++--------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/prowler/providers/azure/azure_provider.py b/prowler/providers/azure/azure_provider.py index 86076fb293f..375a2b0c78d 100644 --- a/prowler/providers/azure/azure_provider.py +++ b/prowler/providers/azure/azure_provider.py @@ -949,22 +949,28 @@ def get_locations(self) -> dict[str, list[str]]: Returns: A dictionary containing the locations available for each subscription. The dictionary has subscription display names as keys and lists of location names as values. + + Examples: + >>> provider = AzureProvider(...) + >>> provider.get_locations() + { + 'Subscription 1': ['eastus', 'eastus2', 'westus', 'westus2'], + 'Subscription 2': ['eastus', 'eastus2', 'westus', 'westus2'] + } """ credentials = self.session + subscription_client = SubscriptionClient(credentials) locations = {} - token = credentials.get_token("https://management.azure.com/.default").token + for display_name, subscription_id in self._identity.subscriptions.items(): - locations.update({display_name: []}) - url = f"https://management.azure.com/subscriptions/{subscription_id}/locations?api-version=2022-12-01" - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } - response = requests.get(url, headers=headers) - if response.status_code == 200: - data = response.json() - for location in data["value"]: - locations[display_name].append(location["name"]) + locations[display_name] = [] + + # List locations for each subscription + for location in subscription_client.subscriptions.list_locations( + subscription_id + ): + locations[display_name].append(location.name) + return locations @staticmethod From bfe4d6d5a47641f21d78035560b6303c106c7b6d Mon Sep 17 00:00:00 2001 From: pedrooot Date: Wed, 13 Nov 2024 18:11:02 +0100 Subject: [PATCH 3/3] feat(azure): remove typo --- .../network_watcher_enabled/network_watcher_enabled.py | 2 +- .../network_watcher_enabled/network_watcher_enabled_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/prowler/providers/azure/services/network/network_watcher_enabled/network_watcher_enabled.py b/prowler/providers/azure/services/network/network_watcher_enabled/network_watcher_enabled.py index 02694561f06..d16798c8f31 100644 --- a/prowler/providers/azure/services/network/network_watcher_enabled/network_watcher_enabled.py +++ b/prowler/providers/azure/services/network/network_watcher_enabled/network_watcher_enabled.py @@ -9,7 +9,7 @@ def execute(self) -> list[Check_Report_Azure]: report = Check_Report_Azure(self.metadata()) report.subscription = subscription report.resource_name = "Network Watcher" - report.location = "Global" + report.location = "global" report.resource_id = f"/subscriptions/{network_client.subscriptions[subscription]}/resourceGroups/NetworkWatcherRG/providers/Microsoft.Network/networkWatchers/NetworkWatcher_*" missing_locations = set(network_client.locations[subscription]) - set( diff --git a/tests/providers/azure/services/network/network_watcher_enabled/network_watcher_enabled_test.py b/tests/providers/azure/services/network/network_watcher_enabled/network_watcher_enabled_test.py index bfb066ff7a9..24a7f1ef4de 100644 --- a/tests/providers/azure/services/network/network_watcher_enabled/network_watcher_enabled_test.py +++ b/tests/providers/azure/services/network/network_watcher_enabled/network_watcher_enabled_test.py @@ -78,7 +78,7 @@ def test_network_invalid_network_watchers(self): assert result[0].subscription == AZURE_SUBSCRIPTION_NAME assert result[0].resource_name == network_watcher_name assert result[0].resource_id == network_watcher_id - assert result[0].location == "Global" + assert result[0].location == "global" def test_network_valid_network_watchers(self): network_client = mock.MagicMock @@ -124,4 +124,4 @@ def test_network_valid_network_watchers(self): assert result[0].subscription == AZURE_SUBSCRIPTION_NAME assert result[0].resource_name == network_watcher_name assert result[0].resource_id == network_watcher_id - assert result[0].location == "Global" + assert result[0].location == "global"