From 9f6eb9d707bb2a01223a620cb96ea4c9df135147 Mon Sep 17 00:00:00 2001 From: Johannes Ziemke Date: Mon, 21 Aug 2023 15:48:11 +0200 Subject: [PATCH] Add secret source support for git and huggingface --- pkg/cmd/agent/submit.go | 5 +- pkg/cmd/agent/test.go | 5 +- pkg/container/docker.go | 2 +- pkg/diambra/config.go | 76 ++++++++++++++++++++----- pkg/diambra/config_test.go | 63 +++++++++++++++++++-- pkg/secretsources/credentials.go | 80 +++++++++++++++++++++++++++ pkg/secretsources/credentials_test.go | 67 ++++++++++++++++++++++ pkg/secretsources/get_token.py | 5 ++ pkg/secretsources/huggingface.go | 42 ++++++++++++++ test/mock-credential-helper.sh | 3 + 10 files changed, 323 insertions(+), 25 deletions(-) create mode 100644 pkg/secretsources/credentials.go create mode 100644 pkg/secretsources/credentials_test.go create mode 100755 pkg/secretsources/get_token.py create mode 100644 pkg/secretsources/huggingface.go create mode 100755 test/mock-credential-helper.sh diff --git a/pkg/cmd/agent/submit.go b/pkg/cmd/agent/submit.go index b676a70..05fccac 100644 --- a/pkg/cmd/agent/submit.go +++ b/pkg/cmd/agent/submit.go @@ -30,7 +30,8 @@ import ( func NewSubmitCmd(logger *log.Logger) *cobra.Command { dump := false - submissionConfig := diambra.NewSubmissionConfig(logger) + submissionConfig := diambra.SubmissionConfig{} + submissionConfig.RegisterCredentialsProviders() c, err := diambra.NewConfig(logger) if err != nil { level.Error(logger).Log("msg", err.Error()) @@ -46,7 +47,7 @@ func NewSubmitCmd(logger *log.Logger) *cobra.Command { level.Error(logger).Log("msg", err.Error()) os.Exit(1) } - submission, err := submissionConfig.Submission(c.CredPath, args) + submission, err := submissionConfig.Submission(c, args) if err != nil { level.Error(logger).Log("msg", "failed to configure manifest", "err", err.Error()) os.Exit(1) diff --git a/pkg/cmd/agent/test.go b/pkg/cmd/agent/test.go index 77c45b7..01839bb 100644 --- a/pkg/cmd/agent/test.go +++ b/pkg/cmd/agent/test.go @@ -24,7 +24,8 @@ const ( ) func NewTestCmd(logger *log.Logger) *cobra.Command { - submissionConfig := diambra.NewSubmissionConfig(logger) + submissionConfig := diambra.SubmissionConfig{} + submissionConfig.RegisterCredentialsProviders() c, err := diambra.NewConfig(logger) if err != nil { level.Error(logger).Log("msg", err.Error()) @@ -37,7 +38,7 @@ func NewTestCmd(logger *log.Logger) *cobra.Command { Long: `This takes a docker image or submission manifest and runs it in the same way as it would be run when submitted to DIAMBRA. This is useful for testing your agent before submitting it. Optionally, you can pass in commands to run instead of the configured entrypoint.`, Run: func(cmd *cobra.Command, args []string) { - submission, err := submissionConfig.Submission(c.CredPath, args) + submission, err := submissionConfig.Submission(c, args) if err != nil { level.Error(logger).Log("msg", "failed to configure manifest", "err", err.Error()) os.Exit(1) diff --git a/pkg/container/docker.go b/pkg/container/docker.go index b8b9f9d..bd70939 100644 --- a/pkg/container/docker.go +++ b/pkg/container/docker.go @@ -60,7 +60,7 @@ func NewDockerRunner(logger log.Logger, client *client.Client, autoRemove bool) func (r *DockerRunner) Pull(c *Container, output *os.File) error { reader, err := r.Client.ImagePull(context.TODO(), c.Image, types.ImagePullOptions{}) if err != nil { - return fmt.Errorf("couldn't pull image %s: %w:\nTo disable pulling the image on start, retry with --images.pull=false", c.Image, err) + return fmt.Errorf("couldn't pull image %s: %w:\nTo disable pulling the image on start, retry with --images.no-pull", c.Image, err) } defer reader.Close() diff --git a/pkg/diambra/config.go b/pkg/diambra/config.go index 1938d62..2519e57 100644 --- a/pkg/diambra/config.go +++ b/pkg/diambra/config.go @@ -27,6 +27,7 @@ import ( "github.com/diambra/cli/pkg/container" "github.com/diambra/cli/pkg/diambra/client" + "github.com/diambra/cli/pkg/secretsources" "github.com/diambra/init/initializer" "github.com/go-kit/log" "github.com/go-kit/log/level" @@ -234,22 +235,28 @@ const ( var ErrInvalidArgs = errors.New("either image, manifest path or submission id must be provided") type SubmissionConfig struct { - logger log.Logger - Mode string Difficulty string EnvVars map[string]string Sources map[string]string Secrets map[string]string + SecretsFrom string ArgsIsCommand bool ManifestPath string SubmissionID int + + credentialsProvider map[string]secretsources.CredentialProvider } -func NewSubmissionConfig(logger log.Logger) *SubmissionConfig { - return &SubmissionConfig{ - logger: logger, +func (c *SubmissionConfig) RegisterCredentialsProvider(name string, provider secretsources.CredentialProvider) { + if c.credentialsProvider == nil { + c.credentialsProvider = make(map[string]secretsources.CredentialProvider) } + c.credentialsProvider[name] = provider +} +func (c *SubmissionConfig) RegisterCredentialsProviders() { + c.RegisterCredentialsProvider("git", &secretsources.GitCredentials{}) + c.RegisterCredentialsProvider("huggingface", &secretsources.HuggingfaceCredentials{}) } func (c *SubmissionConfig) AddFlags(flags *pflag.FlagSet) { @@ -258,12 +265,13 @@ func (c *SubmissionConfig) AddFlags(flags *pflag.FlagSet) { flags.StringToStringVarP(&c.EnvVars, "submission.env", "e", nil, "Environment variables to pass to the agent") flags.StringToStringVarP(&c.Sources, "submission.source", "u", nil, "Source urls to pass to the agent") flags.StringToStringVar(&c.Secrets, "submission.secret", nil, "Secrets to pass to the agent") + flags.StringVar(&c.SecretsFrom, "submission.secrets-from", "", "Automatically add secrets. Supported values: git, huggingface") flags.StringVar(&c.ManifestPath, "submission.manifest", "", "Path to manifest file.") flags.IntVar(&c.SubmissionID, "submission.id", 0, "Submission ID to retrieve manifest from") flags.BoolVar(&c.ArgsIsCommand, "submission.set-command", false, "Treat positional arguments are command instead of entrypoint") } -func (c *SubmissionConfig) Submission(credPath string, args []string) (*client.Submission, error) { +func (c *SubmissionConfig) Submission(config *EnvConfig, args []string) (*client.Submission, error) { var ( nargs = len(args) manifest *client.Manifest @@ -271,7 +279,7 @@ func (c *SubmissionConfig) Submission(credPath string, args []string) (*client.S switch { case c.SubmissionID != 0: - cl, err := client.NewClient(c.logger, credPath) + cl, err := client.NewClient(config.logger, config.CredPath) if err != nil { return nil, fmt.Errorf("failed to create client: %w", err) } @@ -320,22 +328,62 @@ func (c *SubmissionConfig) Submission(credPath string, args []string) (*client.S } if c.Sources != nil { - level.Debug(c.logger).Log("msg", "Using sources", "sources", c.Sources) + level.Debug(config.logger).Log("msg", "Using sources", "sources", c.Sources) manifest.Sources = make(map[string]string) for k, v := range c.Sources { manifest.Sources[k] = v } } - if manifest.Sources != nil { - init, err := initializer.NewInitializer(c.logger, manifest.Sources, c.Secrets, map[string]string{}, "") - if err != nil { - return nil, err + if c.SecretsFrom != "" { + if c.Secrets == nil { + c.Secrets = make(map[string]string) } + } - if err := init.Validate(); err != nil { - return nil, err + if c.SecretsFrom != "" { + ss, ok := c.credentialsProvider[c.SecretsFrom] + if !ok { + return nil, fmt.Errorf("invalid value for --submission.secrets-from: %s", c.SecretsFrom) } + switch c.SecretsFrom { + case "git": + secrets, err := secretsources.CredentialsFill(ss, manifest.Sources) + if err != nil { + return nil, err + } + if manifest.Sources == nil { + return nil, fmt.Errorf("sources are required to use --submission.secrets-from=git") + } + level.Debug(config.logger).Log("msg", "Adding git secrets") + for k, v := range secrets { + level.Info(config.logger).Log("msg", "Adding git secret", "key", k) + c.Secrets[k] = v + } + case "huggingface": + level.Debug(config.logger).Log("msg", "Adding huggingface secrets") + secrets, err := ss.Credentials("") + if err != nil { + return nil, err + } + c.Secrets["HF_TOKEN"] = secrets["HF_TOKEN"] + if manifest.Env == nil { + manifest.Env = make(map[string]string) + } + manifest.Env["HF_TOKEN"] = "{{ .Secrets.HF_TOKEN }}" + case "": + default: + return nil, fmt.Errorf("invalid value for --submission.secrets-from: %s", c.SecretsFrom) + } + } + + init, err := initializer.NewInitializer(config.logger, manifest.Sources, c.Secrets, map[string]string{}, "") + if err != nil { + return nil, err + } + + if err := init.Validate(); err != nil { + return nil, err } return &client.Submission{ diff --git a/pkg/diambra/config_test.go b/pkg/diambra/config_test.go index 2b99ec1..1631cd1 100644 --- a/pkg/diambra/config_test.go +++ b/pkg/diambra/config_test.go @@ -16,9 +16,13 @@ package diambra import ( + "os" + "path/filepath" "testing" "github.com/diambra/cli/pkg/diambra/client" + "github.com/diambra/cli/pkg/secretsources" + "github.com/go-kit/log" "github.com/stretchr/testify/assert" ) @@ -56,6 +60,13 @@ func TestAppArgs(t *testing.T) { } func TestSubmissionConfig(t *testing.T) { + envConfig := &EnvConfig{ + logger: log.NewNopLogger(), + CredPath: "", + } + cwd, err := os.Getwd() + assert.NoError(t, err) + for _, tc := range []struct { name string config SubmissionConfig @@ -113,20 +124,60 @@ func TestSubmissionConfig(t *testing.T) { nil, }, { - "from args, with secrets", - SubmissionConfig{}, - []string{"diambra/agent-random-1:main", "--gameId", "doapp"}, + "from args with sources and secrets", + SubmissionConfig{ + ManifestPath: "testdata/manifest.yaml", + ArgsIsCommand: true, + Sources: map[string]string{"model.zip": "https://user:{{ .Secrets.foo }}@example.com/model.zip"}, + Secrets: map[string]string{ + "foo": "bar", + }, + }, + []string{"python", "agent.py"}, &client.Submission{ Manifest: client.Manifest{ - Image: "diambra/agent-random-1:main", - Args: []string{"--gameId", "doapp"}, + Image: "diambra/agent-random-1:main", + Command: []string{"python", "agent.py"}, + Args: []string{"--gameId", "doapp"}, + Sources: map[string]string{ + "model.zip": "https://user:{{ .Secrets.foo }}@example.com/model.zip", + }, + }, + Secrets: map[string]string{ + "foo": "bar", + }, + }, + nil, + }, + { + "from args with sources and secrets from git", + SubmissionConfig{ + ManifestPath: "testdata/manifest.yaml", + ArgsIsCommand: true, + Sources: map[string]string{"model.zip": "https://example.com/mode.zip"}, + SecretsFrom: "git", + }, + []string{"python", "agent.py"}, + &client.Submission{ + Manifest: client.Manifest{ + Image: "diambra/agent-random-1:main", + Command: []string{"python", "agent.py"}, + Args: []string{"--gameId", "doapp"}, + Sources: map[string]string{ + "model.zip": "https://{{ .Secrets.git_username_1 }}:{{ .Secrets.git_password_1 }}@example.com/mode.zip", + }, + }, + Secrets: map[string]string{ + "git_username_1": "user1", + "git_password_1": "pass1", }, }, nil, }, } { t.Run(tc.name, func(t *testing.T) { - submission, err := tc.config.Submission("", tc.args) + tc.config.RegisterCredentialsProvider("git", &secretsources.GitCredentials{Helper: filepath.Join(cwd, "../../test/mock-credential-helper.sh")}) + submission, err := tc.config.Submission(envConfig, tc.args) assert.Equal(t, tc.expectedErr, err) assert.Equal(t, tc.expected, submission) }) diff --git a/pkg/secretsources/credentials.go b/pkg/secretsources/credentials.go new file mode 100644 index 0000000..b9edb42 --- /dev/null +++ b/pkg/secretsources/credentials.go @@ -0,0 +1,80 @@ +package secretsources + +import ( + "bytes" + "fmt" + "net/url" + "os/exec" + "strings" +) + +type CredentialProvider interface { + Credentials(url string) (map[string]string, error) +} + +type GitCredentials struct { + Helper string +} + +func (c *GitCredentials) Credentials(url string) (map[string]string, error) { + args := []string{} + if c.Helper != "" { + args = append(args, "-c", fmt.Sprintf("credential.helper=%s", c.Helper)) + } + args = append(args, "credential", "fill") + cmd := exec.Command("git", args...) + cmd.Stdin = strings.NewReader("url=" + url + "\n") + + var stdout bytes.Buffer + cmd.Stdout = &stdout + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("failed to run %v: %w", cmd, err) + } + + credentials := make(map[string]string) + lines := strings.Split(stdout.String(), "\n") + for _, line := range lines { + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + credentials[parts[0]] = parts[1] + } + } + + return credentials, nil +} + +// CredentialsFill calls the CredentialsProvider for each source and returns +// a new source map with templating as well as a map of credentials for the templated values. +func CredentialsFill(provider CredentialProvider, sources map[string]string) (map[string]string, error) { + secrets := make(map[string]string) + i := 0 + for k, v := range sources { + i++ + u, err := url.Parse(v) + if err != nil { + return nil, fmt.Errorf("failed to parse url %s: %w", v, err) + } + credentials, err := provider.Credentials(v) + if err != nil { + return nil, err + } + if credentials["password"] == "" { + continue + } + + if credentials["host"] != u.Host { + return nil, fmt.Errorf("host %s does not match %s (this should never happend)", credentials["host"], u.Host) + } + + var ( + uservar = fmt.Sprintf("git_username_%d", i) + passvar = fmt.Sprintf("git_password_%d", i) + ) + + u.User = url.UserPassword(fmt.Sprintf("{{ %s }}", uservar), fmt.Sprintf("{{ %s }}", passvar)) + secrets[uservar] = credentials["username"] + secrets[passvar] = credentials["password"] + sources[k] = fmt.Sprintf("%s://{{ .Secrets.%s }}:{{ .Secrets.%s }}@%s%s", u.Scheme, uservar, passvar, u.Host, u.Path) + } + return secrets, nil +} diff --git a/pkg/secretsources/credentials_test.go b/pkg/secretsources/credentials_test.go new file mode 100644 index 0000000..9ff3ec5 --- /dev/null +++ b/pkg/secretsources/credentials_test.go @@ -0,0 +1,67 @@ +package secretsources + +import ( + "reflect" + "testing" +) + +type MockCredentialProvider struct { + creds map[string]string +} + +func (m *MockCredentialProvider) Credentials(url string) (map[string]string, error) { + return m.creds, nil +} + +func TestCredentialsFill(t *testing.T) { + for _, tc := range []struct { + name string + source map[string]string + credentials map[string]string + expectedSource map[string]string + expectedSecrets map[string]string + }{ + { + name: "no credentials", + source: map[string]string{ + "foo": "git+https://example.com/foo", + }, + credentials: map[string]string{}, + expectedSource: map[string]string{ + "foo": "git+https://example.com/foo", + }, + expectedSecrets: map[string]string{}, + }, + { + name: "single credential", + source: map[string]string{ + "foo": "git+https://example.com/foo", + }, + credentials: map[string]string{ + "username": "foo", + "password": "bar", + "host": "example.com", + }, + expectedSource: map[string]string{ + "foo": "git+https://{{ .Secrets.git_username_1 }}:{{ .Secrets.git_password_1 }}@example.com/foo", + }, + expectedSecrets: map[string]string{ + "git_username_1": "foo", + "git_password_1": "bar", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + secrets, err := CredentialsFill(&MockCredentialProvider{tc.credentials}, tc.source) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !reflect.DeepEqual(tc.expectedSource, tc.source) { + t.Fatalf("expected source %v, got %v", tc.expectedSource, tc.source) + } + if !reflect.DeepEqual(tc.expectedSecrets, secrets) { + t.Fatalf("expected secrets %v, got %v", tc.expectedSecrets, secrets) + } + }) + } +} diff --git a/pkg/secretsources/get_token.py b/pkg/secretsources/get_token.py new file mode 100755 index 0000000..21e6ae9 --- /dev/null +++ b/pkg/secretsources/get_token.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python + +from huggingface_hub.utils import HfFolder + +print(HfFolder.get_token()) \ No newline at end of file diff --git a/pkg/secretsources/huggingface.go b/pkg/secretsources/huggingface.go new file mode 100644 index 0000000..d2f6dc1 --- /dev/null +++ b/pkg/secretsources/huggingface.go @@ -0,0 +1,42 @@ +package secretsources + +import ( + "bytes" + _ "embed" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/go-kit/log" + + "github.com/diambra/cli/pkg/pyarena" + "github.com/go-kit/log/level" +) + +const HFTokenPath = ".cache/huggingface/token" + +//go:embed get_token.py +var GetHuggingfaceToken string + +type HuggingfaceCredentials struct { + logger log.Logger + Home string +} + +func (c *HuggingfaceCredentials) Credentials(url string) (map[string]string, error) { + cmd := exec.Command(pyarena.FindPython(), "-c", GetHuggingfaceToken) + stdout := &bytes.Buffer{} + + cmd.Stdout = stdout + if err := cmd.Run(); err != nil { + level.Debug(c.logger).Log("msg", "couldn't get huggingface token programtically, trying open token file directly", "err", err) + token, err := os.ReadFile(filepath.Join(c.Home, HFTokenPath)) + if err != nil { + return nil, fmt.Errorf("couldn't get huggingface token: %w", err) + } + return map[string]string{"HF_TOKEN": strings.TrimSpace(string(token))}, nil + } + return map[string]string{"HF_TOKEN": strings.TrimSpace(stdout.String())}, nil +} diff --git a/test/mock-credential-helper.sh b/test/mock-credential-helper.sh new file mode 100755 index 0000000..cd7639c --- /dev/null +++ b/test/mock-credential-helper.sh @@ -0,0 +1,3 @@ +#!/bin/sh +echo "username=user1" +echo "password=pass1" \ No newline at end of file