Skip to content

Commit

Permalink
fix(cloudtrail): use dictionary instead of list (#3579)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrCloudSec authored and jfagoagas committed Mar 21, 2024
1 parent c32f7ba commit 85d6d02
Show file tree
Hide file tree
Showing 16 changed files with 43 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def execute(self):
f"Lambda function {function.name} is not recorded by CloudTrail."
)
lambda_recorded_cloudtrail = False
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
for data_event in trail.data_events:
# classic event selectors
if not data_event.is_advanced:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class cloudtrail_bucket_requires_mfa_delete(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
if trail.is_logging:
trail_bucket_is_in_account = False
trail_bucket = trail.s3_bucket
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class cloudtrail_cloudwatch_logging_enabled(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
if trail.name:
report = Check_Report_AWS(self.metadata())
report.region = trail.region
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class cloudtrail_insights_exist(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
if trail.is_logging:
report = Check_Report_AWS(self.metadata())
report.region = trail.region
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class cloudtrail_kms_encryption_enabled(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
if trail.name:
report = Check_Report_AWS(self.metadata())
report.region = trail.region
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class cloudtrail_log_file_validation_enabled(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
if trail.name:
report = Check_Report_AWS(self.metadata())
report.region = trail.region
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class cloudtrail_logs_s3_bucket_access_logging_enabled(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
if trail.name:
trail_bucket_is_in_account = False
trail_bucket = trail.s3_bucket
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class cloudtrail_logs_s3_bucket_is_not_publicly_accessible(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
if trail.name:
trail_bucket_is_in_account = False
trail_bucket = trail.s3_bucket
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def execute(self):
for region in cloudtrail_client.regional_clients.keys():
report = Check_Report_AWS(self.metadata())
report.region = region
for trail in cloudtrail_client.trails:
if trail.region == region:
for trail in cloudtrail_client.trails.values():
if trail.region == region or trail.is_multiregion:
if trail.is_logging:
report.status = "PASS"
report.resource_id = trail.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def execute(self):
report.resource_id = cloudtrail_client.audited_account
report.resource_arn = cloudtrail_client.trail_arn_template

for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
if trail.is_logging:
if trail.is_multiregion:
for event in trail.data_events:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class cloudtrail_s3_dataevents_read_enabled(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
for data_event in trail.data_events:
# classic event selectors
if not data_event.is_advanced:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class cloudtrail_s3_dataevents_write_enabled(Check):
def execute(self):
findings = []
for trail in cloudtrail_client.trails:
for trail in cloudtrail_client.trails.values():
for data_event in trail.data_events:
# Classic event selectors
if not data_event.is_advanced:
Expand Down
44 changes: 20 additions & 24 deletions prowler/providers/aws/services/cloudtrail/cloudtrail_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, audit_info):
# Call AWSService's __init__
super().__init__(__class__.__name__, audit_info)
self.trail_arn_template = f"arn:{self.audited_partition}:cloudtrail:{self.region}:{self.audited_account}:trail"
self.trails = []
self.trails = {}
self.__threading_call__(self.__get_trails__)
self.__get_trail_status__()
self.__get_insight_selectors__()
Expand Down Expand Up @@ -45,27 +45,23 @@ def __get_trails__(self, regional_client):
kms_key_id = trail["KmsKeyId"]
if "CloudWatchLogsLogGroupArn" in trail:
log_group_arn = trail["CloudWatchLogsLogGroupArn"]
self.trails.append(
Trail(
name=trail["Name"],
is_multiregion=trail["IsMultiRegionTrail"],
home_region=trail["HomeRegion"],
arn=trail["TrailARN"],
region=regional_client.region,
is_logging=False,
log_file_validation_enabled=trail[
"LogFileValidationEnabled"
],
latest_cloudwatch_delivery_time=None,
s3_bucket=trail["S3BucketName"],
kms_key=kms_key_id,
log_group_arn=log_group_arn,
data_events=[],
has_insight_selectors=trail.get("HasInsightSelectors"),
)
self.trails[trail["TrailARN"]] = Trail(
name=trail["Name"],
is_multiregion=trail["IsMultiRegionTrail"],
home_region=trail["HomeRegion"],
arn=trail["TrailARN"],
region=regional_client.region,
is_logging=False,
log_file_validation_enabled=trail["LogFileValidationEnabled"],
latest_cloudwatch_delivery_time=None,
s3_bucket=trail["S3BucketName"],
kms_key=kms_key_id,
log_group_arn=log_group_arn,
data_events=[],
has_insight_selectors=trail.get("HasInsightSelectors"),
)
if trails_count == 0:
self.trails.append(
self.trails[self.__get_trail_arn_template__(regional_client.region)] = (
Trail(
region=regional_client.region,
)
Expand All @@ -79,7 +75,7 @@ def __get_trails__(self, regional_client):
def __get_trail_status__(self):
logger.info("Cloudtrail - Getting trail status")
try:
for trail in self.trails:
for trail in self.trails.values():
for region, client in self.regional_clients.items():
if trail.region == region and trail.name:
status = client.get_trail_status(Name=trail.arn)
Expand All @@ -97,7 +93,7 @@ def __get_trail_status__(self):
def __get_event_selectors__(self):
logger.info("Cloudtrail - Getting event selector")
try:
for trail in self.trails:
for trail in self.trails.values():
for region, client in self.regional_clients.items():
if trail.region == region and trail.name:
data_events = client.get_event_selectors(TrailName=trail.arn)
Expand Down Expand Up @@ -131,7 +127,7 @@ def __get_insight_selectors__(self):
logger.info("Cloudtrail - Getting trail insight selectors...")

try:
for trail in self.trails:
for trail in self.trails.values():
for region, client in self.regional_clients.items():
if trail.region == region and trail.name:
insight_selectors = None
Expand Down Expand Up @@ -180,7 +176,7 @@ def __get_insight_selectors__(self):
def __list_tags_for_resource__(self):
logger.info("CloudTrail - List Tags...")
try:
for trail in self.trails:
for trail in self.trails.values():
# Check if trails are in this account and region
if (
trail.region == trail.home_region
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def check_cloudwatch_log_metric_filter(
):
# 1. Iterate for CloudWatch Log Group in CloudTrail trails
log_groups = []
for trail in trails:
for trail in trails.values():
if trail.log_group_arn:
log_groups.append(trail.log_group_arn.split(":")[6])
# 2. Describe metric filters for previous log groups
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_trails_sending_logs_during_and_not_last_day(self):
cloudtrail_cloudwatch_logging_enabled,
)

for trail in service_client.trails:
for trail in service_client.trails.values():
if trail.name == trail_name_us:
trail.latest_cloudwatch_delivery_time = datetime.now().replace(
tzinfo=timezone.utc
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_multi_region_and_single_region_logging_and_not(self):
cloudtrail_cloudwatch_logging_enabled,
)

for trail in service_client.trails:
for trail in service_client.trails.values():
if trail.name == trail_name_us:
trail.latest_cloudwatch_delivery_time = datetime.now().replace(
tzinfo=timezone.utc
Expand All @@ -190,8 +190,8 @@ def test_multi_region_and_single_region_logging_and_not(self):

check = cloudtrail_cloudwatch_logging_enabled()
result = check.execute()
# len of result should be 3 -> (1 multiregion entry per region + 1 entry because of single region trail)
assert len(result) == 3
# len of result should be 2 -> (1 per trail)
assert len(result) == 2
for report in result:
if report.resource_id == trail_name_us:
assert report.resource_id == trail_name_us
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_trails_sending_and_not_sending_logs(self):
cloudtrail_cloudwatch_logging_enabled,
)

for trail in service_client.trails:
for trail in service_client.trails.values():
if trail.name == trail_name_us:
trail.latest_cloudwatch_delivery_time = datetime.now().replace(
tzinfo=timezone.utc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_describe_trails(self):
)
cloudtrail = Cloudtrail(audit_info)
assert len(cloudtrail.trails) == 2
for trail in cloudtrail.trails:
for trail in cloudtrail.trails.values():
if trail.name:
assert trail.name == trail_name_us or trail.name == trail_name_eu
assert not trail.is_multiregion
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_status_trails(self):
)
cloudtrail = Cloudtrail(audit_info)
assert len(cloudtrail.trails) == len(audit_info.audited_regions)
for trail in cloudtrail.trails:
for trail in cloudtrail.trails.values():
if trail.name:
if trail.name == trail_name_us:
assert not trail.is_multiregion
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_get_classic_event_selectors(self):
)
cloudtrail = Cloudtrail(audit_info)
assert len(cloudtrail.trails) == len(audit_info.audited_regions)
for trail in cloudtrail.trails:
for trail in cloudtrail.trails.values():
if trail.name:
if trail.name == trail_name_us:
assert not trail.is_multiregion
Expand Down Expand Up @@ -237,7 +237,7 @@ def test_get_advanced_event_selectors(self):
)
cloudtrail = Cloudtrail(audit_info)
assert len(cloudtrail.trails) == len(audit_info.audited_regions)
for trail in cloudtrail.trails:
for trail in cloudtrail.trails.values():
if trail.name:
if trail.name == trail_name_us:
assert not trail.is_multiregion
Expand Down

0 comments on commit 85d6d02

Please sign in to comment.