Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom and private networks for Azure #1896

Merged
merged 6 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions docs/docs/reference/server/config.yml.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ There are two ways to configure AWS: using an access key or using the default cr
For the regions without configured `vpc_ids`, enable default VPCs by setting `default_vpcs` to `true`.

??? info "Private subnets"
By default, `dstack` utilizes public subnets and permits inbound SSH traffic exclusively for any provisioned instances.
If you want `dstack` to use private subnets, set `public_ips` to `false`.
By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic.
If you want `dstack` to use private subnets and provision instances without public IPs, set `public_ips` to `false`.

```yaml
projects:
Expand All @@ -176,8 +176,8 @@ There are two ways to configure AWS: using an access key or using the default cr
public_ips: false
```

Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets
(e.g., through VPC peering).
Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets.
Additionally, private subnets must have outbound internet connectivity provided by NAT Gateway, Transit Gateway, or other mechanism.

#### Azure

Expand Down Expand Up @@ -287,6 +287,44 @@ There are two ways to configure Azure: using a client secret or using the defaul
}
```

??? info "VPC"
By default, `dstack` creates new Azure networks and subnets for every configured region.
It's possible to use custom networks by specifying `vpc_ids`:

```yaml
projects:
- name: main
backends:
- type: azure
creds:
type: default
regions: [westeurope]
vpc_ids:
westeurope: myNetworkResourceGroup/myNetworkName
```


??? info "Private subnets"
By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic.
If you want `dstack` to use private subnets and provision instances without public IPs,
specify custom networks using `vpc_ids` and set `public_ips` to `false`.

```yaml
projects:
- name: main
backends:
- type: azure
creds:
type: default
regions: [westeurope]
vpc_ids:
westeurope: myNetworkResourceGroup/myNetworkName
public_ips: false
```

Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets.
Additionally, private subnets must have outbound internet connectivity provided by [NAT Gateway or other mechanism](https://learn.microsoft.com/en-us/azure/nat-gateway/nat-overview).

#### GCP

There are two ways to configure GCP: using a service account or using the default credentials.
Expand Down Expand Up @@ -441,8 +479,8 @@ gcloud projects list --format="json(projectId)"
* Allow `INGRESS` traffic on ports `22`, `80`, `443`, with the target tag `dstack-gateway-instance`

??? info "Private subnets"
By default, `dstack` utilizes public subnets and permits inbound SSH traffic exclusively for any provisioned instances.
If you want `dstack` to use private subnets, set `public_ips` to `false`.
By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic.
If you want `dstack` to use private subnets and provision instances without public IPs, set `public_ips` to `false`.

```yaml
projects:
Expand All @@ -455,7 +493,8 @@ gcloud projects list --format="json(projectId)"
public_ips: false
```

Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets (e.g., through VPC peering). Additionally, [Cloud NAT](https://cloud.google.com/nat/docs/overview) must be configured to provide access to external resources for provisioned instances.
Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets.
Additionally, [Cloud NAT](https://cloud.google.com/nat/docs/overview) must be configured to provide access to external resources for provisioned instances.

#### Lambda

Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import botocore.exceptions

import dstack.version as version
from dstack._internal.core.errors import ComputeError, ComputeResourceNotFoundError
from dstack._internal.core.errors import BackendError, ComputeError, ComputeResourceNotFoundError


def get_image_id(ec2_client: botocore.client.BaseClient, cuda: bool) -> str:
Expand Down Expand Up @@ -408,7 +408,7 @@ def make_tags(tags: Dict[str, str]) -> List[Dict[str, str]]:
def validate_tags(tags: Dict[str, str]):
for k, v in tags.items():
if not _is_valid_tag(k, v):
raise ComputeError(
raise BackendError(
"Invalid resource tags. "
"See tags restrictions: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions"
)
Expand Down
143 changes: 109 additions & 34 deletions src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict, List, Optional, Tuple

from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceExistsError
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.mgmt import compute as compute_mgmt
from azure.mgmt import network as network_mgmt
from azure.mgmt.compute.models import (
Expand Down Expand Up @@ -33,6 +33,7 @@

from dstack import version
from dstack._internal import settings
from dstack._internal.core.backends.azure import resources as azure_resources
from dstack._internal.core.backends.azure import utils as azure_utils
from dstack._internal.core.backends.azure.config import AzureConfig
from dstack._internal.core.backends.base.compute import (
Expand Down Expand Up @@ -110,6 +111,19 @@ def create_instance(
ssh_pub_keys = instance_config.get_public_keys()
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)

allocate_public_ip = self.config.allocate_public_ips
network_resource_group, network, subnet = get_resource_group_network_subnet_or_error(
network_client=self._network_client,
resource_group=self.config.resource_group,
vpc_ids=self.config.vpc_ids,
location=location,
allocate_public_ip=allocate_public_ip,
)
network_security_group = azure_utils.get_default_network_security_group_name(
resource_group=self.config.resource_group,
location=location,
)

tags = {
"owner": "dstack",
"dstack_project": instance_config.project_name,
Expand All @@ -122,18 +136,9 @@ def create_instance(
subscription_id=self.config.subscription_id,
location=location,
resource_group=self.config.resource_group,
network_security_group=azure_utils.get_default_network_security_group_name(
resource_group=self.config.resource_group,
location=location,
),
network=azure_utils.get_default_network_name(
resource_group=self.config.resource_group,
location=location,
),
subnet=azure_utils.get_default_subnet_name(
resource_group=self.config.resource_group,
location=location,
),
network_security_group=network_security_group,
network=network,
subnet=subnet,
managed_identity=None,
image_reference=_get_image_ref(
compute_client=self._compute_client,
Expand All @@ -149,6 +154,8 @@ def create_instance(
spot=instance_offer.instance.resources.spot,
disk_size=disk_size,
computer_name="runnervm",
allocate_public_ip=allocate_public_ip,
network_resource_group=network_resource_group,
tags=tags,
)
logger.info("Request succeeded")
Expand All @@ -157,11 +164,14 @@ def create_instance(
resource_group=self.config.resource_group,
vm=vm,
)
hostname = public_ip
if allocate_public_ip:
hostname = private_ip
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=vm.name,
hostname=public_ip,
hostname=hostname,
internal_ip=private_ip,
region=location,
price=instance_offer.price,
Expand Down Expand Up @@ -211,6 +221,18 @@ def create_gateway(
configuration.region,
)

network_resource_group, network, subnet = get_resource_group_network_subnet_or_error(
network_client=self._network_client,
resource_group=self.config.resource_group,
vpc_ids=self.config.vpc_ids,
location=configuration.region,
allocate_public_ip=True,
)
network_security_group = azure_utils.get_default_network_security_group_name(
resource_group=self.config.resource_group,
location=configuration.region,
)

tags = {
"Name": configuration.instance_name,
"owner": "dstack",
Expand All @@ -225,27 +247,19 @@ def create_gateway(
subscription_id=self.config.subscription_id,
location=configuration.region,
resource_group=self.config.resource_group,
network_security_group=azure_utils.get_gateway_network_security_group_name(
resource_group=self.config.resource_group,
location=configuration.region,
),
network=azure_utils.get_default_network_name(
resource_group=self.config.resource_group,
location=configuration.region,
),
subnet=azure_utils.get_default_subnet_name(
resource_group=self.config.resource_group,
location=configuration.region,
),
network_security_group=network_security_group,
network=network,
subnet=subnet,
managed_identity=None,
image_reference=_get_gateway_image_ref(),
vm_size="Standard_B1s",
vm_size="Standard_B1ms",
instance_name=configuration.instance_name,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
ssh_pub_keys=[configuration.ssh_key_pub],
spot=False,
disk_size=30,
computer_name="gatewayvm",
network_resource_group=network_resource_group,
tags=tags,
)
logger.info("Request succeeded")
Expand Down Expand Up @@ -273,6 +287,57 @@ def terminate_gateway(
)


def get_resource_group_network_subnet_or_error(
network_client: network_mgmt.NetworkManagementClient,
resource_group: Optional[str],
vpc_ids: Optional[Dict[str, str]],
location: str,
allocate_public_ip: bool,
) -> Tuple[str, str, str]:
if vpc_ids is not None:
vpc_id = vpc_ids.get(location)
if vpc_id is None:
raise ComputeError(f"Network not configured for location {location}")
try:
resource_group, network_name = _parse_config_vpc_id(vpc_id)
except Exception:
raise ComputeError(
"Network specified in incorrect format."
" Supported format for `vps_ids` values: 'networkResourceGroupName/networkName'"
)
elif resource_group is not None:
network_name = azure_utils.get_default_network_name(resource_group, location)
else:
raise ComputeError("`resource_group` or `vpc_ids` must be specified")

try:
subnets = azure_resources.get_network_subnets(
network_client=network_client,
resource_group=resource_group,
network_name=network_name,
private=not allocate_public_ip,
)
except ResourceNotFoundError:
raise ComputeError(
f"Network {network_name} not found in location {location} in resource group {resource_group}"
)

if len(subnets) == 0:
if not allocate_public_ip:
raise ComputeError(
f"Failed to find private subnets with outbound internet connectivity in network {network_name}"
)
raise ComputeError(f"Failed to find subnets in network {network_name}")

subnet_name = subnets[0]
return resource_group, network_name, subnet_name


def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
resource_group, network_name = vpc_id.split("/")
return resource_group, network_name


class VMImageVariant(enum.Enum):
GRID = enum.auto()
CUDA = enum.auto()
Expand Down Expand Up @@ -396,10 +461,19 @@ def _launch_instance(
spot: bool,
disk_size: int,
computer_name: str,
allocate_public_ip: bool = True,
network_resource_group: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> VirtualMachine:
if tags is None:
tags = {}
if network_resource_group is None:
network_resource_group = resource_group
public_ip_address_configuration = None
if allocate_public_ip:
public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration(
name="public_ip_config",
)
try:
poller = compute_client.virtual_machines.begin_create_or_update(
resource_group,
Expand Down Expand Up @@ -451,14 +525,12 @@ def _launch_instance(
subnet=SubResource(
id=azure_utils.get_subnet_id(
subscription_id,
resource_group,
network_resource_group,
network,
subnet,
)
),
public_ip_address_configuration=VirtualMachinePublicIPAddressConfiguration(
name="public_ip_config",
),
public_ip_address_configuration=public_ip_address_configuration,
)
],
)
Expand Down Expand Up @@ -505,18 +577,21 @@ def _get_vm_public_private_ips(
network_client: network_mgmt.NetworkManagementClient,
resource_group: str,
vm: VirtualMachine,
) -> Tuple[str, str]:
) -> Tuple[Optional[str], str]:
nic_id = vm.network_profile.network_interfaces[0].id
nic_name = azure_utils.get_resource_name_from_resource_id(nic_id)
nic = network_client.network_interfaces.get(
resource_group_name=resource_group,
network_interface_name=nic_name,
)

private_ip = nic.ip_configurations[0].private_ip_address
if nic.ip_configurations[0].public_ip_address is None:
return None, private_ip

public_ip_id = nic.ip_configurations[0].public_ip_address.id
public_ip_name = azure_utils.get_resource_name_from_resource_id(public_ip_id)
public_ip = network_client.public_ip_addresses.get(resource_group, public_ip_name)

private_ip = nic.ip_configurations[0].private_ip_address
return public_ip.ip_address, private_ip


Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/backends/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@

class AzureConfig(AzureStoredConfig, BackendConfig):
creds: AnyAzureCreds

@property
def allocate_public_ips(self) -> bool:
if self.public_ips is not None:
return self.public_ips
return True
Loading