Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow 'Internet' for data providers IP #2247

Merged
merged 8 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions data_safe_haven/config/config_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
admin_email_address: EmailAddress
admin_ip_addresses: list[IpAddress] = []
databases: UniqueList[DatabaseSystem] = []
data_provider_ip_addresses: list[IpAddress] = []
data_provider_ip_addresses: list[IpAddress] | AzureServiceTag = []
remote_desktop: ConfigSubsectionRemoteDesktopOpts
research_user_ip_addresses: list[IpAddress] | AzureServiceTag = []
storage_quota_gb: ConfigSubsectionStorageQuotaGB
Expand All @@ -67,8 +67,6 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):

@field_validator(
"admin_ip_addresses",
"data_provider_ip_addresses",
# "research_user_ip_addresses",
mode="after",
)
@classmethod
Expand All @@ -81,6 +79,7 @@ def ensure_non_overlapping(cls, v: list[IpAddress]) -> list[IpAddress]:
return v

@field_validator(
"data_provider_ip_addresses",
"research_user_ip_addresses",
mode="after",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pulumi_azure_native import storage

from data_safe_haven.external import AzureIPv4Range
from data_safe_haven.types import AzureServiceTag


class WrappedNFSV3StorageAccount(storage.StorageAccount):
Expand All @@ -24,17 +25,35 @@ def __init__(
resource_name: str,
*,
account_name: Input[str],
allowed_ip_addresses: Input[Sequence[str]],
allowed_ip_addresses: Input[Sequence[str]] | None,
allowed_service_tag: AzureServiceTag | None,
location: Input[str],
resource_group_name: Input[str],
subnet_id: Input[str],
opts: ResourceOptions,
tags: Input[Mapping[str, Input[str]]],
):
if allowed_service_tag == AzureServiceTag.INTERNET:
default_action = storage.DefaultAction.ALLOW
ip_rules = []
else:
default_action = storage.DefaultAction.DENY
ip_rules = Output.from_input(allowed_ip_addresses).apply(
lambda ip_ranges: [
storage.IPRuleArgs(
action=storage.Action.ALLOW,
i_p_address_or_range=str(ip_address),
)
for ip_range in sorted(ip_ranges)
for ip_address in AzureIPv4Range.from_cidr(ip_range).all_ips()
]
)

self.resource_group_name_ = Output.from_input(resource_group_name)
super().__init__(
resource_name,
account_name=account_name,
allow_blob_public_access=False,
enable_https_traffic_only=True,
enable_nfs_v3=True,
encryption=self.encryption_args,
Expand All @@ -44,23 +63,15 @@ def __init__(
minimum_tls_version=storage.MinimumTlsVersion.TLS1_2,
network_rule_set=storage.NetworkRuleSetArgs(
bypass=storage.Bypass.AZURE_SERVICES,
default_action=storage.DefaultAction.DENY,
ip_rules=Output.from_input(allowed_ip_addresses).apply(
lambda ip_ranges: [
storage.IPRuleArgs(
action=storage.Action.ALLOW,
i_p_address_or_range=str(ip_address),
)
for ip_range in sorted(ip_ranges)
for ip_address in AzureIPv4Range.from_cidr(ip_range).all_ips()
]
),
default_action=default_action,
ip_rules=ip_rules,
virtual_network_rules=[
storage.VirtualNetworkRuleArgs(
virtual_network_resource_id=subnet_id,
)
],
),
public_network_access=storage.PublicNetworkAccess.ENABLED,
resource_group_name=resource_group_name,
sku=storage.SkuArgs(name=storage.SkuName.PREMIUM_ZRS),
opts=opts,
Expand Down
28 changes: 18 additions & 10 deletions data_safe_haven/infrastructure/programs/sre/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
SSLCertificateProps,
WrappedNFSV3StorageAccount,
)
from data_safe_haven.types import AzureDnsZoneNames
from data_safe_haven.types import AzureDnsZoneNames, AzureServiceTag


class SREDataProps:
Expand All @@ -46,7 +46,7 @@ def __init__(
admin_email_address: Input[str],
admin_group_id: Input[str],
admin_ip_addresses: Input[Sequence[str]],
data_provider_ip_addresses: Input[Sequence[str]],
data_provider_ip_addresses: Input[list[str]] | AzureServiceTag,
dns_private_zones: Input[dict[str, network.PrivateZone]],
dns_record: Input[network.RecordSet],
dns_server_admin_password: Input[pulumi_random.RandomPassword],
Expand All @@ -64,13 +64,7 @@ def __init__(
self.admin_email_address = admin_email_address
self.admin_group_id = admin_group_id
self.data_configuration_ip_addresses = admin_ip_addresses
self.data_private_sensitive_ip_addresses = Output.all(
admin_ip_addresses, data_provider_ip_addresses
).apply(
lambda address_lists: {
ip for address_list in address_lists for ip in address_list
}
)
self.data_provider_ip_addresses = data_provider_ip_addresses
self.dns_private_zones = dns_private_zones
self.dns_record = dns_record
self.password_dns_server_admin = dns_server_admin_password
Expand Down Expand Up @@ -112,6 +106,19 @@ def __init__(
child_opts = ResourceOptions.merge(opts, ResourceOptions(parent=self))
child_tags = {"component": "data"} | (tags if tags else {})

if isinstance(props.data_provider_ip_addresses, list):
data_private_sensitive_service_tag = None
data_private_sensitive_ip_addresses = Output.all(
props.data_configuration_ip_addresses, props.data_provider_ip_addresses
).apply(
lambda address_lists: {
ip for address_list in address_lists for ip in address_list
}
)
else:
data_private_sensitive_ip_addresses = None
data_private_sensitive_service_tag = props.data_provider_ip_addresses

# Define Key Vault reader
identity_key_vault_reader = managedidentity.UserAssignedIdentity(
f"{self._name}_id_key_vault_reader",
Expand Down Expand Up @@ -466,7 +473,8 @@ def __init__(
account_name=alphanumeric(
f"{''.join(truncate_tokens(stack_name.split('-'), 11))}sensitivedata{sha256hash(self._name)}"
)[:24],
allowed_ip_addresses=props.data_private_sensitive_ip_addresses,
allowed_ip_addresses=data_private_sensitive_ip_addresses,
allowed_service_tag=data_private_sensitive_service_tag,
location=props.location,
subnet_id=props.subnet_data_private_id,
resource_group_name=props.resource_group_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
f"{''.join(truncate_tokens(stack_name.split('-'), 11))}desiredstate{sha256hash(self._name)}"
)[:24],
allowed_ip_addresses=props.admin_ip_addresses,
allowed_service_tag=None,
location=props.location,
resource_group_name=props.resource_group_name,
subnet_id=props.subnet_desired_state_id,
Expand Down
2 changes: 1 addition & 1 deletion data_safe_haven/validators/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def ip_address(ip_address: str) -> str:
try:
return str(ipaddress.ip_network(ip_address))
except Exception as exc:
msg = "Expected valid IPv4 address, for example '1.1.1.1', or 'Internet'."
msg = "Expected valid IPv4 address, for example '1.1.1.1'."
raise ValueError(msg) from exc


Expand Down
18 changes: 18 additions & 0 deletions tests/config/test_config_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,24 @@ def test_all_databases_must_be_unique(self) -> None:
databases=[DatabaseSystem.POSTGRESQL, DatabaseSystem.POSTGRESQL],
)

def test_data_provider_tag_internet(
self,
config_subsection_remote_desktop: ConfigSubsectionRemoteDesktopOpts,
config_subsection_storage_quota_gb: ConfigSubsectionStorageQuotaGB,
):
sre_config = ConfigSectionSRE(
admin_email_address="[email protected]",
remote_desktop=config_subsection_remote_desktop,
storage_quota_gb=config_subsection_storage_quota_gb,
data_provider_ip_addresses="Internet",
)
assert isinstance(sre_config.data_provider_ip_addresses, AzureServiceTag)
assert sre_config.data_provider_ip_addresses == "Internet"

def test_data_provider_tag_invalid(self):
with pytest.raises(ValueError, match="Input should be 'Internet'"):
ConfigSectionSRE(data_provider_ip_addresses="Not a tag")

def test_ip_overlap_admin(self):
with pytest.raises(ValueError, match="IP addresses must not overlap."):
ConfigSectionSRE(
Expand Down
2 changes: 1 addition & 1 deletion tests/validators/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_ip_address(self, ip_address, output):
def test_ip_address_fail(self, ip_address):
with pytest.raises(
ValueError,
match="Expected valid IPv4 address, for example '1.1.1.1', or 'Internet'.",
match="Expected valid IPv4 address, for example '1.1.1.1'.",
):
validators.ip_address(ip_address)

Expand Down
Loading