Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use nvidia docker runtime for build if available #552

Merged
merged 2 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cli/dstack/_internal/backend/aws/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from dstack import version
from dstack._internal.backend.aws import utils as aws_utils
from dstack._internal.backend.base.compute import WS_PORT, NoCapacityError, choose_instance_type
from dstack._internal.backend.base.compute import (
WS_PORT,
NoCapacityError,
choose_instance_type,
get_dstack_runner,
)
from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME
from dstack._internal.backend.base.runners import serialize_runner_yaml
from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo
Expand Down Expand Up @@ -460,6 +465,7 @@ def _user_data(
EC2_PUBLIC_HOSTNAME="`wget -q -O - http://169.254.169.254/latest/meta-data/public-hostname || die \"wget public-hostname has failed: $?\"`"
echo "hostname: $EC2_PUBLIC_HOSTNAME" >> /root/.dstack/{RUNNER_CONFIG_FILENAME}
mkdir ~/.ssh; chmod 700 ~/.ssh; echo "{ssh_key_pub}" > ~/.ssh/authorized_keys; chmod 600 ~/.ssh/authorized_keys
{get_dstack_runner()}
HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT} &
"""
return user_data
Expand Down
8 changes: 7 additions & 1 deletion cli/dstack/_internal/backend/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
from dstack import version
from dstack._internal.backend.azure import utils as azure_utils
from dstack._internal.backend.azure.config import AzureConfig
from dstack._internal.backend.base.compute import WS_PORT, Compute, choose_instance_type
from dstack._internal.backend.base.compute import (
WS_PORT,
Compute,
choose_instance_type,
get_dstack_runner,
)
from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME
from dstack._internal.backend.base.runners import serialize_runner_yaml
from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo
Expand Down Expand Up @@ -237,6 +242,7 @@ def _get_user_data_script(azure_config: AzureConfig, job: Job, instance_type: In
mkdir -p /root/.dstack/
echo '{config_content}' > /root/.dstack/{BACKEND_CONFIG_FILENAME}
echo '{runner_content}' > /root/.dstack/{RUNNER_CONFIG_FILENAME}
{get_dstack_runner()}
HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT}
"""

Expand Down
17 changes: 17 additions & 0 deletions cli/dstack/_internal/backend/base/compute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from abc import ABC, abstractmethod
from functools import cmp_to_key
from typing import List, Optional

import dstack.version as version
from dstack._internal.core.error import DstackError
from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo
from dstack._internal.core.job import Job, Requirements
Expand Down Expand Up @@ -119,3 +121,18 @@ def _matches_requirements(resources: Resources, requirements: Optional[Requireme
):
return False
return True


def get_dstack_runner() -> str:
if version.__is_release__:
bucket = "dstack-runner-downloads"
build = version.__version__
else: # stgn
bucket = "dstack-runner-downloads-stgn"
build = version.__version__ or os.environ.get("DSTACK_RUNNER_BUILD", None)

commands = [
f'sudo curl --output /usr/local/bin/dstack-runner "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"',
f"sudo chmod +x /usr/local/bin/dstack-runner",
]
return "\n".join(commands)
2 changes: 2 additions & 0 deletions cli/dstack/_internal/backend/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Compute,
NoCapacityError,
choose_instance_type,
get_dstack_runner,
)
from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME
from dstack._internal.backend.base.runners import serialize_runner_yaml
Expand Down Expand Up @@ -446,6 +447,7 @@ def _get_user_data_script(gcp_config: GCPConfig, job: Job, instance_type: Instan
echo '{runner_content}' > /root/.dstack/{RUNNER_CONFIG_FILENAME}
EXTERNAL_IP=`curl -H "Metadata-Flavor: Google" http://169.254.169.254/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip`
echo "hostname: $EXTERNAL_IP" >> /root/.dstack/{RUNNER_CONFIG_FILENAME}
{get_dstack_runner()}
HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT}
"""

Expand Down
3 changes: 2 additions & 1 deletion cli/dstack/_internal/backend/lambdalabs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import yaml

from dstack import version
from dstack._internal.backend.base.compute import WS_PORT, choose_instance_type
from dstack._internal.backend.base.compute import WS_PORT, choose_instance_type, get_dstack_runner
from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME
from dstack._internal.backend.base.runners import serialize_runner_yaml
from dstack._internal.backend.lambdalabs.api_client import LambdaAPIClient
Expand Down Expand Up @@ -206,6 +206,7 @@ def _get_launch_script(lambda_config: LambdaConfig, job: Job, instance_type: Ins
echo '{config_content}' > /root/.dstack/{BACKEND_CONFIG_FILENAME}
echo '{runner_content}' > /root/.dstack/{RUNNER_CONFIG_FILENAME}
echo 'hostname: HOSTNAME_PLACEHOLDER' >> /root/.dstack/{RUNNER_CONFIG_FILENAME}
{get_dstack_runner()}
HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT}
"""

Expand Down
1 change: 1 addition & 0 deletions cli/dstack/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__version__ = None
__is_release__ = False
miniforge_image = "0.3"
runner_build = "latest"
89 changes: 1 addition & 88 deletions runner/internal/container/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,8 @@ package container

import (
"bytes"
"context"
"crypto/sha256"
"fmt"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
docker "github.com/docker/docker/client"
"github.com/dstackai/dstack/runner/internal/gerrors"
"github.com/dstackai/dstack/runner/internal/log"
"io"
)

