Skip to content

Commit

Permalink
Support AWS private subnets with Transit Gateway (#1881)
Browse files Browse the repository at this point in the history
* Support AWS private subnets with Transit Gateway

* Fix comment
  • Loading branch information
r4victor authored Oct 22, 2024
1 parent 9ac6a0e commit 5dff5be
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
6 changes: 5 additions & 1 deletion src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,11 @@ def get_vpc_id_subnet_id_or_error(
return vpc_id, subnets_ids
if allocate_public_ip:
raise ComputeError(f"Failed to find public subnets for VPC {vpc_id}")
raise ComputeError(f"Failed to find private subnets for VPC {vpc_id}")
raise ComputeError(
f"Failed to find private subnets for VPC {vpc_id} with outbound internet access. "
"Ensure you've setup NAT Gateway, Transit Gateway, or other mechanism "
"to provide outbound internet access from private subnets."
)
if not config.use_default_vpcs:
raise ComputeError(f"No VPC ID configured for region {region}")

Expand Down
20 changes: 13 additions & 7 deletions src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,14 @@ def get_subnets_ids_for_vpc(
if is_public_subnet:
subnets_ids.append(subnet_id)
else:
subnet_behind_nat = _is_subnet_behind_nat(
is_eligible_private_subnet = _is_private_subnet_with_internet_egress(
ec2_client=ec2_client,
vpc_id=vpc_id,
subnet_id=subnet_id,
)
if subnet_behind_nat:
if is_eligible_private_subnet:
subnets_ids.append(subnet_id)

return subnets_ids


Expand Down Expand Up @@ -535,7 +536,10 @@ def _is_public_subnet(
return False


def _is_subnet_behind_nat(
_PRIVATE_SUBNET_EGRESS_ROUTE_KEYS = ["NatGatewayId", "TransitGatewayId", "VpcPeeringConnectionId"]


def _is_private_subnet_with_internet_egress(
ec2_client: botocore.client.BaseClient,
vpc_id: str,
subnet_id: str,
Expand All @@ -546,8 +550,9 @@ def _is_subnet_behind_nat(
)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if "NatGatewayId" in route and route["NatGatewayId"].startswith("nat-"):
return True
if route.get("DestinationCidrBlock") == "0.0.0.0/0":
if any(route.get(k) for k in _PRIVATE_SUBNET_EGRESS_ROUTE_KEYS):
return True

# Main route table controls the routing of all subnetes
# that are not explicitly associated with any other route table.
Expand All @@ -563,7 +568,8 @@ def _is_subnet_behind_nat(
)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if "NatGatewayId" in route and route["NatGatewayId"].startswith("nat-"):
return True
if route.get("DestinationCidrBlock") == "0.0.0.0/0":
if any(route.get(k) for k in _PRIVATE_SUBNET_EGRESS_ROUTE_KEYS):
return True

return False

0 comments on commit 5dff5be

Please sign in to comment.