diff --git a/cli/dstack/_internal/backend/aws/runners.py b/cli/dstack/_internal/backend/aws/runners.py index 09fcada87..4855973bb 100644 --- a/cli/dstack/_internal/backend/aws/runners.py +++ b/cli/dstack/_internal/backend/aws/runners.py @@ -9,7 +9,12 @@ from dstack import version from dstack._internal.backend.aws import utils as aws_utils -from dstack._internal.backend.base.compute import WS_PORT, NoCapacityError, choose_instance_type +from dstack._internal.backend.base.compute import ( + WS_PORT, + NoCapacityError, + choose_instance_type, + get_dstack_runner, +) from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo @@ -460,6 +465,7 @@ def _user_data( EC2_PUBLIC_HOSTNAME="`wget -q -O - http://169.254.169.254/latest/meta-data/public-hostname || die \"wget public-hostname has failed: $?\"`" echo "hostname: $EC2_PUBLIC_HOSTNAME" >> /root/.dstack/{RUNNER_CONFIG_FILENAME} mkdir ~/.ssh; chmod 700 ~/.ssh; echo "{ssh_key_pub}" > ~/.ssh/authorized_keys; chmod 600 ~/.ssh/authorized_keys +{get_dstack_runner()} HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT} & """ return user_data diff --git a/cli/dstack/_internal/backend/azure/compute.py b/cli/dstack/_internal/backend/azure/compute.py index 8f37bc18e..7e02250c7 100644 --- a/cli/dstack/_internal/backend/azure/compute.py +++ b/cli/dstack/_internal/backend/azure/compute.py @@ -38,7 +38,12 @@ from dstack import version from dstack._internal.backend.azure import utils as azure_utils from dstack._internal.backend.azure.config import AzureConfig -from dstack._internal.backend.base.compute import WS_PORT, Compute, choose_instance_type +from dstack._internal.backend.base.compute import ( + WS_PORT, + Compute, + choose_instance_type, + get_dstack_runner, +) from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo @@ -237,6 +242,7 @@ def _get_user_data_script(azure_config: AzureConfig, job: Job, instance_type: In mkdir -p /root/.dstack/ echo '{config_content}' > /root/.dstack/{BACKEND_CONFIG_FILENAME} echo '{runner_content}' > /root/.dstack/{RUNNER_CONFIG_FILENAME} +{get_dstack_runner()} HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT} """ diff --git a/cli/dstack/_internal/backend/base/compute.py b/cli/dstack/_internal/backend/base/compute.py index 61345b81b..b3574769e 100644 --- a/cli/dstack/_internal/backend/base/compute.py +++ b/cli/dstack/_internal/backend/base/compute.py @@ -1,7 +1,9 @@ +import os from abc import ABC, abstractmethod from functools import cmp_to_key from typing import List, Optional +import dstack.version as version from dstack._internal.core.error import DstackError from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job, Requirements @@ -119,3 +121,18 @@ def _matches_requirements(resources: Resources, requirements: Optional[Requireme ): return False return True + + +def get_dstack_runner() -> str: + if version.__is_release__: + bucket = "dstack-runner-downloads" + build = version.__version__ + else: # stgn + bucket = "dstack-runner-downloads-stgn" + build = version.__version__ or os.environ.get("DSTACK_RUNNER_BUILD", None) + + commands = [ + f'sudo curl --output /usr/local/bin/dstack-runner "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"', + f"sudo chmod +x /usr/local/bin/dstack-runner", + ] + return "\n".join(commands) diff --git a/cli/dstack/_internal/backend/gcp/compute.py b/cli/dstack/_internal/backend/gcp/compute.py index 8c1027318..6b8a4285f 100644 --- a/cli/dstack/_internal/backend/gcp/compute.py +++ b/cli/dstack/_internal/backend/gcp/compute.py @@ -11,6 +11,7 @@ Compute, NoCapacityError, choose_instance_type, + get_dstack_runner, ) from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml @@ -446,6 +447,7 @@ def _get_user_data_script(gcp_config: GCPConfig, job: Job, instance_type: Instan echo '{runner_content}' > /root/.dstack/{RUNNER_CONFIG_FILENAME} EXTERNAL_IP=`curl -H "Metadata-Flavor: Google" http://169.254.169.254/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip` echo "hostname: $EXTERNAL_IP" >> /root/.dstack/{RUNNER_CONFIG_FILENAME} +{get_dstack_runner()} HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT} """ diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py index 4ddc8c547..78308836f 100644 --- a/cli/dstack/_internal/backend/lambdalabs/compute.py +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -10,7 +10,7 @@ import yaml from dstack import version -from dstack._internal.backend.base.compute import WS_PORT, choose_instance_type +from dstack._internal.backend.base.compute import WS_PORT, choose_instance_type, get_dstack_runner from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml from dstack._internal.backend.lambdalabs.api_client import LambdaAPIClient @@ -206,6 +206,7 @@ def _get_launch_script(lambda_config: LambdaConfig, job: Job, instance_type: Ins echo '{config_content}' > /root/.dstack/{BACKEND_CONFIG_FILENAME} echo '{runner_content}' > /root/.dstack/{RUNNER_CONFIG_FILENAME} echo 'hostname: HOSTNAME_PLACEHOLDER' >> /root/.dstack/{RUNNER_CONFIG_FILENAME} +{get_dstack_runner()} HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT} """ diff --git a/cli/dstack/version.py b/cli/dstack/version.py index 8a621fa00..c0ea4762c 100644 --- a/cli/dstack/version.py +++ b/cli/dstack/version.py @@ -1,3 +1,4 @@ __version__ = None __is_release__ = False miniforge_image = "0.3" +runner_build = "latest" diff --git a/runner/internal/container/build.go b/runner/internal/container/build.go index c97def2c5..98652e930 100644 --- a/runner/internal/container/build.go +++ b/runner/internal/container/build.go @@ -2,16 +2,8 @@ package container import ( "bytes" - "context" "crypto/sha256" "fmt" - "github.com/docker/docker/api/types" - "github.com/docker/docker/api/types/container" - "github.com/docker/docker/api/types/mount" - docker "github.com/docker/docker/client" - "github.com/dstackai/dstack/runner/internal/gerrors" - "github.com/dstackai/dstack/runner/internal/log" - "io" ) type BuildSpec struct { @@ -28,87 +20,8 @@ type BuildSpec struct { RepoPath string Platform string RepoId string -} - -func BuildImage(ctx context.Context, client docker.APIClient, spec *BuildSpec, imageName string, stoppedCh chan struct{}, logs io.Writer) error { - stopTimeout := 10 * 60 - config := &container.Config{ - Image: spec.BaseImageID, - WorkingDir: spec.WorkDir, - Cmd: spec.Commands, - Entrypoint: spec.Entrypoint, - Env: spec.Env, - StopTimeout: &stopTimeout, - Tty: true, - } - hostConfig := &container.HostConfig{ - Mounts: []mount.Mount{ - { - Type: mount.TypeBind, - Source: spec.RepoPath, - Target: "/workflow", - ReadOnly: true, - }, - }, - } - createResp, err := client.ContainerCreate(ctx, config, hostConfig, nil, nil, "") - if err != nil { - return gerrors.Wrap(err) - } - err = client.ContainerStart(ctx, createResp.ID, types.ContainerStartOptions{}) - if err != nil { - return gerrors.Wrap(err) - } - defer func() { - _ = client.ContainerRemove(ctx, createResp.ID, types.ContainerRemoveOptions{Force: true}) - }() - - log.Trace(ctx, "Streaming build logs") - attachResp, err := client.ContainerAttach(ctx, createResp.ID, types.ContainerAttachOptions{ - Stream: true, - Stdout: true, - Stderr: true, - Logs: true, - }) - if err != nil { - return gerrors.Wrap(err) - } - go func() { - _, err := io.Copy(logs, attachResp.Reader) - if err != nil { - log.Error(ctx, "Failed to stream build logs", "err", err) - } - }() - statusCh, errCh := client.ContainerWait(ctx, createResp.ID, container.WaitConditionNotRunning) - if err != nil { - return gerrors.Wrap(err) - } - select { - case err := <-errCh: - if err != nil { - return gerrors.Wrap(err) - } - case <-stoppedCh: - err := client.ContainerKill(ctx, createResp.ID, "SIGTERM") - if err != nil { - return gerrors.Wrap(err) - } - case <-statusCh: - } - info, err := client.ContainerInspect(ctx, createResp.ID) - if err != nil { - return gerrors.Wrap(err) - } - if info.State.ExitCode != 0 { - return gerrors.Wrap(ContainerExitedError{info.State.ExitCode}) - } - log.Trace(ctx, "Committing build image", "image", imageName) - _, err = client.ContainerCommit(ctx, createResp.ID, types.ContainerCommitOptions{Reference: imageName}) - if err != nil { - return gerrors.Wrap(err) - } - return nil + ShmSize int64 } func (s *BuildSpec) Hash() string { diff --git a/runner/internal/container/engine.go b/runner/internal/container/engine.go index 52830d514..f1231dd94 100644 --- a/runner/internal/container/engine.go +++ b/runner/internal/container/engine.go @@ -3,6 +3,7 @@ package container import ( "context" "fmt" + "github.com/docker/docker/api/types/mount" "github.com/dstackai/dstack/runner/internal/environment" "github.com/dstackai/dstack/runner/internal/models" "io" @@ -326,6 +327,7 @@ func (r *Engine) NewBuildSpec(ctx context.Context, job *models.Job, spec *Spec, RegistryAuthBase64: spec.RegistryAuthBase64, RepoPath: repoPath, RepoId: job.RepoId, + ShmSize: spec.ShmSize, } if daemonInfo.Architecture == "aarch64" { buildSpec.Platform = "arm64" @@ -336,7 +338,59 @@ func (r *Engine) NewBuildSpec(ctx context.Context, job *models.Job, spec *Spec, } func (r *Engine) Build(ctx context.Context, spec *BuildSpec, imageName string, stoppedCh chan struct{}, logs io.Writer) error { - if err := BuildImage(ctx, r.client, spec, imageName, stoppedCh, logs); err != nil { + containerSpec := &Spec{ + Image: spec.BaseImageName, + RegistryAuthBase64: spec.RegistryAuthBase64, + WorkDir: spec.WorkDir, + Commands: spec.Commands, + Entrypoint: spec.Entrypoint, + Env: spec.Env, + ShmSize: spec.ShmSize, + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: spec.RepoPath, + Target: "/workflow", + ReadOnly: true, + }, + }, + } + dockerRuntime, err := r.Create(ctx, containerSpec, logs) + if err != nil { + return gerrors.Wrap(err) + } + if err = dockerRuntime.Run(ctx); err != nil { + return gerrors.Wrap(err) + } + defer func() { + _ = r.client.ContainerRemove(ctx, dockerRuntime.containerID, types.ContainerRemoveOptions{Force: true}) + }() + + statusCh, errCh := r.client.ContainerWait(ctx, dockerRuntime.containerID, container.WaitConditionNotRunning) + select { + // todo timeout + case err := <-errCh: + if err != nil { + return gerrors.Wrap(err) + } + case <-stoppedCh: + err := r.client.ContainerKill(ctx, dockerRuntime.containerID, "SIGTERM") + if err != nil { + return gerrors.Wrap(err) + } + case <-statusCh: + } + + info, err := r.client.ContainerInspect(ctx, dockerRuntime.containerID) + if err != nil { + return gerrors.Wrap(err) + } + if info.State.ExitCode != 0 { + return gerrors.Wrap(ContainerExitedError{info.State.ExitCode}) + } + log.Trace(ctx, "Committing build image", "image", imageName) + _, err = r.client.ContainerCommit(ctx, dockerRuntime.containerID, types.ContainerCommitOptions{Reference: imageName}) + if err != nil { return gerrors.Wrap(err) } return nil