diff --git a/data_safe_haven/config/config_sections.py b/data_safe_haven/config/config_sections.py index 35b9570a7e..62bfec0833 100644 --- a/data_safe_haven/config/config_sections.py +++ b/data_safe_haven/config/config_sections.py @@ -57,7 +57,7 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True): admin_email_address: EmailAddress admin_ip_addresses: list[IpAddress] = [] databases: UniqueList[DatabaseSystem] = [] - data_provider_ip_addresses: list[IpAddress] = [] + data_provider_ip_addresses: list[IpAddress] | AzureServiceTag = [] remote_desktop: ConfigSubsectionRemoteDesktopOpts research_user_ip_addresses: list[IpAddress] | AzureServiceTag = [] storage_quota_gb: ConfigSubsectionStorageQuotaGB @@ -67,8 +67,6 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True): @field_validator( "admin_ip_addresses", - "data_provider_ip_addresses", - # "research_user_ip_addresses", mode="after", ) @classmethod @@ -81,6 +79,7 @@ def ensure_non_overlapping(cls, v: list[IpAddress]) -> list[IpAddress]: return v @field_validator( + "data_provider_ip_addresses", "research_user_ip_addresses", mode="after", ) diff --git a/data_safe_haven/infrastructure/components/wrapped/nfsv3_storage_account.py b/data_safe_haven/infrastructure/components/wrapped/nfsv3_storage_account.py index 181839e71d..e259de4806 100644 --- a/data_safe_haven/infrastructure/components/wrapped/nfsv3_storage_account.py +++ b/data_safe_haven/infrastructure/components/wrapped/nfsv3_storage_account.py @@ -4,6 +4,7 @@ from pulumi_azure_native import storage from data_safe_haven.external import AzureIPv4Range +from data_safe_haven.types import AzureServiceTag class WrappedNFSV3StorageAccount(storage.StorageAccount): @@ -24,17 +25,35 @@ def __init__( resource_name: str, *, account_name: Input[str], - allowed_ip_addresses: Input[Sequence[str]], + allowed_ip_addresses: Input[Sequence[str]] | None, + allowed_service_tag: AzureServiceTag | None, location: Input[str], resource_group_name: Input[str], subnet_id: Input[str], opts: ResourceOptions, tags: Input[Mapping[str, Input[str]]], ): + if allowed_service_tag == AzureServiceTag.INTERNET: + default_action = storage.DefaultAction.ALLOW + ip_rules = [] + else: + default_action = storage.DefaultAction.DENY + ip_rules = Output.from_input(allowed_ip_addresses).apply( + lambda ip_ranges: [ + storage.IPRuleArgs( + action=storage.Action.ALLOW, + i_p_address_or_range=str(ip_address), + ) + for ip_range in sorted(ip_ranges) + for ip_address in AzureIPv4Range.from_cidr(ip_range).all_ips() + ] + ) + self.resource_group_name_ = Output.from_input(resource_group_name) super().__init__( resource_name, account_name=account_name, + allow_blob_public_access=False, enable_https_traffic_only=True, enable_nfs_v3=True, encryption=self.encryption_args, @@ -44,23 +63,15 @@ def __init__( minimum_tls_version=storage.MinimumTlsVersion.TLS1_2, network_rule_set=storage.NetworkRuleSetArgs( bypass=storage.Bypass.AZURE_SERVICES, - default_action=storage.DefaultAction.DENY, - ip_rules=Output.from_input(allowed_ip_addresses).apply( - lambda ip_ranges: [ - storage.IPRuleArgs( - action=storage.Action.ALLOW, - i_p_address_or_range=str(ip_address), - ) - for ip_range in sorted(ip_ranges) - for ip_address in AzureIPv4Range.from_cidr(ip_range).all_ips() - ] - ), + default_action=default_action, + ip_rules=ip_rules, virtual_network_rules=[ storage.VirtualNetworkRuleArgs( virtual_network_resource_id=subnet_id, ) ], ), + public_network_access=storage.PublicNetworkAccess.ENABLED, resource_group_name=resource_group_name, sku=storage.SkuArgs(name=storage.SkuName.PREMIUM_ZRS), opts=opts, diff --git a/data_safe_haven/infrastructure/programs/sre/data.py b/data_safe_haven/infrastructure/programs/sre/data.py index 9e18666277..711b76139f 100644 --- a/data_safe_haven/infrastructure/programs/sre/data.py +++ b/data_safe_haven/infrastructure/programs/sre/data.py @@ -35,7 +35,7 @@ SSLCertificateProps, WrappedNFSV3StorageAccount, ) -from data_safe_haven.types import AzureDnsZoneNames +from data_safe_haven.types import AzureDnsZoneNames, AzureServiceTag class SREDataProps: @@ -46,7 +46,7 @@ def __init__( admin_email_address: Input[str], admin_group_id: Input[str], admin_ip_addresses: Input[Sequence[str]], - data_provider_ip_addresses: Input[Sequence[str]], + data_provider_ip_addresses: Input[list[str]] | AzureServiceTag, dns_private_zones: Input[dict[str, network.PrivateZone]], dns_record: Input[network.RecordSet], dns_server_admin_password: Input[pulumi_random.RandomPassword], @@ -64,13 +64,7 @@ def __init__( self.admin_email_address = admin_email_address self.admin_group_id = admin_group_id self.data_configuration_ip_addresses = admin_ip_addresses - self.data_private_sensitive_ip_addresses = Output.all( - admin_ip_addresses, data_provider_ip_addresses - ).apply( - lambda address_lists: { - ip for address_list in address_lists for ip in address_list - } - ) + self.data_provider_ip_addresses = data_provider_ip_addresses self.dns_private_zones = dns_private_zones self.dns_record = dns_record self.password_dns_server_admin = dns_server_admin_password @@ -112,6 +106,19 @@ def __init__( child_opts = ResourceOptions.merge(opts, ResourceOptions(parent=self)) child_tags = {"component": "data"} | (tags if tags else {}) + if isinstance(props.data_provider_ip_addresses, list): + data_private_sensitive_service_tag = None + data_private_sensitive_ip_addresses = Output.all( + props.data_configuration_ip_addresses, props.data_provider_ip_addresses + ).apply( + lambda address_lists: { + ip for address_list in address_lists for ip in address_list + } + ) + else: + data_private_sensitive_ip_addresses = None + data_private_sensitive_service_tag = props.data_provider_ip_addresses + # Define Key Vault reader identity_key_vault_reader = managedidentity.UserAssignedIdentity( f"{self._name}_id_key_vault_reader", @@ -466,7 +473,8 @@ def __init__( account_name=alphanumeric( f"{''.join(truncate_tokens(stack_name.split('-'), 11))}sensitivedata{sha256hash(self._name)}" )[:24], - allowed_ip_addresses=props.data_private_sensitive_ip_addresses, + allowed_ip_addresses=data_private_sensitive_ip_addresses, + allowed_service_tag=data_private_sensitive_service_tag, location=props.location, subnet_id=props.subnet_data_private_id, resource_group_name=props.resource_group_name, diff --git a/data_safe_haven/infrastructure/programs/sre/desired_state.py b/data_safe_haven/infrastructure/programs/sre/desired_state.py index 73466d6c5b..c4392f5210 100644 --- a/data_safe_haven/infrastructure/programs/sre/desired_state.py +++ b/data_safe_haven/infrastructure/programs/sre/desired_state.py @@ -108,6 +108,7 @@ def __init__( f"{''.join(truncate_tokens(stack_name.split('-'), 11))}desiredstate{sha256hash(self._name)}" )[:24], allowed_ip_addresses=props.admin_ip_addresses, + allowed_service_tag=None, location=props.location, resource_group_name=props.resource_group_name, subnet_id=props.subnet_desired_state_id, diff --git a/data_safe_haven/validators/validators.py b/data_safe_haven/validators/validators.py index 27507d26b4..dd4458ec57 100644 --- a/data_safe_haven/validators/validators.py +++ b/data_safe_haven/validators/validators.py @@ -124,7 +124,7 @@ def ip_address(ip_address: str) -> str: try: return str(ipaddress.ip_network(ip_address)) except Exception as exc: - msg = "Expected valid IPv4 address, for example '1.1.1.1', or 'Internet'." + msg = "Expected valid IPv4 address, for example '1.1.1.1'." raise ValueError(msg) from exc diff --git a/tests/config/test_config_sections.py b/tests/config/test_config_sections.py index 6528b130fa..7d9a0ba873 100644 --- a/tests/config/test_config_sections.py +++ b/tests/config/test_config_sections.py @@ -170,6 +170,24 @@ def test_all_databases_must_be_unique(self) -> None: databases=[DatabaseSystem.POSTGRESQL, DatabaseSystem.POSTGRESQL], ) + def test_data_provider_tag_internet( + self, + config_subsection_remote_desktop: ConfigSubsectionRemoteDesktopOpts, + config_subsection_storage_quota_gb: ConfigSubsectionStorageQuotaGB, + ): + sre_config = ConfigSectionSRE( + admin_email_address="admin@example.com", + remote_desktop=config_subsection_remote_desktop, + storage_quota_gb=config_subsection_storage_quota_gb, + data_provider_ip_addresses="Internet", + ) + assert isinstance(sre_config.data_provider_ip_addresses, AzureServiceTag) + assert sre_config.data_provider_ip_addresses == "Internet" + + def test_data_provider_tag_invalid(self): + with pytest.raises(ValueError, match="Input should be 'Internet'"): + ConfigSectionSRE(data_provider_ip_addresses="Not a tag") + def test_ip_overlap_admin(self): with pytest.raises(ValueError, match="IP addresses must not overlap."): ConfigSectionSRE( diff --git a/tests/validators/test_validators.py b/tests/validators/test_validators.py index 18d2fd31b5..c8447ab441 100644 --- a/tests/validators/test_validators.py +++ b/tests/validators/test_validators.py @@ -111,7 +111,7 @@ def test_ip_address(self, ip_address, output): def test_ip_address_fail(self, ip_address): with pytest.raises( ValueError, - match="Expected valid IPv4 address, for example '1.1.1.1', or 'Internet'.", + match="Expected valid IPv4 address, for example '1.1.1.1'.", ): validators.ip_address(ip_address)