diff --git a/server/modelpath.go b/server/modelpath.go index 7d333876d60..86908226748 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -6,6 +6,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" ) @@ -25,9 +26,10 @@ const ( ) var ( - ErrInvalidImageFormat = errors.New("invalid image format") - ErrInvalidProtocol = errors.New("invalid protocol scheme") - ErrInsecureProtocol = errors.New("insecure protocol http") + ErrInvalidImageFormat = errors.New("invalid image format") + ErrInvalidProtocol = errors.New("invalid protocol scheme") + ErrInsecureProtocol = errors.New("insecure protocol http") + ErrInvalidDigestFormat = errors.New("invalid digest format") ) func ParseModelPath(name string) ModelPath { @@ -149,6 +151,17 @@ func GetBlobsPath(digest string) (string, error) { return "", err } + // only accept actual sha256 digests + pattern := "^sha256[:-][0-9a-fA-F]{64}$" + re := regexp.MustCompile(pattern) + if err != nil { + return "", err + } + + if digest != "" && !re.MatchString(digest) { + return "", ErrInvalidDigestFormat + } + digest = strings.ReplaceAll(digest, ":", "-") path := filepath.Join(dir, "blobs", digest) dirPath := filepath.Dir(path) diff --git a/server/modelpath_test.go b/server/modelpath_test.go index 8b26d52cfab..30741d872b5 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -1,6 +1,73 @@ package server -import "testing" +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetBlobsPath(t *testing.T) { + // GetBlobsPath expects an actual directory to exist + dir, err := os.MkdirTemp("", "ollama-test") + assert.Nil(t, err) + defer os.RemoveAll(dir) + + tests := []struct { + name string + digest string + expected string + err error + }{ + { + "empty digest", + "", + filepath.Join(dir, "blobs"), + nil, + }, + { + "valid with colon", + "sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9", + filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"), + nil, + }, + { + "valid with dash", + "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9", + filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"), + nil, + }, + { + "digest too short", + "sha256-45640291", + "", + ErrInvalidDigestFormat, + }, + { + "digest too long", + "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa", + "", + ErrInvalidDigestFormat, + }, + { + "digest invalid chars", + "../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a", + "", + ErrInvalidDigestFormat, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("OLLAMA_MODELS", dir) + + got, err := GetBlobsPath(tc.digest) + + assert.ErrorIs(t, tc.err, err, tc.name) + assert.Equal(t, tc.expected, got, tc.name) + }) + } +} func TestParseModelPath(t *testing.T) { tests := []struct {