diff --git a/modules/ollama/examples_test.go b/modules/ollama/examples_test.go index 3601e0b120..188be45bbb 100644 --- a/modules/ollama/examples_test.go +++ b/modules/ollama/examples_test.go @@ -178,7 +178,7 @@ func ExampleRun_withLocal() { ctx := context.Background() // localOllama { - ollamaContainer, err := tcollama.Run(ctx, "ollama/ollama:0.3.13", tcollama.WithUseLocal(map[string]string{"OLLAMA_DEBUG": "true"})) + ollamaContainer, err := tcollama.Run(ctx, "ollama/ollama:0.3.13", tcollama.WithUseLocal("OLLAMA_DEBUG=true")) defer func() { if err := testcontainers.TerminateContainer(ollamaContainer); err != nil { log.Printf("failed to terminate container: %s", err) diff --git a/modules/ollama/local_test.go b/modules/ollama/local_test.go index bb063fb361..7bd073ca5e 100644 --- a/modules/ollama/local_test.go +++ b/modules/ollama/local_test.go @@ -29,7 +29,7 @@ func TestRun_local(t *testing.T) { ollamaContainer, err := ollama.Run( ctx, "ollama/ollama:0.1.25", - ollama.WithUseLocal(map[string]string{"FOO": "BAR"}), + ollama.WithUseLocal("FOO=BAR"), ) testcontainers.CleanupContainer(t, ollamaContainer) require.NoError(t, err) @@ -266,7 +266,7 @@ func TestRun_localWithCustomLogFile(t *testing.T) { ctx := context.Background() - ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.1.25", ollama.WithUseLocal(map[string]string{"FOO": "BAR"})) + ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.1.25", ollama.WithUseLocal("FOO=BAR")) require.NoError(t, err) testcontainers.CleanupContainer(t, ollamaContainer) @@ -285,7 +285,7 @@ func TestRun_localWithCustomHost(t *testing.T) { ctx := context.Background() - ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.1.25", ollama.WithUseLocal(nil)) + ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.1.25", ollama.WithUseLocal()) require.NoError(t, err) testcontainers.CleanupContainer(t, ollamaContainer) diff --git a/modules/ollama/options.go b/modules/ollama/options.go index 4653b65169..4761a28530 100644 --- a/modules/ollama/options.go +++ b/modules/ollama/options.go @@ -2,6 +2,8 @@ package ollama import ( "context" + "fmt" + "strings" "github.com/docker/docker/api/types/container" @@ -42,23 +44,29 @@ var _ testcontainers.ContainerCustomizer = (*UseLocal)(nil) // UseLocal will use the local Ollama instance instead of pulling the Docker image. type UseLocal struct { - env map[string]string + env []string } // WithUseLocal the module will use the local Ollama instance instead of pulling the Docker image. // Pass the environment variables you need to set for the Ollama binary to be used, // in the format of "KEY=VALUE". KeyValue pairs with the wrong format will cause an error. -func WithUseLocal(keyVal map[string]string) UseLocal { - return UseLocal{env: keyVal} +func WithUseLocal(values ...string) UseLocal { + return UseLocal{env: values} } // Customize implements the ContainerCustomizer interface, taking the key value pairs // and setting them as environment variables for the Ollama binary. // In the case of an invalid key value pair, an error is returned. func (u UseLocal) Customize(req *testcontainers.GenericContainerRequest) error { - if len(u.env) == 0 { - return nil + env := make(map[string]string) + for _, kv := range u.env { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid environment variable: %s", kv) + } + + env[parts[0]] = parts[1] } - return testcontainers.WithEnv(u.env)(req) + return testcontainers.WithEnv(env)(req) } diff --git a/modules/ollama/options_test.go b/modules/ollama/options_test.go index 46872d0dd4..f842d15a17 100644 --- a/modules/ollama/options_test.go +++ b/modules/ollama/options_test.go @@ -12,18 +12,38 @@ import ( func TestWithUseLocal(t *testing.T) { req := testcontainers.GenericContainerRequest{} - t.Run("empty", func(t *testing.T) { - opt := ollama.WithUseLocal(nil) + t.Run("keyVal/valid", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models") err := opt.Customize(&req) require.NoError(t, err) - require.Empty(t, req.Env) + require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) + }) + + t.Run("keyVal/invalid", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS") + err := opt.Customize(&req) + require.Error(t, err) + }) + + t.Run("keyVal/valid/multiple", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST=localhost") + err := opt.Customize(&req) + require.NoError(t, err) + require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) + require.Equal(t, "localhost", req.Env["OLLAMA_HOST"]) }) - t.Run("valid", func(t *testing.T) { - opt := ollama.WithUseLocal(map[string]string{"OLLAMA_MODELS": "/path/to/models", "OLLAMA_HOST": "localhost:1234"}) + t.Run("keyVal/valid/multiple-equals", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST=localhost=127.0.0.1") err := opt.Customize(&req) require.NoError(t, err) require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) - require.Equal(t, "localhost:1234", req.Env["OLLAMA_HOST"]) + require.Equal(t, "localhost=127.0.0.1", req.Env["OLLAMA_HOST"]) + }) + + t.Run("keyVal/invalid/multiple", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST") + err := opt.Customize(&req) + require.Error(t, err) }) }