Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(azure): get locations with self session #5751

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions prowler/providers/azure/azure_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
)

# TODO: should we keep this here or within the identity?
self._locations = self.get_locations(self.session)
self._locations = self.get_locations()

# Audit Config
if config_content:
Expand Down Expand Up @@ -942,33 +942,35 @@

return identity

def get_locations(self, credentials) -> dict[str, list[str]]:
def get_locations(self) -> dict[str, list[str]]:
"""
Retrieves the locations available for each subscription using the provided credentials.

Args:
credentials: The credentials object used to authenticate the request.

Returns:
A dictionary containing the locations available for each subscription. The dictionary
jfagoagas marked this conversation as resolved.
Show resolved Hide resolved
has subscription display names as keys and lists of location names as values.

Examples:
>>> provider = AzureProvider(...)
>>> provider.get_locations()
{
'Subscription 1': ['eastus', 'eastus2', 'westus', 'westus2'],
'Subscription 2': ['eastus', 'eastus2', 'westus', 'westus2']
}
"""
locations = None
if credentials:
locations = {}
token = credentials.get_token("https://management.azure.com/.default").token
for display_name, subscription_id in self._identity.subscriptions.items():
locations.update({display_name: []})
url = f"https://management.azure.com/subscriptions/{subscription_id}/locations?api-version=2022-12-01"
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
for location in data["value"]:
locations[display_name].append(location["name"])
credentials = self.session
subscription_client = SubscriptionClient(credentials)
locations = {}

Check warning on line 963 in prowler/providers/azure/azure_provider.py

View check run for this annotation

Codecov / codecov/patch

prowler/providers/azure/azure_provider.py#L961-L963

Added lines #L961 - L963 were not covered by tests

for display_name, subscription_id in self._identity.subscriptions.items():
locations[display_name] = []

Check warning on line 966 in prowler/providers/azure/azure_provider.py

View check run for this annotation

Codecov / codecov/patch

prowler/providers/azure/azure_provider.py#L965-L966

Added lines #L965 - L966 were not covered by tests

# List locations for each subscription
for location in subscription_client.subscriptions.list_locations(

Check warning on line 969 in prowler/providers/azure/azure_provider.py

View check run for this annotation

Codecov / codecov/patch

prowler/providers/azure/azure_provider.py#L969

Added line #L969 was not covered by tests
subscription_id
):
locations[display_name].append(location.name)

Check warning on line 972 in prowler/providers/azure/azure_provider.py

View check run for this annotation

Codecov / codecov/patch

prowler/providers/azure/azure_provider.py#L972

Added line #L972 was not covered by tests

return locations

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def execute(self) -> list[Check_Report_Azure]:
report = Check_Report_Azure(self.metadata())
report.subscription = subscription
report.resource_name = "Network Watcher"
report.location = "Global"
report.location = "global"
report.resource_id = f"/subscriptions/{network_client.subscriptions[subscription]}/resourceGroups/NetworkWatcherRG/providers/Microsoft.Network/networkWatchers/NetworkWatcher_*"

missing_locations = set(network_client.locations[subscription]) - set(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_network_invalid_network_watchers(self):
assert result[0].subscription == AZURE_SUBSCRIPTION_NAME
assert result[0].resource_name == network_watcher_name
assert result[0].resource_id == network_watcher_id
assert result[0].location == "Global"
assert result[0].location == "global"

def test_network_valid_network_watchers(self):
network_client = mock.MagicMock
Expand Down Expand Up @@ -124,4 +124,4 @@ def test_network_valid_network_watchers(self):
assert result[0].subscription == AZURE_SUBSCRIPTION_NAME
assert result[0].resource_name == network_watcher_name
assert result[0].resource_id == network_watcher_id
assert result[0].location == "Global"
assert result[0].location == "global"