From 7b1074f0cb17498be28bea2d85cc9916e3c6ffcb Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 20 Nov 2023 15:58:42 +0100 Subject: [PATCH] Refact cwhub (#2603) * Split RemoteHub.downloadIndex() = Hub.updateIndex() + RemoteHub.fetchIndex() * Functions safePath(), Item.installPath(), item.downloadPath() --- pkg/cwhub/cwhub.go | 22 ++++++++++++++ pkg/cwhub/dataset.go | 6 ++-- pkg/cwhub/enable.go | 40 ++++++++++++++++++------ pkg/cwhub/enable_test.go | 4 +-- pkg/cwhub/helpers.go | 66 +++++++++++++++++++--------------------- pkg/cwhub/hub.go | 45 +++++++++++++++++++++------ pkg/cwhub/hub_test.go | 12 +++++--- pkg/cwhub/remote.go | 38 ++++++----------------- pkg/cwhub/sync.go | 7 ++++- 9 files changed, 150 insertions(+), 90 deletions(-) diff --git a/pkg/cwhub/cwhub.go b/pkg/cwhub/cwhub.go index c7e17bd62f3..a2d10e5aa6f 100644 --- a/pkg/cwhub/cwhub.go +++ b/pkg/cwhub/cwhub.go @@ -5,10 +5,32 @@ package cwhub import ( + "fmt" "net/http" + "path/filepath" + "strings" "time" ) var hubClient = &http.Client{ Timeout: 120 * time.Second, } + +// safePath returns an error if the given file path would escape the base directory. +func safePath(dir, filePath string) (string, error) { + absBaseDir, err := filepath.Abs(filepath.Clean(dir)) + if err != nil { + return "", err + } + + absFilePath, err := filepath.Abs(filepath.Join(dir, filePath)) + if err != nil { + return "", err + } + + if !strings.HasPrefix(absFilePath, absBaseDir) { + return "", fmt.Errorf("path %s escapes base directory %s", filePath, dir) + } + + return absFilePath, nil +} diff --git a/pkg/cwhub/dataset.go b/pkg/cwhub/dataset.go index f002c668edd..79eb91573b0 100644 --- a/pkg/cwhub/dataset.go +++ b/pkg/cwhub/dataset.go @@ -6,7 +6,6 @@ import ( "io" "net/http" "os" - "path/filepath" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -67,7 +66,10 @@ func downloadDataSet(dataFolder string, force bool, reader io.Reader) error { } for _, dataS := range data.Data { - destPath := filepath.Join(dataFolder, dataS.DestPath) + destPath, err := safePath(dataFolder, dataS.DestPath) + if err != nil { + return err + } if _, err := os.Stat(destPath); os.IsNotExist(err) || force { log.Infof("downloading data '%s' in '%s'", dataS.SourceURL, destPath) diff --git a/pkg/cwhub/enable.go b/pkg/cwhub/enable.go index 1a5da53bf95..b38f74e9740 100644 --- a/pkg/cwhub/enable.go +++ b/pkg/cwhub/enable.go @@ -10,15 +10,33 @@ import ( log "github.com/sirupsen/logrus" ) -// installLink returns the location of the symlink to the downloaded config file +// installPath returns the location of the symlink to the item in the hub, or the path of the item itself if it's local // (eg. /etc/crowdsec/collections/xyz.yaml) -func (i *Item) installLinkPath() string { - return filepath.Join(i.hub.local.InstallDir, i.Type, i.Stage, i.FileName) +// raises an error if the path goes outside of the install dir +func (i *Item) installPath() (string, error) { + p := i.Type + if i.Stage != "" { + p = filepath.Join(p, i.Stage) + } + + return safePath(i.hub.local.InstallDir, filepath.Join(p, i.FileName)) +} + +// downloadPath returns the location of the actual config file in the hub +// (eg. /etc/crowdsec/hub/collections/author/xyz.yaml) +// raises an error if the path goes outside of the hub dir +func (i *Item) downloadPath() (string, error) { + ret, err := safePath(i.hub.local.HubDir, i.RemotePath) + if err != nil { + return "", err + } + + return ret, nil } // makeLink creates a symlink between the actual config file at hub.HubDir and hub.ConfigDir func (i *Item) createInstallLink() error { - dest, err := filepath.Abs(i.installLinkPath()) + dest, err := i.installPath() if err != nil { return err } @@ -33,7 +51,7 @@ func (i *Item) createInstallLink() error { return nil } - src, err := filepath.Abs(filepath.Join(i.hub.local.HubDir, i.RemotePath)) + src, err := i.downloadPath() if err != nil { return err } @@ -86,7 +104,10 @@ func (i *Item) purge() error { return nil } - src := filepath.Join(i.hub.local.HubDir, i.RemotePath) + src, err := i.downloadPath() + if err != nil { + return err + } if err := os.Remove(src); err != nil { if os.IsNotExist(err) { @@ -105,7 +126,7 @@ func (i *Item) purge() error { // removeInstallLink removes the symlink to the downloaded content func (i *Item) removeInstallLink() error { - syml, err := filepath.Abs(i.installLinkPath()) + syml, err := i.installPath() if err != nil { return err } @@ -126,7 +147,7 @@ func (i *Item) removeInstallLink() error { return fmt.Errorf("while reading symlink: %w", err) } - src, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath) + src, err := i.downloadPath() if err != nil { return err } @@ -151,7 +172,8 @@ func (i *Item) disable(purge bool, force bool) error { err := i.removeInstallLink() if os.IsNotExist(err) { if !purge && !force { - return fmt.Errorf("link %s does not exist (override with --force or --purge)", i.installLinkPath()) + link, _ := i.installPath() + return fmt.Errorf("link %s does not exist (override with --force or --purge)", link) } } else if err != nil { return err diff --git a/pkg/cwhub/enable_test.go b/pkg/cwhub/enable_test.go index 9173024a4b0..fc9863d9ea5 100644 --- a/pkg/cwhub/enable_test.go +++ b/pkg/cwhub/enable_test.go @@ -10,7 +10,7 @@ import ( func testInstall(hub *Hub, t *testing.T, item *Item) { // Install the parser - err := item.downloadLatest(false, false) + _, err := item.downloadLatest(false, false) require.NoError(t, err, "failed to download %s", item.Name) err = hub.localSync() @@ -48,7 +48,7 @@ func testUpdate(hub *Hub, t *testing.T, item *Item) { assert.False(t, hub.Items[item.Type][item.Name].UpToDate, "%s should not be up-to-date", item.Name) // Update it + check status - err := item.downloadLatest(true, true) + _, err := item.downloadLatest(true, true) require.NoError(t, err, "failed to update %s", item.Name) // Local sync and check status diff --git a/pkg/cwhub/helpers.go b/pkg/cwhub/helpers.go index 320d5a5e84a..f014a2a7d1f 100644 --- a/pkg/cwhub/helpers.go +++ b/pkg/cwhub/helpers.go @@ -13,7 +13,6 @@ import ( "net/http" "os" "path/filepath" - "strings" "github.com/enescakir/emoji" log "github.com/sirupsen/logrus" @@ -31,13 +30,14 @@ func (i *Item) Install(force bool, downloadOnly bool) error { } // XXX: confusing semantic between force and updateOnly? - if err := i.downloadLatest(force, true); err != nil { + filePath, err := i.downloadLatest(force, true) + if err != nil { return fmt.Errorf("while downloading %s: %w", i.Name, err) } if downloadOnly { // XXX: should get the path from downloadLatest - log.Infof("Downloaded %s to %s", i.Name, filepath.Join(i.hub.local.HubDir, i.RemotePath)) + log.Infof("Downloaded %s to %s", i.Name, filePath) return nil } @@ -177,7 +177,7 @@ func (i *Item) Upgrade(force bool) (bool, error) { } } - if err := i.downloadLatest(force, true); err != nil { + if _, err := i.downloadLatest(force, true); err != nil { return false, fmt.Errorf("%s: download failed: %w", i.Name, err) } @@ -200,7 +200,7 @@ func (i *Item) Upgrade(force bool) (bool, error) { } // downloadLatest downloads the latest version of the item to the hub directory -func (i *Item) downloadLatest(overwrite bool, updateOnly bool) error { +func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) { // XXX: should return the path of the downloaded file (taken from download()) log.Debugf("Downloading %s %s", i.Type, i.Name) @@ -216,39 +216,40 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) error { if sub.HasSubItems() { log.Tracef("collection, recurse") - if err := sub.downloadLatest(overwrite, updateOnly); err != nil { - return fmt.Errorf("while downloading %s: %w", sub.Name, err) + if _, err := sub.downloadLatest(overwrite, updateOnly); err != nil { + return "", fmt.Errorf("while downloading %s: %w", sub.Name, err) } } downloaded := sub.Downloaded - if err := sub.download(overwrite); err != nil { - return fmt.Errorf("while downloading %s: %w", sub.Name, err) + if _, err := sub.download(overwrite); err != nil { + return "", fmt.Errorf("while downloading %s: %w", sub.Name, err) } // We need to enable an item when it has been added to a collection since latest release of the collection. // We check if sub.Downloaded is false because maybe the item has been disabled by the user. if !sub.Installed && !downloaded { if err := sub.enable(); err != nil { - return fmt.Errorf("enabling '%s': %w", sub.Name, err) + return "", fmt.Errorf("enabling '%s': %w", sub.Name, err) } } } if !i.Installed && updateOnly && i.Downloaded { log.Debugf("skipping upgrade of %s: not installed", i.Name) - return nil + return "", nil } - if err := i.download(overwrite); err != nil { - return fmt.Errorf("failed to download item: %w", err) + ret, err := i.download(overwrite) + if err != nil { + return "", fmt.Errorf("failed to download item: %w", err) } - return nil + return ret, nil } -// fetch downloads the item from the hub, verifies the hash and returns the body +// fetch downloads the item from the hub, verifies the hash and returns the content func (i *Item) fetch() ([]byte, error) { url, err := i.hub.remote.urlTo(i.RemotePath) if err != nil { @@ -287,12 +288,12 @@ func (i *Item) fetch() ([]byte, error) { } // download downloads the item from the hub and writes it to the hub directory -func (i *Item) download(overwrite bool) error { +func (i *Item) download(overwrite bool) (string, error) { // if user didn't --force, don't overwrite local, tainted, up-to-date files if !overwrite { if i.Tainted { log.Debugf("%s: tainted, not updated", i.Name) - return nil + return "", nil } if i.UpToDate { @@ -303,39 +304,33 @@ func (i *Item) download(overwrite bool) error { body, err := i.fetch() if err != nil { - return err + return "", err } - tdir := i.hub.local.HubDir - - //all good, install - - finalPath, err := filepath.Abs(filepath.Join(tdir, i.RemotePath)) - if err != nil { - return err - } + // all good, install // ensure that target file is within target dir - if !strings.HasPrefix(finalPath, tdir) { - return fmt.Errorf("path %s escapes %s, abort", i.RemotePath, tdir) + finalPath, err := i.downloadPath() + if err != nil { + return "", err } parentDir := filepath.Dir(finalPath) if err = os.MkdirAll(parentDir, os.ModePerm); err != nil { - return fmt.Errorf("while creating %s: %w", parentDir, err) + return "", fmt.Errorf("while creating %s: %w", parentDir, err) } // check actual file if _, err = os.Stat(finalPath); !os.IsNotExist(err) { log.Warningf("%s: overwrite", i.Name) - log.Debugf("target: %s/%s", tdir, i.RemotePath) + log.Debugf("target: %s", finalPath) } else { log.Infof("%s: OK", i.Name) } if err = os.WriteFile(finalPath, body, 0o644); err != nil { - return fmt.Errorf("while writing %s: %w", finalPath, err) + return "", fmt.Errorf("while writing %s: %w", finalPath, err) } i.Downloaded = true @@ -343,15 +338,18 @@ func (i *Item) download(overwrite bool) error { i.UpToDate = true if err = downloadDataSet(i.hub.local.InstallDataDir, overwrite, bytes.NewReader(body)); err != nil { - return fmt.Errorf("while downloading data for %s: %w", i.FileName, err) + return "", fmt.Errorf("while downloading data for %s: %w", i.FileName, err) } - return nil + return finalPath, nil } // DownloadDataIfNeeded downloads the data files for the item func (i *Item) DownloadDataIfNeeded(force bool) error { - itemFilePath := fmt.Sprintf("%s/%s/%s/%s", i.hub.local.InstallDir, i.Type, i.Stage, i.FileName) + itemFilePath, err := i.installPath() + if err != nil { + return err + } itemFile, err := os.Open(itemFilePath) if err != nil { diff --git a/pkg/cwhub/hub.go b/pkg/cwhub/hub.go index ee3198fbb73..82be29b01d4 100644 --- a/pkg/cwhub/hub.go +++ b/pkg/cwhub/hub.go @@ -1,6 +1,7 @@ package cwhub import ( + "bytes" "encoding/json" "fmt" "os" @@ -24,25 +25,25 @@ func (h *Hub) GetDataDir() string { } // NewHub returns a new Hub instance with local and (optionally) remote configuration, and syncs the local state -// It also downloads the index if downloadIndex is true -func NewHub(local *csconfig.LocalHubCfg, remote *RemoteHubCfg, downloadIndex bool) (*Hub, error) { +// It also downloads the index if updateIndex is true +func NewHub(local *csconfig.LocalHubCfg, remote *RemoteHubCfg, updateIndex bool) (*Hub, error) { if local == nil { return nil, fmt.Errorf("no hub configuration found") } - if downloadIndex { - if err := remote.downloadIndex(local.HubIndexFile); err != nil { + hub := &Hub{ + local: local, + remote: remote, + } + + if updateIndex { + if err := hub.updateIndex(); err != nil { return nil, err } } log.Debugf("loading hub idx %s", local.HubIndexFile) - hub := &Hub{ - local: local, - remote: remote, - } - if err := hub.parseIndex(); err != nil { return nil, fmt.Errorf("failed to load index: %w", err) } @@ -129,3 +130,29 @@ func (h *Hub) ItemStats() []string { return ret } + +// updateIndex downloads the latest version of the index and writes it to disk if it changed +func (h *Hub) updateIndex() error { + body, err := h.remote.fetchIndex() + if err != nil { + return err + } + + oldContent, err := os.ReadFile(h.local.HubIndexFile) + if err != nil { + if !os.IsNotExist(err) { + log.Warningf("failed to read hub index: %s", err) + } + } else if bytes.Equal(body, oldContent) { + log.Info("hub index is up to date") + return nil + } + + if err = os.WriteFile(h.local.HubIndexFile, body, 0o644); err != nil { + return fmt.Errorf("failed to write hub index: %w", err) + } + + log.Infof("Wrote index to %s, %d bytes", h.local.HubIndexFile, len(body)) + + return nil +} diff --git a/pkg/cwhub/hub_test.go b/pkg/cwhub/hub_test.go index 56e2bf376a7..670f8d84356 100644 --- a/pkg/cwhub/hub_test.go +++ b/pkg/cwhub/hub_test.go @@ -23,7 +23,7 @@ func TestInitHubUpdate(t *testing.T) { require.NoError(t, err) } -func TestDownloadIndex(t *testing.T) { +func TestUpdateIndex(t *testing.T) { // bad url template fmt.Println("Test 'bad URL'") @@ -42,7 +42,9 @@ func TestDownloadIndex(t *testing.T) { IndexPath: "", } - err = hub.remote.downloadIndex(tmpIndex.Name()) + hub.local.HubIndexFile = tmpIndex.Name() + + err = hub.updateIndex() cstest.RequireErrorContains(t, err, "failed to build hub index request: invalid URL template 'x'") // bad domain @@ -54,7 +56,7 @@ func TestDownloadIndex(t *testing.T) { IndexPath: ".index.json", } - err = hub.remote.downloadIndex(tmpIndex.Name()) + err = hub.updateIndex() require.NoError(t, err) // XXX: this is not failing // cstest.RequireErrorContains(t, err, "failed http request for hub index: Get") @@ -68,6 +70,8 @@ func TestDownloadIndex(t *testing.T) { IndexPath: ".index.json", } - err = hub.remote.downloadIndex("/does/not/exist/index.json") + hub.local.HubIndexFile = "/does/not/exist/index.json" + + err = hub.updateIndex() cstest.RequireErrorContains(t, err, "failed to write hub index: open /does/not/exist/index.json:") } diff --git a/pkg/cwhub/remote.go b/pkg/cwhub/remote.go index 2b395681062..c98dfa8f5c2 100644 --- a/pkg/cwhub/remote.go +++ b/pkg/cwhub/remote.go @@ -1,13 +1,9 @@ package cwhub import ( - "bytes" "fmt" "io" "net/http" - "os" - - log "github.com/sirupsen/logrus" ) // RemoteHubCfg contains where to find the remote hub, which branch etc. @@ -31,51 +27,35 @@ func (r *RemoteHubCfg) urlTo(remotePath string) (string, error) { return fmt.Sprintf(r.URLTemplate, r.Branch, remotePath), nil } -// downloadIndex downloads the latest version of the index -func (r *RemoteHubCfg) downloadIndex(localPath string) error { +// fetchIndex downloads the index from the hub and returns the content +func (r *RemoteHubCfg) fetchIndex() ([]byte, error) { if r == nil { - return ErrNilRemoteHub + return nil, ErrNilRemoteHub } url, err := r.urlTo(r.IndexPath) if err != nil { - return fmt.Errorf("failed to build hub index request: %w", err) + return nil, fmt.Errorf("failed to build hub index request: %w", err) } resp, err := hubClient.Get(url) if err != nil { - return fmt.Errorf("failed http request for hub index: %w", err) + return nil, fmt.Errorf("failed http request for hub index: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { if resp.StatusCode == http.StatusNotFound { - return IndexNotFoundError{url, r.Branch} + return nil, IndexNotFoundError{url, r.Branch} } - return fmt.Errorf("bad http code %d for %s", resp.StatusCode, url) + return nil, fmt.Errorf("bad http code %d for %s", resp.StatusCode, url) } body, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("failed to read request answer for hub index: %w", err) - } - - oldContent, err := os.ReadFile(localPath) - if err != nil { - if !os.IsNotExist(err) { - log.Warningf("failed to read hub index: %s", err) - } - } else if bytes.Equal(body, oldContent) { - log.Info("hub index is up to date") - return nil - } - - if err = os.WriteFile(localPath, body, 0o644); err != nil { - return fmt.Errorf("failed to write hub index: %w", err) + return nil, fmt.Errorf("failed to read request answer for hub index: %w", err) } - log.Infof("Wrote index to %s, %d bytes", localPath, len(body)) - - return nil + return body, nil } diff --git a/pkg/cwhub/sync.go b/pkg/cwhub/sync.go index 4bcf6df4419..a755e10fec2 100644 --- a/pkg/cwhub/sync.go +++ b/pkg/cwhub/sync.go @@ -237,7 +237,12 @@ func (h *Hub) itemVisit(path string, f os.DirEntry, err error) error { continue } - if path == h.local.HubDir+"/"+item.RemotePath { + src, err := item.downloadPath() + if err != nil { + return err + } + + if path == src { log.Tracef("marking %s as downloaded", item.Name) item.Downloaded = true }