Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configuring if service path prefix is stripped #2254

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions docs/docs/concepts/services.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,48 @@ port: 8000

</div>

### Path prefix { #path-prefix }

If your `dstack` project doesn't have a [gateway](gateways.md), services are hosted with the
`/proxy/services/<project name>/<run name>/` path prefix in the URL.
When running web apps, you may need to set some app-specific settings
so that browser-side scripts and CSS work correctly with the path prefix.

<div editor-title="dash.dstack.yml">

```yaml
type: service
name: dash
gateway: false

# Disable authorization
auth: false
# Do not strip the path prefix
strip_prefix: false

env:
# Configure Dash to work with a path prefix
# Replace `main` with your dstack project name
- DASH_ROUTES_PATHNAME_PREFIX=/proxy/services/main/dash/

commands:
- pip install dash
# Assuming the Dash app is in your repo at app.py
- python app.py

port: 8050
```

</div>

By default, `dstack` strips the prefix before forwarding requests to your service,
so to the service it appears as if the prefix isn't there. This allows some apps
to work out of the box. If your app doesn't expect the prefix to be stripped,
set [`strip_prefix`](../reference/dstack.yml/service.md#strip_prefix) to `false`.

If your app cannot be configured to work with a path prefix, you can host it
on a dedicated domain name by setting up a [gateway](gateways.md).

### Model

If the service is running a chat model with an OpenAI-compatible interface,
Expand Down Expand Up @@ -345,9 +387,9 @@ via the [`spot_policy`](../reference/dstack.yml/service.md#spot_policy) property
Running services doesn't require [gateways](gateways.md) unless you need to enable auto-scaling or want the endpoint to
use HTTPS and map it to your domain.

!!! info "Websockets and base path"
A [gateway](gateways.md) may also be required if the service needs Websockets or cannot be used with
a base path.
!!! info "Websockets and path prefix"
A gateway may also be required if the service needs Websockets or cannot be used with
a [path prefix](#path-prefix).

> If you're using [dstack Sky :material-arrow-top-right-thin:{ .external }](https://sky.dstack.ai){:target="_blank"},
> a gateway is already pre-configured for you.
Expand Down
11 changes: 11 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CommandsList = List[str]
ValidPort = conint(gt=0, le=65536)
SERVICE_HTTPS_DEFAULT = True
STRIP_PREFIX_DEFAULT = True


class RunConfigurationType(str, Enum):
Expand Down Expand Up @@ -236,6 +237,16 @@ class ServiceConfigurationParams(CoreModel):
),
),
] = None
strip_prefix: Annotated[
bool,
Field(
description=(
"Strip the `/proxy/services/<project name>/<run name>/` path prefix"
" when forwarding requests to the service. Only takes effect"
" when running the service without a gateway"
)
),
] = STRIP_PREFIX_DEFAULT
model: Annotated[
Optional[Union[AnyModel, str]],
Field(
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/core/services/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def _replace_url(self, match: re.Match) -> bytes:
qs = {k: v[0] for k, v in urllib.parse.parse_qs(url.query).items()}
if app_spec and app_spec.url_query_params is not None:
qs.update({k.encode(): v.encode() for k, v in app_spec.url_query_params.items()})
path = url.path
if not path.startswith(self.path_prefix.removesuffix(b"/")):
path = concat_url_path(self.path_prefix, path)

url = url._replace(
scheme=("https" if self.secure else "http").encode(),
netloc=(self.hostname if omit_port else f"{self.hostname}:{local_port}").encode(),
path=concat_url_path(self.path_prefix, url.path),
path=path,
query=urllib.parse.urlencode(qs).encode(),
)
return url.geturl()
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Service(ImmutableModel):
https: Optional[bool] # only used on gateways
auth: bool
client_max_body_size: int # only enforced on gateways
strip_prefix: bool = True # only used in-server
replicas: tuple[Replica, ...]

@property
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/proxy/lib/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def make_service(
domain: Optional[str] = None,
https: Optional[bool] = None,
auth: bool = False,
strip_prefix: bool = True,
) -> Service:
return Service(
project_name=project_name,
Expand All @@ -37,6 +38,7 @@ def make_service(
https=https,
auth=auth,
client_max_body_size=2**20,
strip_prefix=strip_prefix,
replicas=(
Replica(
id="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/server/services/proxy/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
https=None,
auth=run_spec.configuration.auth,
client_max_body_size=DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE,
strip_prefix=run_spec.configuration.strip_prefix,
replicas=tuple(replicas),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ServiceConnectionPool,
get_service_replica_client,
)
from dstack._internal.utils.common import concat_url_path
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -37,6 +38,9 @@ async def proxy(

client = await get_service_replica_client(service, repo, service_conn_pool)

if not service.strip_prefix:
path = concat_url_path(request.scope.get("root_path", "/"), request.url.path)

try:
upstream_request = await build_upstream_request(request, path, client)
except ClientDisconnect:
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
_UPDATABLE_SPEC_FIELDS = ["repo_code_hash", "configuration"]
# Most service fields can be updated via replica redeployment.
# TODO: Allow updating other fields when a rolling deployment is supported.
_UPDATABLE_CONFIGURATION_FIELDS = ["replicas", "scaling"]
_UPDATABLE_CONFIGURATION_FIELDS = ["replicas", "scaling", "strip_prefix"]


def _can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bool:
Expand Down
12 changes: 12 additions & 0 deletions src/dstack/api/server/_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

from pydantic import parse_obj_as

from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.configurations import (
STRIP_PREFIX_DEFAULT,
ServiceConfiguration,
)
from dstack._internal.core.models.pools import Instance
from dstack._internal.core.models.profiles import Profile
from dstack._internal.core.models.runs import (
Expand Down Expand Up @@ -145,6 +150,13 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[dict]:
configuration_excludes.add("stop_duration")
if profile is not None and profile.stop_duration is None:
profile_excludes.add("stop_duration")
# client >= 0.18.40 / server <= 0.18.39 compatibility tweak
if (
is_core_model_instance(configuration, ServiceConfiguration)
and configuration.strip_prefix == STRIP_PREFIX_DEFAULT
):
configuration_excludes.add("strip_prefix")

if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
if profile_excludes:
Expand Down
22 changes: 16 additions & 6 deletions src/tests/_internal/core/services/test_logs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from dstack._internal.core.models.runs import AppSpec
from dstack._internal.core.services.logs import URLReplacer

Expand Down Expand Up @@ -126,7 +128,18 @@ def test_omit_https_default_port(self):
)
assert replacer(b"http://0.0.0.0:8000/qwerty") == b"https://secure.host.com/qwerty"

def test_in_server_proxy(self):
@pytest.mark.parametrize(
("in_path", "out_path"),
[
("", "/proxy/services/main/service/"),
("/", "/proxy/services/main/service/"),
("/a/b/c", "/proxy/services/main/service/a/b/c"),
("/proxy/services/main/service", "/proxy/services/main/service"),
("/proxy/services/main/service/", "/proxy/services/main/service/"),
("/proxy/services/main/service/a/b/c", "/proxy/services/main/service/a/b/c"),
],
)
def test_adds_prefix_unless_already_present(self, in_path: str, out_path: str) -> None:
replacer = URLReplacer(
ports={8888: 3000},
app_specs=[],
Expand All @@ -135,9 +148,6 @@ def test_in_server_proxy(self):
path_prefix="/proxy/services/main/service/",
)
assert (
replacer(b"http://0.0.0.0:8888") == b"http://0.0.0.0:3000/proxy/services/main/service/"
)
assert (
replacer(b"http://0.0.0.0:8888/qwerty")
== b"http://0.0.0.0:3000/proxy/services/main/service/qwerty"
replacer(f"http://0.0.0.0:8888{in_path}".encode())
== f"http://0.0.0.0:3000{out_path}".encode()
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import httpx
import pytest
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse

from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo
from dstack._internal.proxy.lib.auth import BaseProxyAuthProvider
Expand All @@ -25,6 +26,8 @@

@pytest.fixture
def mock_replica_client_httpbin(httpbin) -> Generator[None, None, None]:
"""Mocks deployed services. Replaces them with httpbin"""

with patch(
"dstack._internal.proxy.lib.services.service_connection.ServiceConnectionPool.get_or_add"
) as add_connection_mock:
Expand All @@ -34,6 +37,20 @@ def mock_replica_client_httpbin(httpbin) -> Generator[None, None, None]:
yield


@pytest.fixture
def mock_replica_client_path_reporter() -> Generator[None, None, None]:
"""Mocks deployed services. Replaces them with an app that returns the requested path"""

app = FastAPI()
app.get("{path:path}")(lambda path: PlainTextResponse(path))
client = ServiceClient(base_url="http://test/", transport=httpx.ASGITransport(app))
with patch(
"dstack._internal.proxy.lib.services.service_connection.ServiceConnectionPool.get_or_add"
) as add_connection_mock:
add_connection_mock.return_value.client.return_value = client
yield


def make_app(
repo: BaseProxyRepo, auth: BaseProxyAuthProvider = ProxyTestAuthProvider()
) -> FastAPI:
Expand Down Expand Up @@ -200,3 +217,25 @@ async def test_auth(mock_replica_client_httpbin, token: Optional[str], status: i
url = "http://test-host/proxy/services/test-proj/httpbin/"
resp = await client.get(url, headers=headers)
assert resp.status_code == status


@pytest.mark.asyncio
@pytest.mark.parametrize(
("strip", "downstream_path", "upstream_path"),
[
(True, "/proxy/services/my-proj/my-run/", "/"),
(True, "/proxy/services/my-proj/my-run/a/b", "/a/b"),
(False, "/proxy/services/my-proj/my-run/", "/proxy/services/my-proj/my-run/"),
(False, "/proxy/services/my-proj/my-run/a/b", "/proxy/services/my-proj/my-run/a/b"),
],
)
async def test_strip_prefix(
mock_replica_client_path_reporter, strip: bool, downstream_path: str, upstream_path: str
) -> None:
repo = ProxyTestRepo()
await repo.set_project(make_project("my-proj"))
await repo.set_service(make_service("my-proj", "my-run", strip_prefix=strip))
_, client = make_app_client(repo)
resp = await client.get(f"http://test-host{downstream_path}")
assert resp.status_code == 200
assert resp.text == upstream_path