From bbc421e18424ed776e86f2400d948e177e570ad9 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 25 Oct 2024 10:28:06 +0500 Subject: [PATCH 1/6] Support custom networks on Azure --- .../_internal/core/backends/azure/compute.py | 141 ++++++++++++++---- .../_internal/core/backends/azure/config.py | 6 + .../core/backends/azure/resources.py | 73 ++++++++- .../_internal/core/models/backends/azure.py | 4 + .../services/backends/configurators/azure.py | 70 ++++++++- .../_internal/server/services/config.py | 51 +++++-- 6 files changed, 292 insertions(+), 53 deletions(-) diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 0d98cc2a2..d99795bbf 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple from azure.core.credentials import TokenCredential -from azure.core.exceptions import ResourceExistsError +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from azure.mgmt import compute as compute_mgmt from azure.mgmt import network as network_mgmt from azure.mgmt.compute.models import ( @@ -33,6 +33,7 @@ from dstack import version from dstack._internal import settings +from dstack._internal.core.backends.azure import resources as azure_resources from dstack._internal.core.backends.azure import utils as azure_utils from dstack._internal.core.backends.azure.config import AzureConfig from dstack._internal.core.backends.base.compute import ( @@ -110,6 +111,19 @@ def create_instance( ssh_pub_keys = instance_config.get_public_keys() disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + allocate_public_ip = self.config.allocate_public_ips + network_resource_group, network, subnet = get_resource_group_network_subnet_or_error( + network_client=self._network_client, + resource_group=self.config.resource_group, + vpc_ids=self.config.vpc_ids, + location=location, + allocate_public_ip=allocate_public_ip, + ) + network_security_group = azure_utils.get_default_network_security_group_name( + resource_group=self.config.resource_group, + location=location, + ) + tags = { "owner": "dstack", "dstack_project": instance_config.project_name, @@ -122,18 +136,9 @@ def create_instance( subscription_id=self.config.subscription_id, location=location, resource_group=self.config.resource_group, - network_security_group=azure_utils.get_default_network_security_group_name( - resource_group=self.config.resource_group, - location=location, - ), - network=azure_utils.get_default_network_name( - resource_group=self.config.resource_group, - location=location, - ), - subnet=azure_utils.get_default_subnet_name( - resource_group=self.config.resource_group, - location=location, - ), + network_security_group=network_security_group, + network=network, + subnet=subnet, managed_identity=None, image_reference=_get_image_ref( compute_client=self._compute_client, @@ -149,6 +154,8 @@ def create_instance( spot=instance_offer.instance.resources.spot, disk_size=disk_size, computer_name="runnervm", + allocate_public_ip=allocate_public_ip, + network_resource_group=network_resource_group, tags=tags, ) logger.info("Request succeeded") @@ -157,11 +164,14 @@ def create_instance( resource_group=self.config.resource_group, vm=vm, ) + hostname = public_ip + if allocate_public_ip: + hostname = private_ip return JobProvisioningData( backend=instance_offer.backend, instance_type=instance_offer.instance, instance_id=vm.name, - hostname=public_ip, + hostname=hostname, internal_ip=private_ip, region=location, price=instance_offer.price, @@ -211,6 +221,18 @@ def create_gateway( configuration.region, ) + network_resource_group, network, subnet = get_resource_group_network_subnet_or_error( + network_client=self._network_client, + resource_group=self.config.resource_group, + vpc_ids=self.config.vpc_ids, + location=configuration.region, + allocate_public_ip=self.config.allocate_public_ips, + ) + network_security_group = azure_utils.get_default_network_security_group_name( + resource_group=self.config.resource_group, + location=configuration.region, + ) + tags = { "Name": configuration.instance_name, "owner": "dstack", @@ -225,18 +247,9 @@ def create_gateway( subscription_id=self.config.subscription_id, location=configuration.region, resource_group=self.config.resource_group, - network_security_group=azure_utils.get_gateway_network_security_group_name( - resource_group=self.config.resource_group, - location=configuration.region, - ), - network=azure_utils.get_default_network_name( - resource_group=self.config.resource_group, - location=configuration.region, - ), - subnet=azure_utils.get_default_subnet_name( - resource_group=self.config.resource_group, - location=configuration.region, - ), + network_security_group=network_security_group, + network=network, + subnet=subnet, managed_identity=None, image_reference=_get_gateway_image_ref(), vm_size="Standard_B1s", @@ -246,6 +259,7 @@ def create_gateway( spot=False, disk_size=30, computer_name="gatewayvm", + network_resource_group=network_resource_group, tags=tags, ) logger.info("Request succeeded") @@ -273,6 +287,57 @@ def terminate_gateway( ) +def get_resource_group_network_subnet_or_error( + network_client: network_mgmt.NetworkManagementClient, + resource_group: Optional[str], + vpc_ids: Optional[Dict[str, str]], + location: str, + allocate_public_ip: bool, +) -> Tuple[str, str, str]: + if vpc_ids is not None: + vpc_id = vpc_ids.get(location) + if vpc_id is None: + raise ComputeError(f"Network not configured for location {location}") + try: + resource_group, network_name = _parse_config_vpc_id(vpc_id) + except Exception: + raise ComputeError( + "Network specified in incorrect format." + " Supported format for `vps_ids` values: 'networkResourceGroupName/networkName'" + ) + elif resource_group is not None: + network_name = azure_utils.get_default_network_name(resource_group, location) + else: + raise ComputeError("`resource_group` or `vpc_ids` must be specified") + + try: + subnets = azure_resources.get_network_subnets( + network_client=network_client, + resource_group=resource_group, + network_name=network_name, + private=not allocate_public_ip, + ) + except ResourceNotFoundError: + raise ComputeError( + f"Network {network_name} not found in location {location} in resource group {resource_group}" + ) + + if len(subnets) == 0: + if not allocate_public_ip: + raise ComputeError( + f"Failed to find private subnets with outbound internet connectivity in network {network_name}" + ) + raise ComputeError(f"Failed to find subnets in network {network_name}") + + subnet_name = subnets[0] + return resource_group, network_name, subnet_name + + +def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]: + resource_group, network_name = vpc_id.split("/") + return resource_group, network_name + + class VMImageVariant(enum.Enum): GRID = enum.auto() CUDA = enum.auto() @@ -396,10 +461,19 @@ def _launch_instance( spot: bool, disk_size: int, computer_name: str, + allocate_public_ip: bool = True, + network_resource_group: Optional[str] = None, tags: Optional[Dict[str, str]] = None, ) -> VirtualMachine: if tags is None: tags = {} + if network_resource_group is None: + network_resource_group = resource_group + public_ip_address_configuration = None + if allocate_public_ip: + public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration( + name="public_ip_config", + ) try: poller = compute_client.virtual_machines.begin_create_or_update( resource_group, @@ -451,14 +525,12 @@ def _launch_instance( subnet=SubResource( id=azure_utils.get_subnet_id( subscription_id, - resource_group, + network_resource_group, network, subnet, ) ), - public_ip_address_configuration=VirtualMachinePublicIPAddressConfiguration( - name="public_ip_config", - ), + public_ip_address_configuration=public_ip_address_configuration, ) ], ) @@ -505,18 +577,21 @@ def _get_vm_public_private_ips( network_client: network_mgmt.NetworkManagementClient, resource_group: str, vm: VirtualMachine, -) -> Tuple[str, str]: +) -> Tuple[Optional[str], str]: nic_id = vm.network_profile.network_interfaces[0].id nic_name = azure_utils.get_resource_name_from_resource_id(nic_id) nic = network_client.network_interfaces.get( resource_group_name=resource_group, network_interface_name=nic_name, ) + + private_ip = nic.ip_configurations[0].private_ip_address + if nic.ip_configurations[0].public_ip_address is None: + return None, private_ip + public_ip_id = nic.ip_configurations[0].public_ip_address.id public_ip_name = azure_utils.get_resource_name_from_resource_id(public_ip_id) public_ip = network_client.public_ip_addresses.get(resource_group, public_ip_name) - - private_ip = nic.ip_configurations[0].private_ip_address return public_ip.ip_address, private_ip diff --git a/src/dstack/_internal/core/backends/azure/config.py b/src/dstack/_internal/core/backends/azure/config.py index 4e7cff268..7a25bb91a 100644 --- a/src/dstack/_internal/core/backends/azure/config.py +++ b/src/dstack/_internal/core/backends/azure/config.py @@ -4,3 +4,9 @@ class AzureConfig(AzureStoredConfig, BackendConfig): creds: AnyAzureCreds + + @property + def allocate_public_ips(self) -> bool: + if self.public_ips is not None: + return self.public_ips + return True diff --git a/src/dstack/_internal/core/backends/azure/resources.py b/src/dstack/_internal/core/backends/azure/resources.py index e03ba484d..722d74154 100644 --- a/src/dstack/_internal/core/backends/azure/resources.py +++ b/src/dstack/_internal/core/backends/azure/resources.py @@ -1,9 +1,80 @@ import re -from typing import Dict +from typing import Dict, List + +from azure.mgmt import network as network_mgmt +from azure.mgmt.network.models import Subnet from dstack._internal.core.errors import ComputeError +def get_network_subnets( + network_client: network_mgmt.NetworkManagementClient, + resource_group: str, + network_name: str, + private: bool, +) -> List[str]: + res = [] + subnets = network_client.subnets.list( + resource_group_name=resource_group, virtual_network_name=network_name + ) + for subnet in subnets: + if private: + if _is_eligible_private_subnet( + network_client=network_client, + resource_group=resource_group, + network_name=network_name, + subnet=subnet, + ): + res.append(subnet.name) + else: + if _is_eligible_public_subnet( + network_client=network_client, + resource_group=resource_group, + network_name=network_name, + subnet=subnet, + ): + res.append(subnet.name) + return res + + +def _is_eligible_public_subnet( + network_client: network_mgmt.NetworkManagementClient, + resource_group: str, + network_name: str, + subnet: Subnet, +) -> bool: + # Apparently, in Azure practically any subnet can be used + # to provision instances with public IPs + return True + + +def _is_eligible_private_subnet( + network_client: network_mgmt.NetworkManagementClient, + resource_group: str, + network_name: str, + subnet: Subnet, +) -> bool: + # Azure provides default outbound connectivity but it's deprecated + # and does not work with Flexible orchestration used in dstack, + # so we require an explicit outbound method such as NAT Gateway. + + if subnet.nat_gateway is not None: + return True + + vnet_peerings = list( + network_client.virtual_network_peerings.list( + resource_group_name=resource_group, + virtual_network_name=network_name, + ) + ) + if len(vnet_peerings) > 0: + # We currently assume that any peering can provide outbound connectivity. + # There can be a more elaborate check of the peering configuration. + return True + + return False + + def validate_tags(tags: Dict[str, str]): for k, v in tags.items(): if not _is_valid_tag(k, v): diff --git a/src/dstack/_internal/core/models/backends/azure.py b/src/dstack/_internal/core/models/backends/azure.py index 85a660cc1..21b5b41a0 100644 --- a/src/dstack/_internal/core/models/backends/azure.py +++ b/src/dstack/_internal/core/models/backends/azure.py @@ -12,6 +12,8 @@ class AzureConfigInfo(CoreModel): tenant_id: str subscription_id: str locations: Optional[List[str]] = None + vpc_ids: Optional[Dict[str, str]] = None + public_ips: Optional[bool] = None tags: Optional[Dict[str, str]] = None @@ -47,6 +49,8 @@ class AzureConfigInfoWithCredsPartial(CoreModel): tenant_id: Optional[str] subscription_id: Optional[str] locations: Optional[List[str]] + vpc_ids: Optional[Dict[str, str]] + public_ips: Optional[bool] tags: Optional[Dict[str, str]] diff --git a/src/dstack/_internal/server/services/backends/configurators/azure.py b/src/dstack/_internal/server/services/backends/configurators/azure.py index 0ef3155ad..e81e41f4c 100644 --- a/src/dstack/_internal/server/services/backends/configurators/azure.py +++ b/src/dstack/_internal/server/services/backends/configurators/azure.py @@ -1,5 +1,5 @@ import json -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Optional, Tuple from azure.core.credentials import TokenCredential @@ -20,8 +20,15 @@ from dstack._internal.core.backends.azure import AzureBackend, auth, resources from dstack._internal.core.backends.azure import utils as azure_utils +from dstack._internal.core.backends.azure.auth import AzureCredential +from dstack._internal.core.backends.azure.compute import get_resource_group_network_subnet_or_error from dstack._internal.core.backends.azure.config import AzureConfig -from dstack._internal.core.errors import BackendAuthError, ComputeError, ServerClientError +from dstack._internal.core.errors import ( + BackendAuthError, + BackendError, + ComputeError, + ServerClientError, +) from dstack._internal.core.models.backends.azure import ( AnyAzureConfigInfo, AzureClientCreds, @@ -47,6 +54,7 @@ Configurator, raise_invalid_credentials_error, ) +from dstack._internal.utils.common import get_or_error LOCATIONS = [ ("(US) Central US", "centralus"), @@ -139,7 +147,7 @@ def get_config_values(self, config: AzureConfigInfoWithCredsPartial) -> AzureCon config_values.locations = self._get_locations_element( selected=config.locations or DEFAULT_LOCATIONS ) - self._check_config(config) + self._check_config(config=config, credential=credential) return config_values def create_backend( @@ -311,8 +319,9 @@ def func(location: str): for location in locations: executor.submit(func, location) - def _check_config(self, config: AzureConfigInfoWithCredsPartial): + def _check_config(self, config: AzureConfigInfoWithCredsPartial, credential: AzureCredential): self._check_tags_config(config) + self._check_vpc_config(config=config, credential=credential) def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial): if not config.tags: @@ -326,6 +335,55 @@ def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial): except ComputeError as e: raise ServerClientError(e.args[0]) + def _check_vpc_config( + self, config: AzureConfigInfoWithCredsPartial, credential: AzureCredential + ): + subscription_id = get_or_error(config.subscription_id) + allocate_public_ip = config.public_ips if config.public_ips is not None else True + if config.public_ips is False and config.vpc_ids is None: + raise ServerClientError(msg="`vpc_ids` must be specified if `public_ips: false`.") + locations = config.locations + if locations is None: + locations = DEFAULT_LOCATIONS + if config.vpc_ids is not None: + vpc_ids_locations = list(config.vpc_ids.keys()) + not_configured_locations = [loc for loc in locations if loc not in vpc_ids_locations] + if len(not_configured_locations) > 0: + if config.locations is None: + raise ServerClientError( + f"`vpc_ids` not configured for regions {not_configured_locations}. " + "Configure `vpc_ids` for all regions or specify `regions`." + ) + raise ServerClientError( + f"`vpc_ids` not configured for regions {not_configured_locations}. " + "Configure `vpc_ids` for all regions specified in `regions`." + ) + network_client = network_mgmt.NetworkManagementClient( + credential=credential, + subscription_id=subscription_id, + ) + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [] + for location in locations: + future = executor.submit( + get_resource_group_network_subnet_or_error, + network_client=network_client, + resource_group=None, + vpc_ids=config.vpc_ids, + location=location, + allocate_public_ip=allocate_public_ip, + ) + futures.append(future) + for future in as_completed(futures): + try: + future.result() + except BackendError as e: + raise ServerClientError(e.args[0]) + + +def _get_resource_group_name(project_name: str) -> str: + return f"dstack-{project_name}" + class ResourceManager: def __init__(self, credential: TokenCredential, subscription_id: str): @@ -347,10 +405,6 @@ def create_resource_group( return resource_group.name -def _get_resource_group_name(project_name: str) -> str: - return f"dstack-{project_name}" - - class NetworkManager: def __init__(self, credential: TokenCredential, subscription_id: str): self.network_client = network_mgmt.NetworkManagementClient( diff --git a/src/dstack/_internal/server/services/config.py b/src/dstack/_internal/server/services/config.py index 94f9fd8fb..29113f93a 100644 --- a/src/dstack/_internal/server/services/config.py +++ b/src/dstack/_internal/server/services/config.py @@ -70,7 +70,10 @@ class AWSConfig(CoreModel): vpc_ids: Annotated[ Optional[Dict[str, str]], Field( - description="The mapping from AWS regions to VPC IDs. If `default_vpcs: true`, omitted regions will use default VPCs" + description=( + "The mapping from AWS regions to VPC IDs." + " If `default_vpcs: true`, omitted regions will use default VPCs" + ) ), ] = None default_vpcs: Annotated[ @@ -86,7 +89,12 @@ class AWSConfig(CoreModel): public_ips: Annotated[ Optional[bool], Field( - description="A flag to enable/disable public IP assigning on instances. Defaults to `true`" + description=( + "A flag to enable/disable public IP assigning on instances." + " `public_ips: false` requires at least one private subnet with outbound internet connection" + " provided by a NAT Gateway or a Transit Gateway." + " Defaults to `true`" + ) ), ] = None tags: Annotated[ @@ -101,7 +109,28 @@ class AzureConfig(CoreModel): tenant_id: Annotated[str, Field(description="The tenant ID")] subscription_id: Annotated[str, Field(description="The subscription ID")] regions: Annotated[ - Optional[List[str]], Field(description="The list of Azure regions (locations)") + Optional[List[str]], + Field(description="The list of Azure regions (locations)"), + ] = None + vpc_ids: Annotated[ + Optional[Dict[str, str]], + Field( + description=( + "The mapping from configured Azure locations to network IDs." + " A network ID must have a format `networkResourceGroup/networkName`" + " If not specified, `dstack` will create a new network for every configured region" + ) + ), + ] = None + public_ips: Annotated[ + Optional[bool], + Field( + description=( + "A flag to enable/disable public IP assigning on instances." + " `public_ips: false` requires `vpc_ids` that specifies custom networks with NAT gateway." + " Defaults to `true`" + ) + ), ] = None tags: Annotated[ Optional[Dict[str, str]], @@ -132,9 +161,9 @@ class GCPServiceAccountCreds(CoreModel): Optional[str], Field( description=( - "The contents of the service account file. " - "When configuring via `server/config.yml`, it's automatically filled from `filename`. " - "When configuring via UI, it has to be specified explicitly" + "The contents of the service account file." + " When configuring via `server/config.yml`, it's automatically filled from `filename`." + " When configuring via UI, it has to be specified explicitly" ) ), ] = None @@ -216,9 +245,9 @@ class KubeconfigConfig(CoreModel): Optional[str], Field( description=( - "The contents of the kubeconfig file. " - "When configuring via `server/config.yml`, it's automatically filled from `filename`. " - "When configuring via UI, it has to be specified explicitly" + "The contents of the kubeconfig file." + " When configuring via `server/config.yml`, it's automatically filled from `filename`." + " When configuring via UI, it has to be specified explicitly" ) ), ] = None @@ -312,8 +341,8 @@ class OCIConfig(CoreModel): Optional[str], Field( description=( - "Compartment where `dstack` will create all resources. " - "Omit to instruct `dstack` to create a new compartment" + "Compartment where `dstack` will create all resources." + " Omit to instruct `dstack` to create a new compartment" ) ), ] = None From 032913eb37b3138577e4cee003a3e47d28b81d9b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 25 Oct 2024 10:36:23 +0500 Subject: [PATCH 2/6] Fix tests --- .../server/services/backends/configurators/azure.py | 6 +++--- src/tests/_internal/server/routers/test_backends.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/server/services/backends/configurators/azure.py b/src/dstack/_internal/server/services/backends/configurators/azure.py index e81e41f4c..8f2dd813c 100644 --- a/src/dstack/_internal/server/services/backends/configurators/azure.py +++ b/src/dstack/_internal/server/services/backends/configurators/azure.py @@ -54,7 +54,6 @@ Configurator, raise_invalid_credentials_error, ) -from dstack._internal.utils.common import get_or_error LOCATIONS = [ ("(US) Central US", "centralus"), @@ -338,7 +337,8 @@ def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial): def _check_vpc_config( self, config: AzureConfigInfoWithCredsPartial, credential: AzureCredential ): - subscription_id = get_or_error(config.subscription_id) + if config.subscription_id is None: + return None allocate_public_ip = config.public_ips if config.public_ips is not None else True if config.public_ips is False and config.vpc_ids is None: raise ServerClientError(msg="`vpc_ids` must be specified if `public_ips: false`.") @@ -360,7 +360,7 @@ def _check_vpc_config( ) network_client = network_mgmt.NetworkManagementClient( credential=credential, - subscription_id=subscription_id, + subscription_id=config.subscription_id, ) with ThreadPoolExecutor(max_workers=8) as executor: futures = [] diff --git a/src/tests/_internal/server/routers/test_backends.py b/src/tests/_internal/server/routers/test_backends.py index 9395e378f..04debbbb0 100644 --- a/src/tests/_internal/server/routers/test_backends.py +++ b/src/tests/_internal/server/routers/test_backends.py @@ -309,7 +309,9 @@ async def test_returns_config_on_valid_creds( "dstack._internal.core.backends.azure.auth.authenticate" ) as authenticate_mock, patch( "azure.mgmt.subscription.SubscriptionClient" - ) as SubscriptionClientMock: + ) as SubscriptionClientMock, patch( + "dstack._internal.core.backends.azure.compute.get_resource_group_network_subnet_or_error" + ): default_creds_available_mock.return_value = False authenticate_mock.return_value = None, "test_tenant" client_mock = SubscriptionClientMock.return_value From 705e88e67da55fa2493ba697b75b82822abd2262 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 25 Oct 2024 10:43:09 +0500 Subject: [PATCH 3/6] Refactor resources errors --- .../_internal/core/backends/aws/resources.py | 4 ++-- .../_internal/core/backends/azure/resources.py | 4 ++-- .../_internal/core/backends/gcp/resources.py | 4 ++-- .../server/services/backends/configurators/aws.py | 10 +++++++--- .../services/backends/configurators/azure.py | 15 +++++++-------- .../server/services/backends/configurators/gcp.py | 6 +++--- .../_internal/core/backends/aws/test_resources.py | 4 ++-- .../core/backends/azure/test_resources.py | 4 ++-- .../_internal/core/backends/gcp/test_resources.py | 4 ++-- 9 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/dstack/_internal/core/backends/aws/resources.py b/src/dstack/_internal/core/backends/aws/resources.py index 4e3914503..729697cf1 100644 --- a/src/dstack/_internal/core/backends/aws/resources.py +++ b/src/dstack/_internal/core/backends/aws/resources.py @@ -5,7 +5,7 @@ import botocore.exceptions import dstack.version as version -from dstack._internal.core.errors import ComputeError, ComputeResourceNotFoundError +from dstack._internal.core.errors import BackendError, ComputeError, ComputeResourceNotFoundError def get_image_id(ec2_client: botocore.client.BaseClient, cuda: bool) -> str: @@ -408,7 +408,7 @@ def make_tags(tags: Dict[str, str]) -> List[Dict[str, str]]: def validate_tags(tags: Dict[str, str]): for k, v in tags.items(): if not _is_valid_tag(k, v): - raise ComputeError( + raise BackendError( "Invalid resource tags. " "See tags restrictions: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions" ) diff --git a/src/dstack/_internal/core/backends/azure/resources.py b/src/dstack/_internal/core/backends/azure/resources.py index 722d74154..dff3b2db2 100644 --- a/src/dstack/_internal/core/backends/azure/resources.py +++ b/src/dstack/_internal/core/backends/azure/resources.py @@ -4,7 +4,7 @@ from azure.mgmt import network as network_mgmt from azure.mgmt.network.models import Subnet -from dstack._internal.core.errors import ComputeError +from dstack._internal.core.errors import BackendError def get_network_subnets( @@ -78,7 +78,7 @@ def _is_eligible_private_subnet( def validate_tags(tags: Dict[str, str]): for k, v in tags.items(): if not _is_valid_tag(k, v): - raise ComputeError( + raise BackendError( "Invalid Azure resource tags. " "See tags restrictions: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/tag-resources#limitations" ) diff --git a/src/dstack/_internal/core/backends/gcp/resources.py b/src/dstack/_internal/core/backends/gcp/resources.py index c34f376eb..e7f8d2ea8 100644 --- a/src/dstack/_internal/core/backends/gcp/resources.py +++ b/src/dstack/_internal/core/backends/gcp/resources.py @@ -11,7 +11,7 @@ from google.cloud import tpu_v2 import dstack.version as version -from dstack._internal.core.errors import ComputeError +from dstack._internal.core.errors import BackendError, ComputeError from dstack._internal.core.models.instances import Gpu from dstack._internal.utils.common import remove_prefix from dstack._internal.utils.logging import get_logger @@ -314,7 +314,7 @@ def get_accelerators( def validate_labels(labels: Dict[str, str]): for k, v in labels.items(): if not _is_valid_label(k, v): - raise ComputeError( + raise BackendError( "Invalid resource labels. " "See labels restrictions: https://cloud.google.com/compute/docs/labeling-resources#requirements" ) diff --git a/src/dstack/_internal/server/services/backends/configurators/aws.py b/src/dstack/_internal/server/services/backends/configurators/aws.py index 743517fb4..beeb00d09 100644 --- a/src/dstack/_internal/server/services/backends/configurators/aws.py +++ b/src/dstack/_internal/server/services/backends/configurators/aws.py @@ -6,7 +6,11 @@ from dstack._internal.core.backends.aws import AWSBackend, auth, compute, resources from dstack._internal.core.backends.aws.config import AWSConfig -from dstack._internal.core.errors import BackendAuthError, ComputeError, ServerClientError +from dstack._internal.core.errors import ( + BackendAuthError, + BackendError, + ServerClientError, +) from dstack._internal.core.models.backends.aws import ( AnyAWSConfigInfo, AWSAccessKeyCreds, @@ -144,7 +148,7 @@ def _check_tags_config(self, config: AWSConfigInfoWithCredsPartial): ) try: resources.validate_tags(config.tags) - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) def _check_vpc_config(self, session: Session, config: AWSConfigInfoWithCredsPartial): @@ -188,5 +192,5 @@ def _check_vpc_config(self, session: Session, config: AWSConfigInfoWithCredsPart for future in concurrent.futures.as_completed(futures): try: future.result() - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) diff --git a/src/dstack/_internal/server/services/backends/configurators/azure.py b/src/dstack/_internal/server/services/backends/configurators/azure.py index 8f2dd813c..78f8abc12 100644 --- a/src/dstack/_internal/server/services/backends/configurators/azure.py +++ b/src/dstack/_internal/server/services/backends/configurators/azure.py @@ -18,15 +18,12 @@ ) from azure.mgmt.resource.resources.models import ResourceGroup -from dstack._internal.core.backends.azure import AzureBackend, auth, resources +from dstack._internal.core.backends.azure import AzureBackend, auth, compute, resources from dstack._internal.core.backends.azure import utils as azure_utils -from dstack._internal.core.backends.azure.auth import AzureCredential -from dstack._internal.core.backends.azure.compute import get_resource_group_network_subnet_or_error from dstack._internal.core.backends.azure.config import AzureConfig from dstack._internal.core.errors import ( BackendAuthError, BackendError, - ComputeError, ServerClientError, ) from dstack._internal.core.models.backends.azure import ( @@ -318,7 +315,9 @@ def func(location: str): for location in locations: executor.submit(func, location) - def _check_config(self, config: AzureConfigInfoWithCredsPartial, credential: AzureCredential): + def _check_config( + self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential + ): self._check_tags_config(config) self._check_vpc_config(config=config, credential=credential) @@ -331,11 +330,11 @@ def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial): ) try: resources.validate_tags(config.tags) - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) def _check_vpc_config( - self, config: AzureConfigInfoWithCredsPartial, credential: AzureCredential + self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential ): if config.subscription_id is None: return None @@ -366,7 +365,7 @@ def _check_vpc_config( futures = [] for location in locations: future = executor.submit( - get_resource_group_network_subnet_or_error, + compute.get_resource_group_network_subnet_or_error, network_client=network_client, resource_group=None, vpc_ids=config.vpc_ids, diff --git a/src/dstack/_internal/server/services/backends/configurators/gcp.py b/src/dstack/_internal/server/services/backends/configurators/gcp.py index 88e411aa3..5eef9962c 100644 --- a/src/dstack/_internal/server/services/backends/configurators/gcp.py +++ b/src/dstack/_internal/server/services/backends/configurators/gcp.py @@ -6,7 +6,7 @@ from dstack._internal.core.backends.gcp import GCPBackend, auth, resources from dstack._internal.core.backends.gcp.config import GCPConfig -from dstack._internal.core.errors import BackendAuthError, ComputeError, ServerClientError +from dstack._internal.core.errors import BackendAuthError, BackendError, ServerClientError from dstack._internal.core.models.backends.base import ( BackendType, ConfigElement, @@ -239,7 +239,7 @@ def _check_tags_config(self, config: GCPConfigInfoWithCredsPartial): ) try: resources.validate_labels(config.tags) - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) def _check_vpc_config( @@ -259,5 +259,5 @@ def _check_vpc_config( shared_vpc_project_id=config.vpc_project_id, allocate_public_ip=allocate_public_ip, ) - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) diff --git a/src/tests/_internal/core/backends/aws/test_resources.py b/src/tests/_internal/core/backends/aws/test_resources.py index 040c60e45..7b23bbf5b 100644 --- a/src/tests/_internal/core/backends/aws/test_resources.py +++ b/src/tests/_internal/core/backends/aws/test_resources.py @@ -5,7 +5,7 @@ _is_valid_tag_value, validate_tags, ) -from dstack._internal.core.errors import ComputeError +from dstack._internal.core.errors import BackendError class TestIsValidTagKey: @@ -69,5 +69,5 @@ def test_validate_valid_tags(self): def test_validate_invalid_tags(self): tags = {"aws:ReservedKey": "SomeValue", "ValidKey": "Invalid#Value"} - with pytest.raises(ComputeError, match="Invalid resource tags"): + with pytest.raises(BackendError, match="Invalid resource tags"): validate_tags(tags) diff --git a/src/tests/_internal/core/backends/azure/test_resources.py b/src/tests/_internal/core/backends/azure/test_resources.py index 643fff0b9..12498fa9a 100644 --- a/src/tests/_internal/core/backends/azure/test_resources.py +++ b/src/tests/_internal/core/backends/azure/test_resources.py @@ -5,7 +5,7 @@ _is_valid_tag_value, validate_tags, ) -from dstack._internal.core.errors import ComputeError +from dstack._internal.core.errors import BackendError class TestValidateTags: @@ -15,7 +15,7 @@ def test_valid_tags(self): def test_invalid_tags(self): tags = {"Invalid Date: Fri, 25 Oct 2024 10:50:17 +0500 Subject: [PATCH 4/6] Do not create default networks if custom specified --- .../services/backends/configurators/azure.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/dstack/_internal/server/services/backends/configurators/azure.py b/src/dstack/_internal/server/services/backends/configurators/azure.py index 78f8abc12..363d331c8 100644 --- a/src/dstack/_internal/server/services/backends/configurators/azure.py +++ b/src/dstack/_internal/server/services/backends/configurators/azure.py @@ -165,6 +165,7 @@ def create_backend( subscription_id=config.subscription_id, resource_group=resource_group, locations=config.locations, + create_default_network=config.vpc_ids is None, ) return BackendModel( project_id=project.id, @@ -289,17 +290,19 @@ def _create_network_resources( subscription_id: str, resource_group: str, locations: List[str], + create_default_network: bool, ): def func(location: str): network_manager = NetworkManager( credential=credential, subscription_id=subscription_id ) - network_manager.create_virtual_network( - resource_group=resource_group, - location=location, - name=azure_utils.get_default_network_name(resource_group, location), - subnet_name=azure_utils.get_default_subnet_name(resource_group, location), - ) + if create_default_network: + network_manager.create_virtual_network( + resource_group=resource_group, + location=location, + name=azure_utils.get_default_network_name(resource_group, location), + subnet_name=azure_utils.get_default_subnet_name(resource_group, location), + ) network_manager.create_network_security_group( resource_group=resource_group, location=location, From 4183df5c4835e83f751923be3a4d75826466cf9b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 25 Oct 2024 11:50:52 +0500 Subject: [PATCH 5/6] Use bigger instances for azure gateways --- src/dstack/_internal/core/backends/azure/compute.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index d99795bbf..f4ffd0f04 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -226,7 +226,7 @@ def create_gateway( resource_group=self.config.resource_group, vpc_ids=self.config.vpc_ids, location=configuration.region, - allocate_public_ip=self.config.allocate_public_ips, + allocate_public_ip=True, ) network_security_group = azure_utils.get_default_network_security_group_name( resource_group=self.config.resource_group, @@ -252,7 +252,7 @@ def create_gateway( subnet=subnet, managed_identity=None, image_reference=_get_gateway_image_ref(), - vm_size="Standard_B1s", + vm_size="Standard_B1ms", instance_name=configuration.instance_name, user_data=get_gateway_user_data(configuration.ssh_key_pub), ssh_pub_keys=[configuration.ssh_key_pub], From 9402149aa49743c9b5afca5e2b36cc9b43deece2 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 25 Oct 2024 14:28:53 +0500 Subject: [PATCH 6/6] Update docs on custom vpcs and private subnets --- docs/docs/reference/server/config.yml.md | 53 ++++++++++++++++--- .../_internal/server/services/config.py | 16 ++++-- 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/docs/docs/reference/server/config.yml.md b/docs/docs/reference/server/config.yml.md index 84190def5..c96b0b84c 100644 --- a/docs/docs/reference/server/config.yml.md +++ b/docs/docs/reference/server/config.yml.md @@ -162,8 +162,8 @@ There are two ways to configure AWS: using an access key or using the default cr For the regions without configured `vpc_ids`, enable default VPCs by setting `default_vpcs` to `true`. ??? info "Private subnets" - By default, `dstack` utilizes public subnets and permits inbound SSH traffic exclusively for any provisioned instances. - If you want `dstack` to use private subnets, set `public_ips` to `false`. + By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic. + If you want `dstack` to use private subnets and provision instances without public IPs, set `public_ips` to `false`. ```yaml projects: @@ -176,8 +176,8 @@ There are two ways to configure AWS: using an access key or using the default cr public_ips: false ``` - Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets - (e.g., through VPC peering). + Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets. + Additionally, private subnets must have outbound internet connectivity provided by NAT Gateway, Transit Gateway, or other mechanism. #### Azure @@ -287,6 +287,44 @@ There are two ways to configure Azure: using a client secret or using the defaul } ``` +??? info "VPC" + By default, `dstack` creates new Azure networks and subnets for every configured region. + It's possible to use custom networks by specifying `vpc_ids`: + + ```yaml + projects: + - name: main + backends: + - type: azure + creds: + type: default + regions: [westeurope] + vpc_ids: + westeurope: myNetworkResourceGroup/myNetworkName + ``` + + +??? info "Private subnets" + By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic. + If you want `dstack` to use private subnets and provision instances without public IPs, + specify custom networks using `vpc_ids` and set `public_ips` to `false`. + + ```yaml + projects: + - name: main + backends: + - type: azure + creds: + type: default + regions: [westeurope] + vpc_ids: + westeurope: myNetworkResourceGroup/myNetworkName + public_ips: false + ``` + + Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets. + Additionally, private subnets must have outbound internet connectivity provided by [NAT Gateway or other mechanism](https://learn.microsoft.com/en-us/azure/nat-gateway/nat-overview). + #### GCP There are two ways to configure GCP: using a service account or using the default credentials. @@ -441,8 +479,8 @@ gcloud projects list --format="json(projectId)" * Allow `INGRESS` traffic on ports `22`, `80`, `443`, with the target tag `dstack-gateway-instance` ??? info "Private subnets" - By default, `dstack` utilizes public subnets and permits inbound SSH traffic exclusively for any provisioned instances. - If you want `dstack` to use private subnets, set `public_ips` to `false`. + By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic. + If you want `dstack` to use private subnets and provision instances without public IPs, set `public_ips` to `false`. ```yaml projects: @@ -455,7 +493,8 @@ gcloud projects list --format="json(projectId)" public_ips: false ``` - Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets (e.g., through VPC peering). Additionally, [Cloud NAT](https://cloud.google.com/nat/docs/overview) must be configured to provide access to external resources for provisioned instances. + Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets. + Additionally, [Cloud NAT](https://cloud.google.com/nat/docs/overview) must be configured to provide access to external resources for provisioned instances. #### Lambda diff --git a/src/dstack/_internal/server/services/config.py b/src/dstack/_internal/server/services/config.py index 29113f93a..255de946b 100644 --- a/src/dstack/_internal/server/services/config.py +++ b/src/dstack/_internal/server/services/config.py @@ -65,7 +65,12 @@ class AWSConfig(CoreModel): regions: Annotated[Optional[List[str]], Field(description="The list of AWS regions")] = None vpc_name: Annotated[ Optional[str], - Field(description="The VPC name. All configured regions must have a VPC with this name"), + Field( + description=( + "The name of custom VPCs. All configured regions must have a VPC with this name." + " If your custom VPCs don't have names or have different names in different regions, use `vpc_ids` instead." + ) + ), ] = None vpc_ids: Annotated[ Optional[Dict[str, str]], @@ -91,7 +96,7 @@ class AWSConfig(CoreModel): Field( description=( "A flag to enable/disable public IP assigning on instances." - " `public_ips: false` requires at least one private subnet with outbound internet connection" + " `public_ips: false` requires at least one private subnet with outbound internet connectivity" " provided by a NAT Gateway or a Transit Gateway." " Defaults to `true`" ) @@ -127,7 +132,8 @@ class AzureConfig(CoreModel): Field( description=( "A flag to enable/disable public IP assigning on instances." - " `public_ips: false` requires `vpc_ids` that specifies custom networks with NAT gateway." + " `public_ips: false` requires `vpc_ids` that specifies custom networks with outbound internet connectivity" + " provided by NAT Gateway or other mechanism." " Defaults to `true`" ) ), @@ -195,7 +201,7 @@ class GCPConfig(CoreModel): type: Annotated[Literal["gcp"], Field(description="The type of backend")] = "gcp" project_id: Annotated[str, Field(description="The project ID")] regions: Optional[List[str]] = None - vpc_name: Annotated[Optional[str], Field(description="The VPC name")] = None + vpc_name: Annotated[Optional[str], Field(description="The name of a custom VPC")] = None vpc_project_id: Annotated[ Optional[str], Field(description="The shared VPC hosted project ID. Required for shared VPC only"), @@ -219,7 +225,7 @@ class GCPAPIConfig(CoreModel): type: Annotated[Literal["gcp"], Field(description="The type of backend")] = "gcp" project_id: Annotated[str, Field(description="The project ID")] regions: Optional[List[str]] = None - vpc_name: Annotated[Optional[str], Field(description="The VPC name")] = None + vpc_name: Annotated[Optional[str], Field(description="The name of a custom VPC")] = None vpc_project_id: Annotated[ Optional[str], Field(description="The shared VPC hosted project ID. Required for shared VPC only"),