type BuildSpec struct {
Expand All @@ -28,87 +20,8 @@ type BuildSpec struct {
RepoPath string
Platform string
RepoId string
}

func BuildImage(ctx context.Context, client docker.APIClient, spec *BuildSpec, imageName string, stoppedCh chan struct{}, logs io.Writer) error {
stopTimeout := 10 * 60
config := &container.Config{
Image: spec.BaseImageID,
WorkingDir: spec.WorkDir,
Cmd: spec.Commands,
Entrypoint: spec.Entrypoint,
Env: spec.Env,
StopTimeout: &stopTimeout,
Tty: true,
}
hostConfig := &container.HostConfig{
Mounts: []mount.Mount{
{
Type: mount.TypeBind,
Source: spec.RepoPath,
Target: "/workflow",
ReadOnly: true,
},
},
}
createResp, err := client.ContainerCreate(ctx, config, hostConfig, nil, nil, "")
if err != nil {
return gerrors.Wrap(err)
}
err = client.ContainerStart(ctx, createResp.ID, types.ContainerStartOptions{})
if err != nil {
return gerrors.Wrap(err)
}
defer func() {
_ = client.ContainerRemove(ctx, createResp.ID, types.ContainerRemoveOptions{Force: true})
}()

log.Trace(ctx, "Streaming build logs")
attachResp, err := client.ContainerAttach(ctx, createResp.ID, types.ContainerAttachOptions{
Stream: true,
Stdout: true,
Stderr: true,
Logs: true,
})
if err != nil {
return gerrors.Wrap(err)
}
go func() {
_, err := io.Copy(logs, attachResp.Reader)
if err != nil {
log.Error(ctx, "Failed to stream build logs", "err", err)
}
}()

statusCh, errCh := client.ContainerWait(ctx, createResp.ID, container.WaitConditionNotRunning)
if err != nil {
return gerrors.Wrap(err)
}
select {
case err := <-errCh:
if err != nil {
return gerrors.Wrap(err)
}
case <-stoppedCh:
err := client.ContainerKill(ctx, createResp.ID, "SIGTERM")
if err != nil {
return gerrors.Wrap(err)
}
case <-statusCh:
}
info, err := client.ContainerInspect(ctx, createResp.ID)
if err != nil {
return gerrors.Wrap(err)
}
if info.State.ExitCode != 0 {
return gerrors.Wrap(ContainerExitedError{info.State.ExitCode})
}
log.Trace(ctx, "Committing build image", "image", imageName)
_, err = client.ContainerCommit(ctx, createResp.ID, types.ContainerCommitOptions{Reference: imageName})
if err != nil {
return gerrors.Wrap(err)
}
return nil
ShmSize int64
}

func (s *BuildSpec) Hash() string {
Expand Down
56 changes: 55 additions & 1 deletion runner/internal/container/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package container
import (
"context"
"fmt"
"github.com/docker/docker/api/types/mount"
"github.com/dstackai/dstack/runner/internal/environment"
"github.com/dstackai/dstack/runner/internal/models"
"io"
Expand Down Expand Up @@ -326,6 +327,7 @@ func (r *Engine) NewBuildSpec(ctx context.Context, job *models.Job, spec *Spec,
RegistryAuthBase64: spec.RegistryAuthBase64,
RepoPath: repoPath,
RepoId: job.RepoId,
ShmSize: spec.ShmSize,
}
if daemonInfo.Architecture == "aarch64" {
buildSpec.Platform = "arm64"
Expand All @@ -336,7 +338,59 @@ func (r *Engine) NewBuildSpec(ctx context.Context, job *models.Job, spec *Spec,
}

func (r *Engine) Build(ctx context.Context, spec *BuildSpec, imageName string, stoppedCh chan struct{}, logs io.Writer) error {
if err := BuildImage(ctx, r.client, spec, imageName, stoppedCh, logs); err != nil {
containerSpec := &Spec{
Image: spec.BaseImageName,
RegistryAuthBase64: spec.RegistryAuthBase64,
WorkDir: spec.WorkDir,
Commands: spec.Commands,
Entrypoint: spec.Entrypoint,
Env: spec.Env,
ShmSize: spec.ShmSize,
Mounts: []mount.Mount{
{
Type: mount.TypeBind,
Source: spec.RepoPath,
Target: "/workflow",
ReadOnly: true,
},
},
}
dockerRuntime, err := r.Create(ctx, containerSpec, logs)
if err != nil {
return gerrors.Wrap(err)
}
if err = dockerRuntime.Run(ctx); err != nil {
return gerrors.Wrap(err)
}
defer func() {
_ = r.client.ContainerRemove(ctx, dockerRuntime.containerID, types.ContainerRemoveOptions{Force: true})
}()

statusCh, errCh := r.client.ContainerWait(ctx, dockerRuntime.containerID, container.WaitConditionNotRunning)
select {
// todo timeout
case err := <-errCh:
if err != nil {
return gerrors.Wrap(err)
}
case <-stoppedCh:
err := r.client.ContainerKill(ctx, dockerRuntime.containerID, "SIGTERM")
if err != nil {
return gerrors.Wrap(err)
}
case <-statusCh:
}

info, err := r.client.ContainerInspect(ctx, dockerRuntime.containerID)
if err != nil {
return gerrors.Wrap(err)
}
if info.State.ExitCode != 0 {
return gerrors.Wrap(ContainerExitedError{info.State.ExitCode})
}
log.Trace(ctx, "Committing build image", "image", imageName)
_, err = r.client.ContainerCommit(ctx, dockerRuntime.containerID, types.ContainerCommitOptions{Reference: imageName})
if err != nil {
return gerrors.Wrap(err)
}
return nil
Expand Down
Loading